huggingface_hub-0.31.1/000077500000000000000000000000001500667546600147205ustar00rootroot00000000000000huggingface_hub-0.31.1/.gitignore000066400000000000000000000035501500667546600167130ustar00rootroot00000000000000# Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv .venv* env/ venv/ ENV/ env.bak/ venv.bak/ .venv* # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ .vscode/ .idea/ .DS_Store # Ruff .ruff_cache # Spell checker config cspell.json tmp*huggingface_hub-0.31.1/.pre-commit-config.yaml000066400000000000000000000010471500667546600212030ustar00rootroot00000000000000repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: - id: check-yaml exclude: .github/conda/meta.yaml|tests/cassettes/ - id: end-of-file-fixer - id: trailing-whitespace - id: check-case-conflict - id: check-merge-conflict - repo: https://github.com/charliermarsh/ruff-pre-commit # https://github.com/charliermarsh/ruff#usage rev: v0.1.13 hooks: - id: ruff - repo: https://github.com/pappasam/toml-sort rev: v0.23.1 hooks: - id: toml-sort-fix huggingface_hub-0.31.1/CODE_OF_CONDUCT.md000066400000000000000000000121521500667546600175200ustar00rootroot00000000000000 # Contributor Covenant Code of Conduct ## Our Pledge We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. ## Our Standards Examples of behavior that contributes to a positive environment for our community include: * Demonstrating empathy and kindness toward other people * Being respectful of differing opinions, viewpoints, and experiences * Giving and gracefully accepting constructive feedback * Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience * Focusing on what is best not just for us as individuals, but for the overall community Examples of unacceptable behavior include: * The use of sexualized language or imagery, and sexual attention or advances of any kind * Trolling, insulting or derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or email address, without their explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Enforcement Responsibilities Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. ## Scope This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at feedback@huggingface.co. All complaints will be reviewed and investigated promptly and fairly. All community leaders are obligated to respect the privacy and security of the reporter of any incident. ## Enforcement Guidelines Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: ### 1. Correction **Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. **Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. ### 2. Warning **Community Impact**: A violation through a single incident or series of actions. **Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban. ### 3. Temporary Ban **Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. **Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. ### 4. Permanent Ban **Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. **Consequence**: A permanent ban from any sort of public interaction within the community. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see the FAQ at https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations. huggingface_hub-0.31.1/CONTRIBUTING.md000066400000000000000000000266461500667546600171670ustar00rootroot00000000000000 # How to contribute to huggingface_hub, the GitHub repository? Everyone is welcome to contribute, and we value everybody's contribution. Code is not the only way to help the community. Answering questions, helping others, reaching out and improving the documentations are immensely valuable to the community. It also helps us if you spread the word: reference the library from blog posts on the awesome projects it made possible, shout out on Twitter every time it has helped you, or simply star the repo to say "thank you". Whichever way you choose to contribute, please be mindful to respect our [code of conduct](https://github.com/huggingface/huggingface_hub/blob/main/CODE_OF_CONDUCT.md). > Looking for a good first issue to work on? > Please check out our contributing guide below and then select an issue from our [curated list](https://github.com/huggingface/huggingface_hub/contribute). > Pick one and get started with it! ### The client library, `huggingface_hub` This repository hosts the `huggingface_hub`, the client library that interfaces any Python script with the Hugging Face Hub. Its implementation lives in `src/huggingface_hub`, while the tests are located in `tests/`. There are many ways you can contribute to this client library: - Fixing outstanding issues with the existing code; - Contributing to the examples or to the documentation; - Submitting issues related to bugs or desired new features. ## Submitting a new issue or feature request Do your best to follow these guidelines when submitting an issue or a feature request. It will make it easier for us to come back to you quickly and with good feedback. ### Did you find a bug? The `huggingface_hub` library is robust and reliable thanks to the users who notify us of the problems they encounter. So thank you for reporting an issue. First, we would really appreciate it if you could **make sure the bug was not already reported** (use the search bar on Github under Issues). Did not find it? :( So we can act quickly on it, please follow these steps: - A short, self-contained, code snippet that allows us to reproduce the bug in less than 30s; - Provide the _full_ traceback if an exception is raised by copying the text from your terminal in the issue description. - Include information about your local setup. You can dump this information by running `huggingface-cli env` in your terminal; ### Do you want a new feature? A good feature request addresses the following points: 1. Motivation first: - Is it related to a problem/frustration with the library? If so, please explain why and provide a code snippet that demonstrates the problem best. - Is it related to something you would need for a project? We'd love to hear about it! - Is it something you worked on and think could benefit the community? Awesome! Tell us what problem it solved for you. 2. Write a _full paragraph_ describing the feature; 3. Provide a **code snippet** that demonstrates its future use; 4. In case this is related to a paper, please attach a link; 5. Attach any additional information (drawings, screenshots, etc.) you think may help. If your issue is well written, we're already 80% of the way there by the time you post it! ## Submitting a pull request (PR) Before writing code, we strongly advise you to search through the existing PRs or issues to make sure that nobody is already working on the same thing. If you are unsure, it is always a good idea to open an issue to get some feedback. You will need basic `git` proficiency to be able to contribute to `huggingface_hub`. `git` is not the easiest tool to use but it has the greatest manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro Git](https://git-scm.com/book/en/v2) is a very good reference. Follow these steps to start contributing: 1. Fork the [repository](https://github.com/huggingface/huggingface_hub) by clicking on the 'Fork' button on the repository's page. This creates a copy of the code under your GitHub user account. 2. Clone your fork to your local disk, and add the base repository as a remote. The following command assumes you have your public SSH key uploaded to GitHub. See the following guide for more [information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository). ```bash $ git clone git@github.com:/huggingface_hub.git $ cd huggingface_hub $ git remote add upstream https://github.com/huggingface/huggingface_hub.git ``` 3. Create a new branch to hold your development changes, and do this for every new PR you work on. Start by synchronizing your `main` branch with the `upstream/main` branch (more details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)): ```bash $ git checkout main $ git fetch upstream $ git merge upstream/main ``` Once your `main` branch is synchronized, create a new branch from it: ```bash $ git checkout -b a-descriptive-name-for-my-changes ``` **Do not** work on the `main` branch. 4. Set up a development environment by running the following command in a [virtual environment](https://docs.python.org/3/library/venv.html#creating-virtual-environments) or a conda environment you've created for working on this library: ```bash $ pip uninstall huggingface_hub # make sure huggingface_hub is not already installed $ pip install -e ".[dev]" # install in editable (-e) mode ``` 5. Develop the features on your branch 6. Test your implementation! To make a good Pull Request you must test the features you have added. To do so, we use the `unittest` framework and run them using `pytest`: ```bash $ pytest tests -k # or $ pytest tests/.py ``` 7. Format your code. `huggingface_hub` relies on [`ruff`](https://github.com/astral-sh/ruff) to format its source code consistently. You can apply automatic style corrections and code verifications with the following command: ```bash $ make style ``` This command will update your code to comply with the standards of the `huggingface_hub` repository. A few custom scripts are also run to ensure consistency. Once automatic style corrections have been applied, you must test that it passes the quality checks: ```bash $ make quality ``` Compared to `make style`, `make quality` will never update your code. In addition to the previous code formatter, it also runs [`mypy`](https://github.com/python/mypy) to check for static typing issues. All those tests will also run in the CI once you open your PR but it is recommended to run them locally in order to iterate faster. > For the commands leveraging the `make` utility, we recommend using the WSL system when running on > Windows. More information [here](https://docs.microsoft.com/en-us/windows/wsl/about). 8. (optional) Alternatively, you can install pre-commit hooks so that these styles are applied and checked on files that you have touched in each commit: ```bash pip install pre-commit pre-commit install ``` You only need to do the above once in your repository's environment. If for any reason you would like to disable pre-commit hooks on a commit, you can pass `-n` to your `git commit` command to temporarily disable pre-commit hooks. To permanently disable hooks, you can run the following command: ```bash pre-commit uninstall ``` 9. Once you're happy with your changes, add changed files using `git add` and make a commit with `git commit` to record your changes locally: ```bash $ git add modified_file.py $ git commit ``` Please write [good commit messages](https://chris.beams.io/posts/git-commit/). It is a good idea to sync your copy of the code with the original repository regularly. The following document covers it in length: [github documentation](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork) And here's how you can do it quickly from your `git` commandline: ```bash $ git fetch upstream $ git rebase upstream/main ``` Push the changes to your account using: ```bash $ git push -u origin a-descriptive-name-for-my-changes ``` 10. Once you are satisfied (**and the [checklist below](https://github.com/huggingface/huggingface_hub/blob/main/CONTRIBUTING.md#checklist) is happy too**), go to the webpage of your fork on GitHub. Click on 'Pull request' to send your changes to the project maintainers for review. 11. It's ok if maintainers ask you for changes. It happens all the time to core contributors too! So everyone can see the changes in the Pull request, work in your local branch and push the changes to your fork. They will automatically appear in the pull request. 12. Once your changes have been approved, one of the project maintainers will merge your pull request for you. Good job! ### Checklist 1. The title of your pull request should be a summary of its contribution; 2. If your pull request addresses an issue, please mention the issue number in the pull request description to make sure they are linked (and people consulting the issue know you are working on it); 3. To indicate a work in progress please prefix the title with `[WIP]`, or mark the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate it from PRs ready to be merged; 4. Make sure existing tests pass; 5. Add high-coverage tests. No quality testing = no merge. 6. Due to the rapidly growing repository, it is important to make sure that no files that would significantly weigh down the repository are added. This includes images, videos and other non-text files. We prefer to leverage a hf.co hosted `dataset` like the ones hosted on [`hf-internal-testing`](https://huggingface.co/hf-internal-testing) in which to place these files and reference them by URL. We recommend putting them in the following dataset: [huggingface/documentation-images](https://huggingface.co/datasets/huggingface/documentation-images). If an external contribution, feel free to add the images to your PR and ask a Hugging Face member to migrate your images to this dataset. ### Tests An extensive test suite is included to test the library behavior and several examples. Library tests can be found in the [tests folder](https://github.com/huggingface/huggingface_hub/tree/main/tests). We use `pytest` in order to run the tests for the library. From the root of the repository they can be run with the following: ```bash $ python -m pytest ./tests ``` You can specify a smaller set of tests in order to test only the feature you're working on. For example, the following will only run the tests in the `test_repository.py` file: ```bash $ python -m pytest ./tests/test_repository.py ``` And the following will only run the tests that include `tag` in their name: ```bash $ python -m pytest ./tests -k tag ``` huggingface_hub-0.31.1/LICENSE000066400000000000000000000261351500667546600157340ustar00rootroot00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. huggingface_hub-0.31.1/MANIFEST.in000066400000000000000000000002361500667546600164570ustar00rootroot00000000000000include src/huggingface_hub/templates/modelcard_template.md include src/huggingface_hub/templates/datasetcard_template.md include src/huggingface_hub/py.typedhuggingface_hub-0.31.1/Makefile000066400000000000000000000045431500667546600163660ustar00rootroot00000000000000.PHONY: contrib quality style test check_dirs := contrib src tests utils setup.py quality: ruff check $(check_dirs) # linter ruff format --check $(check_dirs) # formatter python utils/check_inference_input_params.py python utils/check_contrib_list.py python utils/check_static_imports.py python utils/check_all_variable.py python utils/generate_async_inference_client.py mypy src style: ruff format $(check_dirs) # formatter ruff check --fix $(check_dirs) # linter python utils/check_contrib_list.py --update python utils/check_static_imports.py --update python utils/check_all_variable.py --update python utils/generate_async_inference_client.py --update inference_check: python utils/generate_inference_types.py python utils/check_task_parameters.py inference_update: python utils/generate_inference_types.py --update python utils/check_task_parameters.py --update repocard: python utils/push_repocard_examples.py test: pytest ./tests/ # Taken from https://stackoverflow.com/a/12110773 # Commands: # make contrib_setup_timm : setup tests for timm # make contrib_test_timm : run tests for timm # make contrib_timm : setup and run tests for timm # make contrib_clear_timm : delete timm virtual env # # make contrib_setup : setup ALL tests # make contrib_test : run ALL tests # make contrib : setup and run ALL tests # make contrib_clear : delete all virtual envs # Use -j4 flag to run jobs in parallel. CONTRIB_LIBS := sentence_transformers spacy timm CONTRIB_JOBS := $(addprefix contrib_,${CONTRIB_LIBS}) CONTRIB_CLEAR_JOBS := $(addprefix contrib_clear_,${CONTRIB_LIBS}) CONTRIB_SETUP_JOBS := $(addprefix contrib_setup_,${CONTRIB_LIBS}) CONTRIB_TEST_JOBS := $(addprefix contrib_test_,${CONTRIB_LIBS}) contrib_clear_%: rm -rf contrib/$*/.venv contrib_setup_%: python3 -m venv contrib/$*/.venv ./contrib/$*/.venv/bin/pip install -r contrib/$*/requirements.txt ./contrib/$*/.venv/bin/pip uninstall -y huggingface_hub ./contrib/$*/.venv/bin/pip install -e .[testing] contrib_test_%: ./contrib/$*/.venv/bin/python -m pytest contrib/$* contrib_%: make contrib_setup_$* make contrib_test_$* contrib: ${CONTRIB_JOBS}; contrib_clear: ${CONTRIB_CLEAR_JOBS}; echo "Successful contrib tests." contrib_setup: ${CONTRIB_SETUP_JOBS}; echo "Successful contrib setup." contrib_test: ${CONTRIB_TEST_JOBS}; echo "Successful contrib tests." huggingface_hub-0.31.1/README.md000066400000000000000000000162371500667546600162100ustar00rootroot00000000000000

huggingface_hub library logo

The official Python client for the Huggingface Hub.

Documentation GitHub release PyPi version PyPI - Downloads Code coverage

English | Deutsch | हिंदी | 한국어 | 中文(简体)

--- **Documentation**: https://hf.co/docs/huggingface_hub **Source Code**: https://github.com/huggingface/huggingface_hub --- ## Welcome to the huggingface_hub library The `huggingface_hub` library allows you to interact with the [Hugging Face Hub](https://huggingface.co/), a platform democratizing open-source Machine Learning for creators and collaborators. Discover pre-trained models and datasets for your projects or play with the thousands of machine learning apps hosted on the Hub. You can also create and share your own models, datasets and demos with the community. The `huggingface_hub` library provides a simple way to do all these things with Python. ## Key features - [Download files](https://huggingface.co/docs/huggingface_hub/en/guides/download) from the Hub. - [Upload files](https://huggingface.co/docs/huggingface_hub/en/guides/upload) to the Hub. - [Manage your repositories](https://huggingface.co/docs/huggingface_hub/en/guides/repository). - [Run Inference](https://huggingface.co/docs/huggingface_hub/en/guides/inference) on deployed models. - [Search](https://huggingface.co/docs/huggingface_hub/en/guides/search) for models, datasets and Spaces. - [Share Model Cards](https://huggingface.co/docs/huggingface_hub/en/guides/model-cards) to document your models. - [Engage with the community](https://huggingface.co/docs/huggingface_hub/en/guides/community) through PRs and comments. ## Installation Install the `huggingface_hub` package with [pip](https://pypi.org/project/huggingface-hub/): ```bash pip install huggingface_hub ``` If you prefer, you can also install it with [conda](https://huggingface.co/docs/huggingface_hub/en/installation#install-with-conda). In order to keep the package minimal by default, `huggingface_hub` comes with optional dependencies useful for some use cases. For example, if you want have a complete experience for Inference, run: ```bash pip install huggingface_hub[inference] ``` To learn more installation and optional dependencies, check out the [installation guide](https://huggingface.co/docs/huggingface_hub/en/installation). ## Quick start ### Download files Download a single file ```py from huggingface_hub import hf_hub_download hf_hub_download(repo_id="tiiuae/falcon-7b-instruct", filename="config.json") ``` Or an entire repository ```py from huggingface_hub import snapshot_download snapshot_download("stabilityai/stable-diffusion-2-1") ``` Files will be downloaded in a local cache folder. More details in [this guide](https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache). ### Login The Hugging Face Hub uses tokens to authenticate applications (see [docs](https://huggingface.co/docs/hub/security-tokens)). To log in your machine, run the following CLI: ```bash huggingface-cli login # or using an environment variable huggingface-cli login --token $HUGGINGFACE_TOKEN ``` ### Create a repository ```py from huggingface_hub import create_repo create_repo(repo_id="super-cool-model") ``` ### Upload files Upload a single file ```py from huggingface_hub import upload_file upload_file( path_or_fileobj="/home/lysandre/dummy-test/README.md", path_in_repo="README.md", repo_id="lysandre/test-model", ) ``` Or an entire folder ```py from huggingface_hub import upload_folder upload_folder( folder_path="/path/to/local/space", repo_id="username/my-cool-space", repo_type="space", ) ``` For details in the [upload guide](https://huggingface.co/docs/huggingface_hub/en/guides/upload). ## Integrating to the Hub. We're partnering with cool open source ML libraries to provide free model hosting and versioning. You can find the existing integrations [here](https://huggingface.co/docs/hub/libraries). The advantages are: - Free model or dataset hosting for libraries and their users. - Built-in file versioning, even with very large files, thanks to a git-based approach. - Serverless inference API for all models publicly available. - In-browser widgets to play with the uploaded models. - Anyone can upload a new model for your library, they just need to add the corresponding tag for the model to be discoverable. - Fast downloads! We use Cloudfront (a CDN) to geo-replicate downloads so they're blazing fast from anywhere on the globe. - Usage stats and more features to come. If you would like to integrate your library, feel free to open an issue to begin the discussion. We wrote a [step-by-step guide](https://huggingface.co/docs/hub/adding-a-library) with ❤️ showing how to do this integration. ## Contributions (feature requests, bugs, etc.) are super welcome 💙💚💛💜🧡❤️ Everyone is welcome to contribute, and we value everybody's contribution. Code is not the only way to help the community. Answering questions, helping others, reaching out and improving the documentations are immensely valuable to the community. We wrote a [contribution guide](https://github.com/huggingface/huggingface_hub/blob/main/CONTRIBUTING.md) to summarize how to get started to contribute to this repository. huggingface_hub-0.31.1/codecov.yml000066400000000000000000000004661500667546600170730ustar00rootroot00000000000000comment: # https://docs.codecov.com/docs/pull-request-comments#requiring-changes require_changes: true # https://docs.codecov.com/docs/pull-request-comments#after_n_builds after_n_builds: 12 coverage: status: # not in PRs patch: false project: false github_checks: annotations: false huggingface_hub-0.31.1/contrib/000077500000000000000000000000001500667546600163605ustar00rootroot00000000000000huggingface_hub-0.31.1/contrib/README.md000066400000000000000000000046401500667546600176430ustar00rootroot00000000000000# Contrib test suite The contrib folder contains simple end-to-end scripts to test integration of `huggingface_hub` in downstream libraries. The main goal is to proactively notice breaking changes and deprecation warnings. ## Add tests for a new library To add another contrib lib, one must: 1. Create a subfolder with the lib name. Example: `./contrib/transformers` 2. Create a `requirements.txt` file specific to this lib. Example `./contrib/transformers/requirements.txt` 3. Implements tests for this lib. Example: `./contrib/transformers/test_push_to_hub.py` 4. Run `make style`. This will edit both `makefile` and `.github/workflows/contrib-tests.yml` to add the lib to list of libs to test. Make sure changes are accurate before committing. ## Run contrib tests on CI Contrib tests can be [manually triggered in GitHub](https://github.com/huggingface/huggingface_hub/actions) with the `Contrib tests` workflow. Tests are not run in the default test suite (for each PR) as this would slow down development process. The goal is to notice breaking changes, not to avoid them. In particular, it is interesting to trigger it before a release to make sure it will not cause too much friction. ## Run contrib tests locally Tests must be ran individually for each dependent library. Here is an example to run `timm` tests. Tests are separated to avoid conflicts between version dependencies. ### Run all contrib tests Before running tests, a virtual env must be setup for each contrib library. To do so, run: ```sh # Run setup in parallel to save time make contrib_setup -j4 ``` Then tests can be run ```sh # Optional: -j4 to run in parallel. Output will be messy in that case. make contrib_test -j4 ``` Optionally, it is possible to setup and run all tests in a single command. However this take more time as you don't need to setup the venv each time you run tests. ```sh make contrib -j4 ``` Finally, it is possible to delete all virtual envs to get a fresh start for contrib tests. After running this command, `contrib_setup` will have to re-download/re-install all dependencies. ``` make contrib_clear ``` ### Run contrib tests for a single lib Instead of running tests for all contrib libraries, you can run a specific lib: ```sh # Setup timm tests make contrib_setup_timm # Run timm tests make contrib_test_timm # (or) Setup and run timm tests at once make contrib_timm # Delete timm virtualenv if corrupted make contrib_clear_timm ``` huggingface_hub-0.31.1/contrib/__init__.py000066400000000000000000000000001500667546600204570ustar00rootroot00000000000000huggingface_hub-0.31.1/contrib/conftest.py000066400000000000000000000024241500667546600205610ustar00rootroot00000000000000import os import time import uuid from typing import Generator import pytest from huggingface_hub import delete_repo @pytest.fixture(scope="session") def token() -> str: # Not critical, only usable on the sandboxed CI instance. return "hf_94wBhPGp6KrrTH3KDchhKpRxZwd6dmHWLL" @pytest.fixture(scope="session") def user() -> str: return "__DUMMY_TRANSFORMERS_USER__" @pytest.fixture(autouse=True, scope="session") def login_as_dummy_user(token: str) -> Generator: """Log in with dummy user token.""" # Cannot use `monkeypatch` fixture since we want it to be "session-scoped" old_token = os.environ["HF_TOKEN"] os.environ["HF_TOKEN"] = token yield os.environ["HF_TOKEN"] = old_token @pytest.fixture def repo_name(request) -> None: """ Return a readable pseudo-unique repository name for tests. Example: "repo-2fe93f-16599646671840" """ prefix = request.module.__name__ # example: `test_timm` id = uuid.uuid4().hex[:6] ts = int(time.time() * 10e3) return f"repo-{prefix}-{id}-{ts}" @pytest.fixture def cleanup_repo(user: str, repo_name: str) -> None: """Delete the repo at the end of the tests. TODO: Adapt to handle `repo_type` as well """ yield # run test delete_repo(repo_id=f"{user}/{repo_name}") huggingface_hub-0.31.1/contrib/sentence_transformers/000077500000000000000000000000001500667546600227715ustar00rootroot00000000000000huggingface_hub-0.31.1/contrib/sentence_transformers/__init__.py000066400000000000000000000000001500667546600250700ustar00rootroot00000000000000huggingface_hub-0.31.1/contrib/sentence_transformers/requirements.txt000066400000000000000000000001221500667546600262500ustar00rootroot00000000000000git+https://github.com/UKPLab/sentence-transformers.git#egg=sentence-transformers huggingface_hub-0.31.1/contrib/sentence_transformers/test_sentence_transformers.py000066400000000000000000000023561500667546600310210ustar00rootroot00000000000000import time import pytest from sentence_transformers import SentenceTransformer, util from huggingface_hub import model_info from ..utils import production_endpoint @pytest.fixture(scope="module") def multi_qa_model() -> SentenceTransformer: with production_endpoint(): return SentenceTransformer("multi-qa-MiniLM-L6-cos-v1") def test_from_pretrained(multi_qa_model: SentenceTransformer) -> None: # Example taken from https://www.sbert.net/docs/hugging_face.html#using-hugging-face-models. query_embedding = multi_qa_model.encode("How big is London") passage_embedding = multi_qa_model.encode( [ "London has 9,787,426 inhabitants at the 2011 census", "London is known for its financial district", ] ) print("Similarity:", util.dot_score(query_embedding, passage_embedding)) def test_push_to_hub(multi_qa_model: SentenceTransformer, repo_name: str, user: str, cleanup_repo: None) -> None: multi_qa_model.save_to_hub(repo_name, organization=user) # Sleep to ensure that model_info isn't called too soon time.sleep(1) # Check model has been pushed properly model_id = f"{user}/{repo_name}" assert model_info(model_id).library_name == "sentence-transformers" huggingface_hub-0.31.1/contrib/spacy/000077500000000000000000000000001500667546600174775ustar00rootroot00000000000000huggingface_hub-0.31.1/contrib/spacy/__init__.py000066400000000000000000000000001500667546600215760ustar00rootroot00000000000000huggingface_hub-0.31.1/contrib/spacy/requirements.txt000066400000000000000000000001251500667546600227610ustar00rootroot00000000000000git+https://github.com/explosion/spacy-huggingface-hub.git#egg=spacy-huggingface-hub huggingface_hub-0.31.1/contrib/spacy/test_spacy.py000066400000000000000000000024171500667546600222330ustar00rootroot00000000000000import time from spacy_huggingface_hub import push from huggingface_hub import delete_repo, hf_hub_download, model_info from huggingface_hub.errors import HfHubHTTPError from ..utils import production_endpoint def test_push_to_hub(user: str) -> None: """Test equivalent of `python -m spacy huggingface-hub push`. (0. Delete existing repo on the Hub (if any)) 1. Download an example file from production 2. Push the model! 3. Check model pushed the Hub + as spacy library (4. Cleanup) """ model_id = f"{user}/en_core_web_sm" _delete_repo(model_id) # Download example file from HF Hub (see https://huggingface.co/spacy/en_core_web_sm) with production_endpoint(): whl_path = hf_hub_download( repo_id="spacy/en_core_web_sm", filename="en_core_web_sm-any-py3-none-any.whl", ) # Push spacy model to Hub push(whl_path) # Sleep to ensure that model_info isn't called too soon time.sleep(1) # Check model has been pushed properly model_id = f"{user}/en_core_web_sm" assert model_info(model_id).library_name == "spacy" # Cleanup _delete_repo(model_id) def _delete_repo(model_id: str) -> None: try: delete_repo(model_id) except HfHubHTTPError: pass huggingface_hub-0.31.1/contrib/timm/000077500000000000000000000000001500667546600173265ustar00rootroot00000000000000huggingface_hub-0.31.1/contrib/timm/__init__.py000066400000000000000000000000001500667546600214250ustar00rootroot00000000000000huggingface_hub-0.31.1/contrib/timm/requirements.txt000066400000000000000000000001261500667546600226110ustar00rootroot00000000000000# Timm git+https://github.com/rwightman/pytorch-image-models.git#egg=timm safetensors huggingface_hub-0.31.1/contrib/timm/test_timm.py000066400000000000000000000007761500667546600217170ustar00rootroot00000000000000import timm from ..utils import production_endpoint MODEL_ID = "timm/mobilenetv3_large_100.ra_in1k" @production_endpoint() def test_load_from_hub() -> None: # Test load only config _ = timm.models.load_model_config_from_hf(MODEL_ID) # Load entire model from Hub _ = timm.create_model("hf_hub:" + MODEL_ID, pretrained=True) def test_push_to_hub(repo_name: str, cleanup_repo: None) -> None: model = timm.create_model("mobilenetv3_rw") timm.models.push_to_hf_hub(model, repo_name) huggingface_hub-0.31.1/contrib/utils.py000066400000000000000000000031171500667546600200740ustar00rootroot00000000000000import contextlib from typing import Generator from unittest.mock import patch @contextlib.contextmanager def production_endpoint() -> Generator: """Patch huggingface_hub to connect to production server in a context manager. Ugly way to patch all constants at once. TODO: refactor when https://github.com/huggingface/huggingface_hub/issues/1172 is fixed. Example: ```py def test_push_to_hub(): # Pull from production Hub with production_endpoint(): model = ...from_pretrained("modelname") # Push to staging Hub model.push_to_hub() ``` """ PROD_ENDPOINT = "https://huggingface.co" ENDPOINT_TARGETS = [ "huggingface_hub.constants", "huggingface_hub._commit_api", "huggingface_hub.hf_api", "huggingface_hub.lfs", "huggingface_hub.commands.user", "huggingface_hub.utils._git_credential", ] PROD_URL_TEMPLATE = PROD_ENDPOINT + "/{repo_id}/resolve/{revision}/{filename}" URL_TEMPLATE_TARGETS = [ "huggingface_hub.constants", "huggingface_hub.file_download", ] from huggingface_hub.hf_api import api patchers = ( [patch(target + ".ENDPOINT", PROD_ENDPOINT) for target in ENDPOINT_TARGETS] + [patch(target + ".HUGGINGFACE_CO_URL_TEMPLATE", PROD_URL_TEMPLATE) for target in URL_TEMPLATE_TARGETS] + [patch.object(api, "endpoint", PROD_URL_TEMPLATE)] ) # Start all patches for patcher in patchers: patcher.start() yield # Stop all patches for patcher in patchers: patcher.stop() huggingface_hub-0.31.1/docs/000077500000000000000000000000001500667546600156505ustar00rootroot00000000000000huggingface_hub-0.31.1/docs/README.md000066400000000000000000000255011500667546600171320ustar00rootroot00000000000000 # Generating the documentation To generate the documentation, you have to build it. Several packages are necessary to build the doc. First, you need to install the project itself by running the following command at the root of the code repository: ```bash pip install -e . ``` You also need to install 2 extra packages: ```bash # `hf-doc-builder` to build the docs pip install git+https://github.com/huggingface/doc-builder@main # `watchdog` for live reloads pip install watchdog ``` --- **NOTE** You only need to generate the documentation to inspect it locally (if you're planning changes and want to check how they look before committing for instance). You don't have to commit the built documentation. --- ## Building the documentation Once you have setup the `doc-builder` and additional packages with the pip install command above, you can generate the documentation by typing the following command: ```bash doc-builder build huggingface_hub docs/source/en/ --build_dir ~/tmp/test-build ``` You can adapt the `--build_dir` to set any temporary folder that you prefer. This command will create it and generate the MDX files that will be rendered as the documentation on the main website. You can inspect them in your favorite Markdown editor. ## Previewing the documentation To preview the docs, run the following command: ```bash doc-builder preview huggingface_hub docs/source/en/ ``` The docs will be viewable at [http://localhost:5173](http://localhost:5173). You can also preview the docs once you have opened a PR. You will see a bot add a comment to a link where the documentation with your changes lives. --- **NOTE** The `preview` command only works with existing doc files. When you add a completely new file, you need to update `_toctree.yml` & restart `preview` command (`ctrl-c` to stop it & call `doc-builder preview ...` again). --- ## Adding a new element to the navigation bar Accepted files are Markdown (.md). Create a file with its extension and put it in the source directory. You can then link it to the toc-tree by putting the filename without the extension in the [`_toctree.yml`](https://github.com/huggingface/huggingface_hub/blob/main/docs/source/en/_toctree.yml) file. ## Renaming section headers and moving sections It helps to keep the old links working when renaming the section header and/or moving sections from one document to another. This is because the old links are likely to be used in Issues, Forums, and Social media and it'd make for a much more superior user experience if users reading those months later could still easily navigate to the originally intended information. Therefore, we simply keep a little map of moved sections at the end of the document where the original section was. The key is to preserve the original anchor. So if you renamed a section from: "Section A" to "Section B", then you can add at the end of the file: ``` Sections that were moved: [ Section A ] ``` and of course, if you moved it to another file, then: ``` Sections that were moved: [ Section A ] ``` Use the relative style to link to the new file so that the versioned docs continue to work. For an example of a rich moved section set please see the very end of [the transformers Trainer doc](https://github.com/huggingface/transformers/blob/main/docs/source/en/main_classes/trainer.md). ## Writing Documentation - Specification The `huggingface/huggingface_hub` documentation follows the [Google documentation](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) style for docstrings, although we can write them directly in Markdown. ### Adding a new tutorial Adding a new tutorial or section is done in two steps: - Add a new Markdown (.md) file under `./source`. - Link that file in `./source/_toctree.yml` on the correct toc-tree. Make sure to put your new file under the proper section. If you have a doubt, feel free to ask in a Github Issue or PR. ### Translating When translating, refer to the guide at [./TRANSLATING.md](https://github.com/huggingface/huggingface_hub/blob/main/docs/TRANSLATING.md). ### Writing source documentation Values that should be put in `code` should either be surrounded by backticks: \`like so\`. Note that argument names and objects like True, None, or any strings should usually be put in `code`. When mentioning a class, function, or method, it is recommended to use our syntax for internal links so that our tool adds a link to its documentation with this syntax: \[\`XXXClass\`\] or \[\`function\`\]. This requires the class or function to be in the main package. If you want to create a link to some internal class or function, you need to provide its path. For instance: \[\`utils.ModelOutput\`\]. This will be converted into a link with `utils.ModelOutput` in the description. To get rid of the path and only keep the name of the object you are linking to in the description, add a ~: \[\`~utils.ModelOutput\`\] will generate a link with `ModelOutput` in the description. The same works for methods so you can either use \[\`XXXClass.method\`\] or \[~\`XXXClass.method\`\]. #### Defining arguments in a method Arguments should be defined with the `Args:` (or `Arguments:` or `Parameters:`) prefix, followed by a line return and an indentation. The argument should be followed by its type, with its shape if it is a tensor, a colon, and its description: ``` Args: n_layers (`int`): The number of layers of the model. ``` If the description is too long to fit in one line, another indentation is necessary before writing the description after the argument. Here's an example showcasing everything so far: ``` Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AlbertTokenizer`]. See [`~PreTrainedTokenizer.encode`] and [`~PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) ``` For optional arguments or arguments with defaults we follow the following syntax: imagine we have a function with the following signature: ``` def my_function(x: str = None, a: float = 1): ``` then its documentation should look like this: ``` Args: x (`str`, *optional*): This argument controls ... a (`float`, *optional*, defaults to 1): This argument is used to ... ``` Note that we always omit the "defaults to \`None\`" when None is the default for any argument. Also note that even if the first line describing your argument type and its default gets long, you can't break it on several lines. You can however write as many lines as you want in the indented description (see the example above with `input_ids`). #### Writing a multi-line code block Multi-line code blocks can be useful for displaying examples. They are done between two lines of three backticks as usual in Markdown: ```` ``` # first line of code # second line # etc ``` ```` #### Writing a return block The return block should be introduced with the `Returns:` prefix, followed by a line return and an indentation. The first line should be the type of the return, followed by a line return. No need to indent further for the elements building the return. Here's an example of a single value return: ``` Returns: `List[int]`: A list of integers in the range [0, 1] --- 1 for a special token, 0 for a sequence token. ``` Here's an example of a tuple return, comprising several objects: ``` Returns: `tuple(torch.FloatTensor)` comprising various elements depending on the configuration ([`BertConfig`]) and inputs: - ** loss** (*optional*, returned when `masked_lm_labels` is provided) `torch.FloatTensor` of shape `(1,)` -- Total loss is the sum of the masked language modeling loss and the next sequence prediction (classification) loss. - **prediction_scores** (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`) -- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). ``` #### Adding an image Due to the rapidly growing repository, it is important to make sure that no files that would significantly weigh down the repository are added. This includes images, videos, and other non-text files. We prefer to leverage a hf.co hosted `dataset` like the ones hosted on [`hf-internal-testing`](https://huggingface.co/hf-internal-testing) in which to place these files and reference them by URL. We recommend putting them in the following dataset: [huggingface/documentation-images](https://huggingface.co/datasets/huggingface/documentation-images). If an external contribution, feel free to add the images to your PR and ask a Hugging Face member to migrate your images to this dataset. #### Writing documentation examples The syntax for Example docstrings can look as follows: ``` Example: ```python >>> from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC >>> from datasets import load_dataset >>> import torch >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") >>> dataset = dataset.sort("id") >>> sampling_rate = dataset.features["audio"].sampling_rate >>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") >>> model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") >>> # audio file is decoded on the fly >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") >>> with torch.no_grad(): ... logits = model(**inputs).logits >>> predicted_ids = torch.argmax(logits, dim=-1) >>> # transcribe speech >>> transcription = processor.batch_decode(predicted_ids) >>> transcription[0] 'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL' ``` ``` The docstring should give a minimal, clear example of how the respective model is to be used in inference and also include the expected (ideally sensible) output. Often, readers will try out the example before even going through the function or class definitions. Therefore, it is of utmost importance that the example works as expected. huggingface_hub-0.31.1/docs/TRANSLATING.md000066400000000000000000000063551500667546600177710ustar00rootroot00000000000000### Translating the `huggingface_hub` documentation into your language As part of our mission to democratize machine learning, we'd love to make the `huggingface_hub` library available in many more languages! Follow the steps below if you want to help translate the documentation into your language 🙏. **🗞️ Open an issue** To get started, navigate to the [Issues](https://github.com/huggingface/huggingface_hub/issues) page of this repo and check if anyone else has opened an issue for your language. If not, open a new issue by selecting the "Translation template" from the "New issue" button. Once an issue exists, post a comment to indicate which chapters you'd like to work on, and we'll add your name to the list. **🍴 Fork the repository** First, you'll need to [fork the `huggingface_hub` repo](https://docs.github.com/en/get-started/quickstart/fork-a-repo). You can do this by clicking on the **Fork** button on the top-right corner of this repo's page. Once you've forked the repo, you'll want to get the files on your local machine for editing. You can do that by cloning the fork with Git as follows: ```bash git clone https://github.com/YOUR-USERNAME/huggingface_hub.git ``` **📋 Copy-paste the English version with a new language code** The documentation files are in one leading directory: - [`docs/source`](https://github.com/huggingface/huggingface_hub/tree/main/docs/source): All the documentation materials are organized here by language. You'll only need to copy the files in the [`docs/source/en`](https://github.com/huggingface/huggingface_hub/tree/main/docs/source/en) directory, so first navigate to your fork of the repo and run the following: ```bash cd ~/path/to/huggingface_hub/docs cp -r source/en source/LANG-ID ``` Here, `LANG-ID` should be one of the ISO 639-1 or ISO 639-2 language codes -- see [here](https://www.loc.gov/standards/iso639-2/php/code_list.php) for a handy table. **✍️ Start translating** The fun part comes - translating the text! The first thing we recommend is translating the part of the `_toctree.yml` file that corresponds to your doc chapter. This file is used to render the table of contents on the website. > 🙋 If the `_toctree.yml` file doesn't yet exist for your language, you can create one by copy-pasting from the English version and deleting the sections unrelated to your chapter. Just make sure it exists in the `docs/source/LANG-ID/` directory! The fields you should add are `local` (with the name of the file containing the translation; e.g. `guides/manage-spaces`), and `title` (with the title of the doc in your language; e.g. `Manage your Space`) -- as a reference, here is the `_toctree.yml` for [English](https://github.com/huggingface/huggingface_hub/blob/main/docs/source/en/_toctree.yml): ```yaml - title: "How-to guides" # Translate this! sections: - local: guides/manage-spaces # Do not change this! Use the same name for your .md file title: Manage your Space # Translate this! ... ``` Once you have translated the `_toctree.yml` file, you can start translating the Markdown files associated with your docs chapter. > 🙋 If you'd like others to help you with the translation, you should [open an issue](https://github.com/huggingface/huggingface_hub/issues) and tag @Wauplin. huggingface_hub-0.31.1/docs/dev/000077500000000000000000000000001500667546600164265ustar00rootroot00000000000000huggingface_hub-0.31.1/docs/dev/release.md000066400000000000000000000032041500667546600203670ustar00rootroot00000000000000This document covers all steps that need to be done in order to do a release of the `huggingface_hub` library. 1. On a clone of the main repo, not your fork, checkout the main branch and pull the latest changes: ``` git checkout main git pull ``` 2. Checkout a new branch with the version that you'd like to release: v-release, for example `v0.5-release`. All patches will be done to that same branch. 3. Update the `__version__` variable in the `src/huggingface_hub/__init__.py` file to point to the version you're releasing: ``` __version__ = "" ``` 4. Make sure that the conda build works correctly by building it locally: ``` conda install -c defaults anaconda-client conda-build HUB_VERSION= conda-build .github/conda ``` 5. Make sure that the pip wheel works correctly by building it locally and installing it: ``` pip install setuptools wheel python setup.py sdist bdist_wheel pip install dist/huggingface_hub--py3-none-any.whl ``` 6. Commit, tag, and push the branch: ``` git commit -am "Release: v" git tag v -m "Adds tag v for pypi and conda" git push -u --tags origin v-release ``` 7. Verify that the docs have been built correctly. You can check that on the following link: https://huggingface.co/docs/huggingface_hub/v 8. Checkout main once again to update the version in the `__init__.py` file: ``` git checkout main ``` 9. Update the version to contain the `.dev0` suffix: ``` __version__ = ".dev0" # For example, after releasing v0.5.0 or v0.5.1: "0.6.0.dev0". ``` 10. Push the changes! ``` git push origin main ``` huggingface_hub-0.31.1/docs/source/000077500000000000000000000000001500667546600171505ustar00rootroot00000000000000huggingface_hub-0.31.1/docs/source/cn/000077500000000000000000000000001500667546600175505ustar00rootroot00000000000000huggingface_hub-0.31.1/docs/source/cn/_toctree.yml000066400000000000000000000011631500667546600221000ustar00rootroot00000000000000- title: "Starten" sections: - local: index title: 索引 - local: quick-start title: 快速入门指南 - local: installation title: 安装 - title: "guides" sections: - local: guides/repository title: 储存库 - local: guides/search title: 搜索 - local: guides/collections title: 集合 - local: guides/community title: 社区 - local: guides/overview title: 概览 - local: guides/hf_file_system title: Hugging Face 文件系统 - title: "concepts" sections: - local: concepts/git_vs_http title: Git vs HTTP 范式 huggingface_hub-0.31.1/docs/source/cn/concepts/000077500000000000000000000000001500667546600213665ustar00rootroot00000000000000huggingface_hub-0.31.1/docs/source/cn/concepts/git_vs_http.md000066400000000000000000000066041500667546600242500ustar00rootroot00000000000000 # Git 与 HTTP 范式 `huggingface_hub`库是用于与Hugging Face Hub进行交互的库,Hugging Face Hub是一组基于Git的存储库(模型、数据集或Spaces)。使用 `huggingface_hub`有两种主要方式来访问Hub。 第一种方法,即所谓的“基于git”的方法,由[`Repository`]类驱动。这种方法使用了一个包装器,它在 `git`命令的基础上增加了专门与Hub交互的额外函数。第二种选择,称为“基于HTTP”的方法,涉及使用[`HfApi`]客户端进行HTTP请求。让我们来看一看每种方法的优缺点。 ## 存储库:基于历史的 Git 方法 最初,`huggingface_hub`主要围绕 [`Repository`] 类构建。它为常见的 `git` 命令(如 `"git add"`、`"git commit"`、`"git push"`、`"git tag"`、`"git checkout"` 等)提供了 Python 包装器 该库还可以帮助设置凭据和跟踪大型文件,这些文件通常在机器学习存储库中使用。此外,该库允许您在后台执行其方法,使其在训练期间上传数据很有用。 使用 [`Repository`] 的最大优点是它允许你在本地机器上维护整个存储库的本地副本。这也可能是一个缺点,因为它需要你不断更新和维护这个本地副本。这类似于传统软件开发中,每个开发人员都维护自己的本地副本,并在开发功能时推送更改。但是,在机器学习的上下文中,这可能并不总是必要的,因为用户可能只需要下载推理所需的权重,或将权重从一种格式转换为另一种格式,而无需克隆整个存储库。 ## HfApi: 一个功能强大且方便的HTTP客户端 `HfApi` 被开发为本地 git 存储库的替代方案,因为本地 git 存储库在处理大型模型或数据集时可能会很麻烦。`HfApi` 提供与基于 git 的方法相同的功能,例如下载和推送文件以及创建分支和标签,但无需本地文件夹来保持同步。 `HfApi`除了提供 `git` 已经提供的功能外,还提供其他功能,例如: * 管理存储库 * 使用缓存下载文件以进行有效的重复使用 * 在 Hub 中搜索存储库和元数据 * 访问社区功能,如讨论、PR和评论 * 配置Spaces ## 我应该使用什么?以及何时使用? 总的来说,在大多数情况下,`HTTP 方法`是使用 huggingface_hub 的推荐方法。但是,在以下几种情况下,维护本地 git 克隆(使用 `Repository`)可能更有益: 如果您在本地机器上训练模型,使用传统的 git 工作流程并定期推送更新可能更有效。`Repository` 被优化为此类情况,因为它能够在后台运行。 如果您需要手动编辑大型文件,`git `是最佳选择,因为它只会将文件的差异发送到服务器。使用 `HfAPI` 客户端,每次编辑都会上传整个文件。请记住,大多数大型文件是二进制文件,因此无法从 git 差异中受益。 并非所有 git 命令都通过 [`HfApi`] 提供。有些可能永远不会被实现,但我们一直在努力改进并缩小差距。如果您没有看到您的用例被覆盖。 请在[Github](https://github.com/huggingface/huggingface_hub)打开一个 issue!我们欢迎反馈,以帮助我们与我们的用户一起构建 🤗 生态系统。 huggingface_hub-0.31.1/docs/source/cn/guides/000077500000000000000000000000001500667546600210305ustar00rootroot00000000000000huggingface_hub-0.31.1/docs/source/cn/guides/collections.md000066400000000000000000000225361500667546600237000ustar00rootroot00000000000000 # 集合(Collections) 集合(collection)是 Hub 上将一组相关项目(模型、数据集、Spaces、论文)组织在同一页面上的一种方式。利用集合,你可以创建自己的作品集、为特定类别的内容添加书签,或呈现你想要分享的精选条目。要了解更多关于集合的概念及其在 Hub 上的呈现方式,请查看这篇 [指南](https://huggingface.co/docs/hub/collections) 你可以直接在浏览器中管理集合,但本指南将重点介绍如何以编程方式进行管理。 ## 获取集合 使用 [`get_collection`] 来获取你的集合或任意公共集合。 你需要提供集合的 *slug* 才能检索到该集合。 slug 是基于集合标题和唯一 ID 的标识符。你可以在集合页面的 URL 中找到它。
让我们获取`"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`这个集合: ```py >>> from huggingface_hub import get_collection >>> collection = get_collection("TheBloke/recent-models-64f9a55bb3115b4f513ec026") >>> collection Collection( slug='TheBloke/recent-models-64f9a55bb3115b4f513ec026', title='Recent models', owner='TheBloke', items=[...], last_updated=datetime.datetime(2023, 10, 2, 22, 56, 48, 632000, tzinfo=datetime.timezone.utc), position=1, private=False, theme='green', upvotes=90, description="Models I've recently quantized. Please note that currently this list has to be updated manually, and therefore is not guaranteed to be up-to-date." ) >>> collection.items[0] CollectionItem( item_object_id='651446103cd773a050bf64c2', item_id='TheBloke/U-Amethyst-20B-AWQ', item_type='model', position=88, note=None ) ``` [`get_collection`] 返回的 [`Collection`] 对象包含以下信息: - 高级元数据: `slug`, `owner`, `title`, `description`等。 - 一个 [`CollectionItem`] 对象列表; 每个条目代表一个模型、数据集、Space 或论文。 所有集合条目(items)都保证具有: - 唯一的 `item_object_id`: 这是集合条目在数据库中的唯一 ID - 一个 `item_id`: Hub 上底层条目的 ID(模型、数据集、Space、论文);此 ID 不一定是唯一的,仅当 `item_id` 与 `item_type` 成对出现时才唯一 - 一个 `item_type`: 如`model`, `dataset`, `Space`, `paper` - 该条目在集合中的 `position`, 可通过后续操作 (参加下文的 [`update_collection_item`])来重新排序集合条目 此外,`note` 可选地附加在条目上。这对为某个条目添加额外信息(评论、博客文章链接等)很有帮助。如果条目没有备注,`note` 的值为 `None`。 除了这些基本属性之外,不同类型的条目可能会返回额外属性,如:`author`、`private`、`lastModified`、`gated`、`title`、`likes`、`upvotes` 等。这些属性不保证一定存在。 ## 列出集合 我们也可以使用 [`list_collections`]来检索集合,并通过一些参数进行过滤。让我们列出用户[`teknium`](https://huggingface.co/teknium)的所有集合: ```py >>> from huggingface_hub import list_collections >>> collections = list_collections(owner="teknium") ``` 这将返回一个 Collection 对象的可迭代序列。我们可以遍历它们,比如打印每个集合的点赞数(upvotes): ```py >>> for collection in collections: ... print("Number of upvotes:", collection.upvotes) Number of upvotes: 1 Number of upvotes: 5 ``` 当列出集合时,每个集合中返回的条目列表最多会被截断为 4 个。若要检索集合中的所有条目,你必须使用 [`get_collection`]. 我们可以进行更高级的过滤。例如,让我们获取所有包含模型 [TheBloke/OpenHermes-2.5-Mistral-7B-GGUF](https://huggingface.co/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF) 的集合,并按照趋势(trending)进行排序,同时将结果限制为 5 个。 ```py >>> collections = list_collections(item="models/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF", sort="trending", limit=5): >>> for collection in collections: ... print(collection.slug) teknium/quantized-models-6544690bb978e0b0f7328748 AmeerH/function-calling-65560a2565d7a6ef568527af PostArchitekt/7bz-65479bb8c194936469697d8c gnomealone/need-to-test-652007226c6ce4cdacf9c233 Crataco/favorite-7b-models-651944072b4fffcb41f8b568 ``` `sort` 参数必须是 `"last_modified"`、`"trending"` 或 `"upvotes"` 之一。`item` 参数接受任意特定条目,例如: * `"models/teknium/OpenHermes-2.5-Mistral-7B"` * `"spaces/julien-c/open-gpt-rhyming-robot"` * `"datasets/squad"` * `"papers/2311.12983"` 详情请查看 [`list_collections`] 的参考文档。 ## 创建新集合 现在我们已经知道如何获取一个 [`Collection`], 让我们自己创建一个吧! 使用 [`create_collection`],传入一个标题和描述即可。 若要在组织(organization)名下创建集合,可以通过 `namespace="my-cool-org"` 参数指定。同样,你也可以通过传入 `private=True` 创建私有集合。 ```py >>> from huggingface_hub import create_collection >>> collection = create_collection( ... title="ICCV 2023", ... description="Portfolio of models, papers and demos I presented at ICCV 2023", ... ) ``` 该函数会返回一个包含高级元数据(标题、描述、所有者等)和空条目列表的 [`Collection`] 对象。现在你可以使用返回的 `slug` 来引用该集合。 ```py >>> collection.slug 'owner/iccv-2023-15e23b46cb98efca45' >>> collection.title "ICCV 2023" >>> collection.owner "username" >>> collection.url 'https://huggingface.co/collections/owner/iccv-2023-15e23b46cb98efca45' ``` ## 管理集合中的条目 现在我们有了一个 [`Collection`],接下来要添加条目并进行管理。 ### 添加条目 使用 [`add_collection_item`] 来向集合中添加条目(一次添加一个)。你只需要提供 `collection_slug`、`item_id` 和 `item_type`。可选参数 `note` 用于为该条目添加附加说明(最多 500 个字符)。 ```py >>> from huggingface_hub import create_collection, add_collection_item >>> collection = create_collection(title="OS Week Highlights - Sept 18 - 24", namespace="osanseviero") >>> collection.slug "osanseviero/os-week-highlights-sept-18-24-650bfed7f795a59f491afb80" >>> add_collection_item(collection.slug, item_id="coqui/xtts", item_type="space") >>> add_collection_item( ... collection.slug, ... item_id="warp-ai/wuerstchen", ... item_type="model", ... note="Würstchen is a new fast and efficient high resolution text-to-image architecture and model" ... ) >>> add_collection_item(collection.slug, item_id="lmsys/lmsys-chat-1m", item_type="dataset") >>> add_collection_item(collection.slug, item_id="warp-ai/wuerstchen", item_type="space") # same item_id, different item_type ``` 如果一个条目已存在于集合中(相同的 `item_id` 和 `item_type`),将会引发 HTTP 409 错误。你可以通过设置 `exists_ok=True` 来忽略此错误。 ### 为已存在条目添加备注 你可以使用 [`update_collection_item`] 来为已存在条目添加或修改备注。让我们重用上面的示例: ```py >>> from huggingface_hub import get_collection, update_collection_item # Fetch collection with newly added items >>> collection_slug = "osanseviero/os-week-highlights-sept-18-24-650bfed7f795a59f491afb80" >>> collection = get_collection(collection_slug) # Add note the `lmsys-chat-1m` dataset >>> update_collection_item( ... collection_slug=collection_slug, ... item_object_id=collection.items[2].item_object_id, ... note="This dataset contains one million real-world conversations with 25 state-of-the-art LLMs.", ... ) ``` ### 重新排序条目 集合中的条目是有序的。该顺序由每个条目的 `position` 属性决定。默认情况下,新添加的条目会被追加到集合末尾。你可以通过 [`update_collection_item`] 来更新顺序。 再次使用之前的示例: ```py >>> from huggingface_hub import get_collection, update_collection_item # Fetch collection >>> collection_slug = "osanseviero/os-week-highlights-sept-18-24-650bfed7f795a59f491afb80" >>> collection = get_collection(collection_slug) # Reorder to place the two `Wuerstchen` items together >>> update_collection_item( ... collection_slug=collection_slug, ... item_object_id=collection.items[3].item_object_id, ... position=2, ... ) ``` ### 删除条目 最后,你也可以使用 [`delete_collection_item`] 来删除集合中的条目。 ```py >>> from huggingface_hub import get_collection, update_collection_item # Fetch collection >>> collection_slug = "osanseviero/os-week-highlights-sept-18-24-650bfed7f795a59f491afb80" >>> collection = get_collection(collection_slug) # Remove `coqui/xtts` Space from the list >>> delete_collection_item(collection_slug=collection_slug, item_object_id=collection.items[0].item_object_id) ``` ## 删除集合 可以使用 [`delete_collection`] 来删除集合。 此操作不可逆。删除的集合无法恢复。 ```py >>> from huggingface_hub import delete_collection >>> collection = delete_collection("username/useless-collection-64f9a55bb3115b4f513ec026", missing_ok=True) ``` huggingface_hub-0.31.1/docs/source/cn/guides/community.md000066400000000000000000000134301500667546600233770ustar00rootroot00000000000000 # 互动讨论与拉取请求(Pull Request) huggingface_hub 库提供了一个 Python 接口,用于与 Hub 上的拉取请求(Pull Request)和讨论互动。 访问 [相关的文档页面](https://huggingface.co/docs/hub/repositories-pull-requests-discussions) ,了解有关 Hub 上讨论和拉取请求(Pull Request)的更深入的介绍及其工作原理。 ## 从 Hub 获取讨论和拉取请求(Pull Request) `HfApi` 类允许您获取给定仓库中的讨论和拉取请求(Pull Request): ```python >>> from huggingface_hub import get_repo_discussions >>> for discussion in get_repo_discussions(repo_id="bigscience/bloom"): ... print(f"{discussion.num} - {discussion.title}, pr: {discussion.is_pull_request}") # 11 - Add Flax weights, pr: True # 10 - Update README.md, pr: True # 9 - Training languages in the model card, pr: True # 8 - Update tokenizer_config.json, pr: True # 7 - Slurm training script, pr: False [...] ``` `HfApi.get_repo_discussions` 支持按作者、类型(拉取请求或讨论)和状态(`open` 或 `closed`)进行过滤: ```python >>> from huggingface_hub import get_repo_discussions >>> for discussion in get_repo_discussions( ... repo_id="bigscience/bloom", ... author="ArthurZ", ... discussion_type="pull_request", ... discussion_status="open", ... ): ... print(f"{discussion.num} - {discussion.title} by {discussion.author}, pr: {discussion.is_pull_request}") # 19 - Add Flax weights by ArthurZ, pr: True ``` `HfApi.get_repo_discussions` 返回一个 [生成器](https://docs.python.org/3.7/howto/functional.html#generators) 生成 [`Discussion`] 对象。 要获取所有讨论并存储为列表,可以运行: ```python >>> from huggingface_hub import get_repo_discussions >>> discussions_list = list(get_repo_discussions(repo_id="bert-base-uncased")) ``` [`HfApi.get_repo_discussions`] 返回的 [`Discussion`] 对象提供讨论或拉取请求(Pull Request)的高级概览。您还可以使用 [`HfApi.get_discussion_details`] 获取更详细的信息: ```python >>> from huggingface_hub import get_discussion_details >>> get_discussion_details( ... repo_id="bigscience/bloom-1b3", ... discussion_num=2 ... ) DiscussionWithDetails( num=2, author='cakiki', title='Update VRAM memory for the V100s', status='open', is_pull_request=True, events=[ DiscussionComment(type='comment', author='cakiki', ...), DiscussionCommit(type='commit', author='cakiki', summary='Update VRAM memory for the V100s', oid='1256f9d9a33fa8887e1c1bf0e09b4713da96773a', ...), ], conflicting_files=[], target_branch='refs/heads/main', merge_commit_oid=None, diff='diff --git a/README.md b/README.md\nindex a6ae3b9294edf8d0eda0d67c7780a10241242a7e..3a1814f212bc3f0d3cc8f74bdbd316de4ae7b9e3 100644\n--- a/README.md\n+++ b/README.md\n@@ -132,7 +132,7 [...]', ) ``` [`HfApi.get_discussion_details`] 返回一个 [`DiscussionWithDetails`] 对象,它是 [`Discussion`] 的子类,包含有关讨论或拉取请求(Pull Request)的更详细信息。详细信息包括所有评论、状态更改以及讨论的重命名信息,可通过 [`DiscussionWithDetails.events`] 获取。 如果是拉取请求(Pull Request),您可以通过 [`DiscussionWithDetails.diff`] 获取原始的 git diff。拉取请求(Pull Request)的所有提交都列在 [`DiscussionWithDetails.events`] 中。 ## 以编程方式创建和编辑讨论或拉取请求 [`HfApi`] 类还提供了创建和编辑讨论及拉取请求(Pull Request)的方法。 您需要一个 [访问令牌](https://huggingface.co/docs/hub/security-tokens) 来创建和编辑讨论或拉取请求(Pull Request)。 在 Hub 上对 repo 提出修改建议的最简单方法是使用 [`create_commit`] API:只需将 `create_pr` 参数设置为 `True` 。此参数也适用于其他封装了 [`create_commit`] 的方法: * [`upload_file`] * [`upload_folder`] * [`delete_file`] * [`delete_folder`] * [`metadata_update`] ```python >>> from huggingface_hub import metadata_update >>> metadata_update( ... repo_id="username/repo_name", ... metadata={"tags": ["computer-vision", "awesome-model"]}, ... create_pr=True, ... ) ``` 您还可以使用 [`HfApi.create_discussion`](或 [`HfApi.create_pull_request`])在仓库上创建讨论(或拉取请求)。以这种方式打开拉取请求在您需要本地处理更改时很有用。以这种方式打开的拉取请求将处于“draft”模式。 ```python >>> from huggingface_hub import create_discussion, create_pull_request >>> create_discussion( ... repo_id="username/repo-name", ... title="Hi from the huggingface_hub library!", ... token="", ... ) DiscussionWithDetails(...) >>> create_pull_request( ... repo_id="username/repo-name", ... title="Hi from the huggingface_hub library!", ... token="", ... ) DiscussionWithDetails(..., is_pull_request=True) ``` 您可以使用 [`HfApi`] 类来方便地管理拉取请求和讨论。例如: * [`comment_discussion`] 添加评论 * [`edit_discussion_comment`] 编辑评论 * [`rename_discussion`] 重命名讨论或拉取请求 * [`change_discussion_status`] 打开或关闭讨论/拉取请求 * [`merge_pull_request`] 合并拉取请求 请访问 [`HfApi`] 文档页面,获取所有可用方法的完整参考 ## 推送更改到拉取请求(Pull Request) *敬请期待!* ## 参见 有关更详细的参考,请访问 [讨论和拉取请求](../package_reference/community) 和 [hf_api](../package_reference/hf_api) 文档页面。 huggingface_hub-0.31.1/docs/source/cn/guides/hf_file_system.md000066400000000000000000000121361500667546600243550ustar00rootroot00000000000000 # 通过文件系统 API 与 Hub 交互 除了 [`HfApi`],`huggingface_hub` 库还提供了 [`HfFileSystem`],这是一个符合 [fsspec](https://filesystem-spec.readthedocs.io/en/latest/) 规范的 Python 文件接口,用于与 Hugging Face Hub 交互。[`HfFileSystem`] 基于 [`HfApi`] 构建,提供了典型的文件系统操作,如 `cp`、`mv`、`ls`、`du`、`glob`、`get_file` 和 `put_file`。 [`HfFileSystem`] 提供了 fsspec 兼容性,这对于需要它的库(例如,直接使用 `pandas` 读取 Hugging Face 数据集)非常有用。然而,由于这种兼容性层,会引入额外的开销。为了更好的性能和可靠性,建议尽可能使用 [`HfApi`] 方法。 ## 使用方法 ```python >>> from huggingface_hub import HfFileSystem >>> fs = HfFileSystem() >>> # 列出目录中的所有文件 >>> fs.ls("datasets/my-username/my-dataset-repo/data", detail=False) ['datasets/my-username/my-dataset-repo/data/train.csv', 'datasets/my-username/my-dataset-repo/data/test.csv'] >>> # 列出仓库中的所有 ".csv" 文件 >>> fs.glob("datasets/my-username/my-dataset-repo/**/*.csv") ['datasets/my-username/my-dataset-repo/data/train.csv', 'datasets/my-username/my-dataset-repo/data/test.csv'] >>> # 读取远程文件 >>> with fs.open("datasets/my-username/my-dataset-repo/data/train.csv", "r") as f: ... train_data = f.readlines() >>> # 远程文件内容读取为字符串 >>> train_data = fs.read_text("datasets/my-username/my-dataset-repo/data/train.csv", revision="dev") >>> # 写入远程文件 >>> with fs.open("datasets/my-username/my-dataset-repo/data/validation.csv", "w") as f: ... f.write("text,label") ... f.write("Fantastic movie!,good") ``` 可以传递可选的 `revision` 参数,以从特定提交(如分支、标签名或提交哈希)运行操作。 与 Python 内置的 `open` 不同,`fsspec` 的 `open` 默认是二进制模式 `"rb"`。这意味着您必须明确设置模式为 `"r"` 以读取文本模式,或 `"w"` 以写入文本模式。目前不支持追加到文件(模式 `"a"` 和 `"ab"`) ## 集成 [`HfFileSystem`] 可以与任何集成了 `fsspec` 的库一起使用,前提是 URL 遵循以下格式: ``` hf://[][@]/ ```
对于数据集,`repo_type_prefix` 为 `datasets/`,对于Space,`repo_type_prefix`为 `spaces/`,模型不需要在 URL 中使用这样的前缀。 以下是一些 [`HfFileSystem`] 简化与 Hub 交互的有趣集成: * 从 Hub 仓库读取/写入 [Pandas](https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html#reading-writing-remote-files) DataFrame : ```python >>> import pandas as pd >>> # 将远程 CSV 文件读取到 DataFrame >>> df = pd.read_csv("hf://datasets/my-username/my-dataset-repo/train.csv") >>> # 将 DataFrame 写入远程 CSV 文件 >>> df.to_csv("hf://datasets/my-username/my-dataset-repo/test.csv") ``` 同样的工作流程也适用于 [Dask](https://docs.dask.org/en/stable/how-to/connect-to-remote-data.html) 和 [Polars](https://pola-rs.github.io/polars/py-polars/html/reference/io.html) DataFrames. * 使用 [DuckDB](https://duckdb.org/docs/guides/python/filesystems) 查询(远程)Hub文件: ```python >>> from huggingface_hub import HfFileSystem >>> import duckdb >>> fs = HfFileSystem() >>> duckdb.register_filesystem(fs) >>> # 查询远程文件并将结果返回为 DataFrame >>> fs_query_file = "hf://datasets/my-username/my-dataset-repo/data_dir/data.parquet" >>> df = duckdb.query(f"SELECT * FROM '{fs_query_file}' LIMIT 10").df() ``` * 使用 [Zarr](https://zarr.readthedocs.io/en/stable/tutorial.html#io-with-fsspec) 将 Hub 作为数组存储: ```python >>> import numpy as np >>> import zarr >>> embeddings = np.random.randn(50000, 1000).astype("float32") >>> # 将数组写入仓库 >>> with zarr.open_group("hf://my-username/my-model-repo/array-store", mode="w") as root: ... foo = root.create_group("embeddings") ... foobar = foo.zeros('experiment_0', shape=(50000, 1000), chunks=(10000, 1000), dtype='f4') ... foobar[:] = embeddings >>> # 从仓库读取数组 >>> with zarr.open_group("hf://my-username/my-model-repo/array-store", mode="r") as root: ... first_row = root["embeddings/experiment_0"][0] ``` ## 认证 在许多情况下,您必须登录 Hugging Face 账户才能与 Hub 交互。请参阅文档的[认证](../quick-start#authentication) 部分,了解有关 Hub 上认证方法的更多信息。 也可以通过将您的 token 作为参数传递给 [`HfFileSystem`] 以编程方式登录: ```python >>> from huggingface_hub import HfFileSystem >>> fs = HfFileSystem(token=token) ``` 如果您以这种方式登录,请注意在共享源代码时不要意外泄露令牌! huggingface_hub-0.31.1/docs/source/cn/guides/overview.md000066400000000000000000000136261500667546600232300ustar00rootroot00000000000000 # 操作指南 在本节中,您将找到帮助您实现特定目标的实用指南。 查看这些指南,了解如何使用 huggingface_hub 解决实际问题: huggingface_hub-0.31.1/docs/source/cn/guides/repository.md000066400000000000000000000231551500667546600235770ustar00rootroot00000000000000 # 创建和管理存储库 Hugging Face Hub是一组 Git 存储库。[Git](https://git-scm.com/)是软件开发中广泛使用的工具,可以在协作工作时轻松对项目进行版本控制。本指南将向您展示如何与 Hub 上的存储库进行交互,特别关注以下内容: - 创建和删除存储库 - 管理分支和标签 - 重命名您的存储库 - 更新您的存储库可见性 - 管理存储库的本地副本 如果您习惯于使用类似于GitLab/GitHub/Bitbucket等平台,您可能首先想到使用 `git`命令行工具来克隆存储库(`git clone`)、提交更改(`git add` , ` git commit`)并推送它们(`git push`)。在使用 Hugging Face Hub 时,这是有效的。然而,软件工程和机器学习并不具有相同的要求和工作流程。模型存储库可能会维护大量模型权重文件以适应不同的框架和工具,因此克隆存储库会导致您维护大量占用空间的本地文件夹。因此,使用我们的自定义HTTP方法可能更有效。您可以阅读我们的[git与HTTP相比较](../concepts/git_vs_http)解释页面以获取更多详细信息 如果你想在Hub上创建和管理一个仓库,你的计算机必须处于登录状态。如果尚未登录,请参考[此部分](../quick-start#login)。在本指南的其余部分,我们将假设你的计算机已登录 ## 仓库创建和删除 第一步是了解如何创建和删除仓库。你只能管理你拥有的仓库(在你的用户名命名空间下)或者你具有写入权限的组织中的仓库 ### 创建一个仓库 使用 [`create_repo`] 创建一个空仓库,并通过 `repo_id`参数为其命名 `repo_id`是你的命名空间,后面跟着仓库名称:`username_or_org/repo_name` 运行以下代码,以创建仓库: ```py >>> from huggingface_hub import create_repo >>> create_repo("lysandre/test-model") 'https://huggingface.co/lysandre/test-model' ``` 默认情况下,[`create_repo`] 会创建一个模型仓库。但是你可以使用 `repo_type`参数来指定其他仓库类型。例如,如果你想创建一个数据集仓库 请运行以下代码: ```py >>> from huggingface_hub import create_repo >>> create_repo("lysandre/test-dataset", repo_type="dataset") 'https://huggingface.co/datasets/lysandre/test-dataset' ``` 创建仓库时,你可以使用 `private`参数设置仓库的可见性 请运行以下代码 ```py >>> from huggingface_hub import create_repo >>> create_repo("lysandre/test-private", private=True) ``` 如果你想在以后更改仓库的可见性,你可以使用[`update_repo_settings`] 函数 ### 删除一个仓库 使用 [`delete_repo`] 删除一个仓库。确保你确实想要删除仓库,因为这是一个不可逆转的过程!做完上述过程后,指定你想要删除的仓库的 `repo_id` 请运行以下代码: ```py >>> delete_repo(repo_id="lysandre/my-corrupted-dataset", repo_type="dataset") ``` ### 克隆一个仓库(仅适用于 Spaces) 在某些情况下,你可能想要复制别人的仓库并根据自己的用例进行调整。对于 Spaces,你可以使用 [`duplicate_space`] 方法来实现。它将复制整个仓库。 你仍然需要配置自己的设置(硬件和密钥)。查看我们的[管理你的Space指南](./manage-spaces)以获取更多详细信息。 请运行以下代码: ```py >>> from huggingface_hub import duplicate_space >>> duplicate_space("multimodalart/dreambooth-training", private=False) RepoUrl('https://huggingface.co/spaces/nateraw/dreambooth-training',...) ``` ## 上传和下载文件 既然您已经创建了您的存储库,您现在也可以推送更改至其中并从中下载文件 这两个主题有它们自己的指南。请[上传指南](./upload) 和[下载指南](./download)来学习如何使用您的存储库。 ## 分支和标签 Git存储库通常使用分支来存储同一存储库的不同版本。标签也可以用于标记存储库的特定状态,例如,在发布版本这个情况下。更一般地说,分支和标签被称为[git引用](https://git-scm.com/book/en/v2/Git-Internals-Git-References). ### 创建分支和标签 你可以使用[`create_branch`]和[`create_tag`]来创建新的分支和标签: 请运行以下代码: ```py >>> from huggingface_hub import create_branch, create_tag # Create a branch on a Space repo from `main` branch >>> create_branch("Matthijs/speecht5-tts-demo", repo_type="space", branch="handle-dog-speaker") # Create a tag on a Dataset repo from `v0.1-release` branch >>> create_branch("bigcode/the-stack", repo_type="dataset", revision="v0.1-release", tag="v0.1.1", tag_message="Bump release version.") ``` 同时,你可以以相同的方式使用 [`delete_branch`] 和 [`delete_tag`] 函数来删除分支或标签 ### 列出所有的分支和标签 你还可以使用 [`list_repo_refs`] 列出存储库中的现有 Git 引用 请运行以下代码: ```py >>> from huggingface_hub import list_repo_refs >>> api.list_repo_refs("bigcode/the-stack", repo_type="dataset") GitRefs( branches=[ GitRefInfo(name='main', ref='refs/heads/main', target_commit='18edc1591d9ce72aa82f56c4431b3c969b210ae3'), GitRefInfo(name='v1.1.a1', ref='refs/heads/v1.1.a1', target_commit='f9826b862d1567f3822d3d25649b0d6d22ace714') ], converts=[], tags=[ GitRefInfo(name='v1.0', ref='refs/tags/v1.0', target_commit='c37a8cd1e382064d8aced5e05543c5f7753834da') ] ) ``` ## 修改存储库设置 存储库具有一些可配置的设置。大多数情况下,您通常会在浏览器中的存储库设置页面上手动配置这些设置。要配置存储库,您必须具有对其的写访问权限(拥有它或属于组织)。在本节中,我们将看到您还可以使用 `huggingface_hub` 在编程方式上配置的设置。 一些设置是特定于 Spaces(硬件、环境变量等)的。要配置这些设置,请参考我们的[管理Spaces](../guides/manage-spaces)指南。 ### 更新可见性 一个存储库可以是公共的或私有的。私有存储库仅对您或存储库所在组织的成员可见。 请运行以下代码将存储库更改为私有: ```py >>> from huggingface_hub import update_repo_settings >>> update_repo_settings(repo_id=repo_id, private=True) ``` ### 重命名您的存储库 您可以使用 [`move_repo`] 在 Hub 上重命名您的存储库。使用这种方法,您还可以将存储库从一个用户移动到一个组织。在这样做时,有一些[限制](https://hf.co/docs/hub/repositories-settings#renaming-or-transferring-a-repo)需要注意。例如,您不能将存储库转移到另一个用户。 请运行以下代码: ```py >>> from huggingface_hub import move_repo >>> move_repo(from_id="Wauplin/cool-model", to_id="huggingface/cool-model") ``` ## 管理存储库的本地副本 上述所有操作都可以通过HTTP请求完成。然而,在某些情况下,您可能希望在本地拥有存储库的副本,并使用您熟悉的Git命令与之交互。 [`Repository`] 类允许您使用类似于Git命令的函数与Hub上的文件和存储库进行交互。它是对Git和Git-LFS方法的包装,以使用您已经了解和喜爱的Git命令。在开始之前,请确保已安装Git-LFS(请参阅[此处](https://git-lfs.github.com/)获取安装说明)。 ### 使用本地存储库 使用本地存储库路径实例化一个 [`Repository`] 对象: 请运行以下代码: ```py >>> from huggingface_hub import Repository >>> repo = Repository(local_dir="//") ``` ### 克隆 `clone_from`参数将一个存储库从Hugging Face存储库ID克隆到由 `local_dir`参数指定的本地目录: 请运行以下代码: ```py >>> from huggingface_hub import Repository >>> repo = Repository(local_dir="w2v2", clone_from="facebook/wav2vec2-large-960h-lv60") ``` `clone_from`还可以使用URL克隆存储库: 请运行以下代码: ```py >>> repo = Repository(local_dir="huggingface-hub", clone_from="https://huggingface.co/facebook/wav2vec2-large-960h-lv60") ``` 你可以将`clone_from`参数与[`create_repo`]结合使用,以创建并克隆一个存储库: 请运行以下代码: ```py >>> repo_url = create_repo(repo_id="repo_name") >>> repo = Repository(local_dir="repo_local_path", clone_from=repo_url) ``` 当你克隆一个存储库时,通过在克隆时指定`git_user`和`git_email`参数,你还可以为克隆的存储库配置Git用户名和电子邮件。当用户提交到该存储库时,Git将知道提交的作者是谁。 请运行以下代码: ```py >>> repo = Repository( ... "my-dataset", ... clone_from="/", ... token=True, ... repo_type="dataset", ... git_user="MyName", ... git_email="me@cool.mail" ... ) ``` ### 分支 分支对于协作和实验而不影响当前文件和代码非常重要。使用[`~Repository.git_checkout`]来在不同的分支之间切换。例如,如果你想从 `branch1`切换到 `branch2`: 请运行以下代码: ```py >>> from huggingface_hub import Repository >>> repo = Repository(local_dir="huggingface-hub", clone_from="/", revision='branch1') >>> repo.git_checkout("branch2") ``` ### 拉取 [`~Repository.git_pull`] 允许你使用远程存储库的更改更新当前本地分支: 请运行以下代码: ```py >>> from huggingface_hub import Repository >>> repo.git_pull() ``` 如果你希望本地的提交发生在你的分支被远程的新提交更新之后,请设置`rebase=True`: ```py >>> repo.git_pull(rebase=True) ``` huggingface_hub-0.31.1/docs/source/cn/guides/search.md000066400000000000000000000037141500667546600226240ustar00rootroot00000000000000 # 搜索 Hub 在本教程中,您将学习如何使用 `huggingface_hub` 在 Hub 上搜索模型、数据集和Spaces。 ## 如何列出仓库? `huggingface_hub`库包括一个 HTTP 客户端 [`HfApi`],用于与 Hub 交互。 除此之外,它还可以列出存储在 Hub 上的模型、数据集和Spaces: ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> models = api.list_models() ``` [`list_models`] 返回一个迭代器,包含存储在 Hub 上的模型。 同样,您可以使用 [`list_datasets`] 列出数据集,使用 [`list_spaces`] 列出 Spaces。 ## 如何过滤仓库? 列出仓库是一个好开始,但现在您可能希望对搜索结果进行过滤。 列出时,可以使用多个属性来过滤结果,例如: - `filter` - `author` - `search` - ... 让我们看一个示例,获取所有在 Hub 上进行图像分类的模型,这些模型已在 imagenet 数据集上训练,并使用 PyTorch 运行。 ```py models = hf_api.list_models( task="image-classification", library="pytorch", trained_dataset="imagenet", ) ``` 在过滤时,您还可以对模型进行排序,并仅获取前几个结果。例如,以下示例获取了 Hub 上下载量最多的前 5 个数据集: ```py >>> list(list_datasets(sort="downloads", direction=-1, limit=5)) [DatasetInfo( id='argilla/databricks-dolly-15k-curated-en', author='argilla', sha='4dcd1dedbe148307a833c931b21ca456a1fc4281', last_modified=datetime.datetime(2023, 10, 2, 12, 32, 53, tzinfo=datetime.timezone.utc), private=False, downloads=8889377, (...) ``` 如果您想要在Hub上探索可用的过滤器, 请在浏览器中访问 [models](https://huggingface.co/models) 和 [datasets](https://huggingface.co/datasets) 页面 ,尝试不同的参数并查看URL中的值。 huggingface_hub-0.31.1/docs/source/cn/index.md000066400000000000000000000075111500667546600212050ustar00rootroot00000000000000 # 🤗 Hub 客户端库 通过`huggingface_hub` 库,您可以与面向机器学习开发者和协作者的平台 [Hugging Face Hub](https://huggingface.co/)进行交互,找到适用于您所在项目的预训练模型和数据集,体验在平台托管的数百个机器学习应用,还可以创建或分享自己的模型和数据集并于社区共享。以上所有都可以用Python在`huggingface_hub` 库中轻松实现。 阅读[快速入门指南](快速入门指南)以开始使用huggingface_hub库。您将学习如何从Hub下载文件,创建存储库以及将文件上传到Hub。继续阅读以了解更多关于如何在🤗Hub上管理您的存储库,如何参与讨论或者甚至如何访问推理API的信息。 通过 `huggingface_hub`库,您可以与面向机器学习开发者和协作者的平台 [Hugging Face Hub](https://huggingface.co/)进行交互,找到适用于您所在项目的预训练模型和数据集,体验在平台托管的数百个机器学习应用,还可以创建或分享自己的模型和数据集并于社区共享。以上所有都可以用Python在 `huggingface_hub`库中轻松实现。 ## 贡献 所有对 huggingface_hub 的贡献都受到欢迎和同等重视!🤗 除了在代码中添加或修复现有问题外,您还可以通过确保其准确且最新来帮助改进文档,在问题上帮助回答问题,并请求您认为可以改进库的新功能。请查看[贡献指南](https://github.com/huggingface/huggingface_hub/blob/main/CONTRIBUTING.md) 了解有关如何提交新问题或功能请求、如何提交拉取请求以及如何测试您的贡献以确保一切正常运行的更多信息。 当然,贡献者也应该尊重我们的[行为准则](https://github.com/huggingface/huggingface_hub/blob/main/CODE_OF_CONDUCT.md),以便为每个人创建一个包容和欢迎的协作空间。 huggingface_hub-0.31.1/docs/source/cn/installation.md000066400000000000000000000152211500667546600225740ustar00rootroot00000000000000 # 安装 在开始之前,您需要通过安装适当的软件包来设置您的环境 huggingface_hub 在 Python 3.8 或更高版本上进行了测试,可以保证在这些版本上正常运行。如果您使用的是 Python 3.7 或更低版本,可能会出现兼容性问题 ## 使用 pip 安装 我们建议将huggingface_hub安装在[虚拟环境](https://docs.python.org/3/library/venv.html)中. 如果你不熟悉 Python虚拟环境,可以看看这个[指南](https://packaging.python.org/en/latest/guides/installing-using-pip-and-virtual-environments/). 虚拟环境可以更容易地管理不同的项目,避免依赖项之间的兼容性问题 首先在你的项目目录中创建一个虚拟环境,请运行以下代码: ```bash python -m venv .env ``` 在Linux和macOS上,请运行以下代码激活虚拟环境: ```bash source .env/bin/activate ``` 在 Windows 上,请运行以下代码激活虚拟环境: ```bash .env/Scripts/activate ``` 现在您可以从[PyPi注册表](https://pypi.org/project/huggingface-hub/)安装 `huggingface_hub`: ```bash pip install --upgrade huggingface_hub ``` 完成后,[检查安装](#check-installation)是否正常工作 ### 安装可选依赖项 `huggingface_hub`的某些依赖项是 [可选](https://setuptools.pypa.io/en/latest/userguide/dependency_management.html#optional-dependencies) 的,因为它们不是运行`huggingface_hub`的核心功能所必需的.但是,如果没有安装可选依赖项, `huggingface_hub` 的某些功能可能会无法使用 您可以通过`pip`安装可选依赖项,请运行以下代码: ```bash # 安装 TensorFlow 特定功能的依赖项 # /!\ 注意:这不等同于 `pip install tensorflow` pip install 'huggingface_hub[tensorflow]' # 安装 TensorFlow 特定功能和 CLI 特定功能的依赖项 pip install 'huggingface_hub[cli,torch]' ``` 这里列出了 `huggingface_hub` 的可选依赖项: - `cli`:为 `huggingface_hub` 提供更方便的命令行界面 - `fastai`,` torch`, `tensorflow`: 运行框架特定功能所需的依赖项 - `dev`:用于为库做贡献的依赖项。包括 `testing`(用于运行测试)、`typing`(用于运行类型检查器)和 `quality`(用于运行 linter) ### 从源代码安装 在某些情况下,直接从源代码安装`huggingface_hub`会更有趣。因为您可以使用最新的主版本`main`而非最新的稳定版本 `main`版本更有利于跟进平台的最新开发进度,例如,在最近一次官方发布之后和最新的官方发布之前所修复的某个错误 但是,这意味着`main`版本可能不总是稳定的。我们会尽力让其正常运行,大多数问题通常会在几小时或一天内解决。如果您遇到问题,请创建一个 Issue ,以便我们可以更快地解决! ```bash pip install git+https://github.com/huggingface/huggingface_hub # 使用pip从GitHub仓库安装Hugging Face Hub库 ``` 从源代码安装时,您还可以指定特定的分支。如果您想测试尚未合并的新功能或新错误修复,这很有用 ```bash pip install git+https://github.com/huggingface/huggingface_hub@my-feature-branch # 使用pip从指定的GitHub分支(my-feature-branch)安装Hugging Face Hub库 ``` 完成安装后,请[检查安装](#check-installation)是否正常工作 ### 可编辑安装 从源代码安装允许您设置[可编辑安装](https://pip.pypa.io/en/stable/topics/local-project-installs/#editable-installs).如果您计划为`huggingface_hub`做出贡献并需要测试代码更改,这是一个更高级的安装方式。您需要在本机上克隆一个`huggingface_hub`的本地副本 ```bash # 第一,使用以下命令克隆代码库 git clone https://github.com/huggingface/huggingface_hub.git # 然后,使用以下命令启动虚拟环境 cd huggingface_hub pip install -e . ``` 这些命令将你克隆存储库的文件夹与你的 Python 库路径链接起来。Python 现在将除了正常的库路径之外,还会在你克隆到的文件夹中查找。例如,如果你的 Python 包通常安装在`./.venv/lib/python3.13/site-packages/`中,Python 还会搜索你克隆的文件夹`./huggingface_hub/` ## 通过 conda 安装 如果你更熟悉它,你可以使用[conda-forge channel](https://anaconda.org/conda-forge/huggingface_hub)渠道来安装 `huggingface_hub` 请运行以下代码: ```bash conda install -c conda-forge huggingface_hub ``` 完成安装后,请[检查安装](#check-installation)是否正常工作 ## 验证安装 安装完成后,通过运行以下命令检查`huggingface_hub`是否正常工作: ```bash python -c "from huggingface_hub import model_info; print(model_info('gpt2'))" ``` 这个命令将从 Hub 获取有关 [gpt2](https://huggingface.co/gpt2) 模型的信息。 输出应如下所示: ```text Model Name: gpt2 模型名称 Tags: ['pytorch', 'tf', 'jax', 'tflite', 'rust', 'safetensors', 'gpt2', 'text-generation', 'en', 'doi:10.57967/hf/0039', 'transformers', 'exbert', 'license:mit', 'has_space'] 标签 Task: text-generation 任务:文本生成 ``` ## Windows局限性 为了实现让每个人都能使用机器学习的目标,我们构建了 `huggingface_hub`库,使其成为一个跨平台的库,尤其可以在 Unix 和 Windows 系统上正常工作。但是,在某些情况下,`huggingface_hub`在Windows上运行时会有一些限制。以下是一些已知问题的完整列表。如果您遇到任何未记录的问题,请打开 [Github上的issue](https://github.com/huggingface/huggingface_hub/issues/new/choose).让我们知道 - `huggingface_hub`的缓存系统依赖于符号链接来高效地缓存从Hub下载的文件。在Windows上,您必须激活开发者模式或以管理员身份运行您的脚本才能启用符号链接。如果它们没有被激活,缓存系统仍然可以工作,但效率较低。有关更多详细信息,请阅读[缓存限制](./guides/manage-cache#limitations)部分。 - Hub上的文件路径可能包含特殊字符(例如:`path/to?/my/file`)。Windows对[特殊字符](https://learn.microsoft.com/en-us/windows/win32/intl/character-sets-used-in-file-names)更加严格,这使得在Windows上下载这些文件变得不可能。希望这是罕见的情况。如果您认为这是一个错误,请联系存储库所有者或我们,以找出解决方案。 ## 后记 一旦您在机器上正确安装了`huggingface_hub`,您可能需要[配置环境变量](package_reference/environment_variables)或者[查看我们的指南之一](guides/overview)以开始使用。 huggingface_hub-0.31.1/docs/source/cn/quick-start.md000066400000000000000000000131161500667546600223430ustar00rootroot00000000000000 # 快速入门 [Hugging Face Hub](https://huggingface.co/)是分享机器学习模型、演示、数据集和指标的首选平台`huggingface_hub`库帮助你在不离开开发环境的情况下与 Hub 进行交互。你可以轻松地创建和管理仓库,下载和上传文件,并从 Hub 获取有用的模型和数据集元数据 ## 安装 要开始使用,请安装`huggingface_hub`库: ```bash pip install --upgrade huggingface_hub ``` 更多详细信息,请查看[安装指南](installation) ## 下载文件 Hugging Face 平台上的存储库是使用 git 版本控制的,用户可以下载单个文件或整个存储库。您可以使用 [`hf_hub_download`] 函数下载文件。该函数将下载并将文件缓存在您的本地磁盘上。下次您需要该文件时,它将从您的缓存中加载,因此您无需重新下载它 您将需要填写存储库 ID 和您要下载的文件的文件名。例如,要下载[Pegasus](https://huggingface.co/google/pegasus-xsum)模型配置文件,请运行以下代码: ```py >>> from huggingface_hub import hf_hub_download >>> hf_hub_download(repo_id="google/pegasus-xsum", filename="config.json") repo_id: 仓库的 ID 或路径,这里使用了 "google/pegasus-xsum" filename: 要下载的文件名,这里是 "config.json" ``` 要下载文件的特定版本,请使用`revision`参数指定分支名称、标签或提交哈希。如果您选择使用提交哈希,它必须是完整长度的哈希,而不是较短的7个字符的提交哈希: ```py >>> from huggingface_hub import hf_hub_download >>> hf_hub_download( ... repo_id="google/pegasus-xsum", ... filename="config.json", ... revision="4d33b01d79672f27f001f6abade33f22d993b151" ... ) ``` 有关更多详细信息和选项,请参阅 [`hf_hub_download`] 的 API 参考文档 ## 登录 在许多情况下,您必须使用 Hugging Face 帐户进行登录后才能与 Hugging Face 模型库进行交互,例如下载私有存储库、上传文件、创建 PR 等。如果您还没有帐户,请[创建一个](https://huggingface.co/join),然后登录以获取您的 [用户访问令牌](https://huggingface.co/docs/hub/security-tokens),security-tokens从您的[设置页面](https://huggingface.co/settings/tokens)进入设置,用户访问令牌用于向模型库进行身份验证 运行以下代码,这将使用您的用户访问令牌登录到Hugging Face模型库 ```bash huggingface-cli login huggingface-cli login --token $HUGGINGFACE_TOKEN ``` 或者,你可以在笔记本电脑或脚本中使用 [`login`] 来进行程序化登录,请运行以下代码: ```py >>> from huggingface_hub import login >>> login() ``` 您还可以直接将令牌传递给 [`login`],如下所示:`login(token="hf_xxx")`。这将使用您的用户访问令牌登录到 Hugging Face 模型库,而无需您输入任何内容。但是,如果您这样做,请在共享源代码时要小心。最好从安全保管库中加载令牌,而不是在代码库/笔记本中显式保存它 您一次只能登录一个帐户。如果您使用另一个帐户登录您的机器,您将会从之前的帐户注销。请确保使用命令 `huggingface-cli whoami`来检查您当前使用的是哪个帐户。如果您想在同一个脚本中处理多个帐户,您可以在调用每个方法时提供您的令牌。这对于您不想在您的机器上存储任何令牌也很有用 一旦您登录了,所有对模型库的请求(即使是不需要认证的方法)都将默认使用您的访问令牌。如果您想禁用对令牌的隐式使用,您应该设置`HF_HUB_DISABLE_IMPLICIT_TOKEN`环境变量 ## 创建存储库 一旦您注册并登录,请使用 [`create_repo`] 函数创建存储库: ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.create_repo(repo_id="super-cool-model") ``` 如果您想将存储库设置为私有,请按照以下步骤操作: ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.create_repo(repo_id="super-cool-model", private=True) ``` 私有存储库将不会对任何人可见,除了您自己 创建存储库或将内容推送到 Hub 时,必须提供具有`写入`权限的用户访问令牌。您可以在创建令牌时在您的[设置页面](https://huggingface.co/settings/tokens)中选择权限 ## 上传文件 您可以使用 [`upload_file`] 函数将文件添加到您新创建的存储库。您需要指定: 1. 要上传的文件的路径 2. 文件在存储库中的位置 3. 您要将文件添加到的存储库的 ID ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.upload_file( ... path_or_fileobj="/home/lysandre/dummy-test/README.md" ... path_in_repo="README.md" ... repo_id="lysandre/test-model" ... ) ``` 要一次上传多个文件,请查看[上传指南](./guides/upload) ,该指南将向您介绍几种上传文件的方法(有或没有 git)。 ## 下一步 `huggingface_hub`库为用户提供了一种使用Python与Hub 进行交互的简单方法。要了解有关如何在Hub上管理文件和存储库的更多信息,我们建议您阅读我们的[操作方法指南](./guides/overview): - [管理您的存储库](./guides/repository) - [从Hub下载文件](./guides/download) - [将文件上传到Hub](./guides/upload) - [在Hub中搜索您的所需模型或数据集](./guides/search) - [了解如何使用 Inference API 进行快速推理](./guides/inference) huggingface_hub-0.31.1/docs/source/de/000077500000000000000000000000001500667546600175405ustar00rootroot00000000000000huggingface_hub-0.31.1/docs/source/de/_toctree.yml000066400000000000000000000021041500667546600220640ustar00rootroot00000000000000- title: "Starten" sections: - local: index title: Home - local: quick-start title: Kurzanleitung - local: installation title: Installation - title: "Anleitungen" sections: - local: guides/overview title: Übersicht - local: guides/download title: Dateien herunterladen - local: guides/upload title: Dateien hochladen - local: guides/hf_file_system title: HfFileSystem - local: guides/repository title: Repository - local: guides/search title: Suche - local: guides/inference title: Inferenz - local: guides/community title: Community-Tab - local: guides/manage-cache title: Cache - local: guides/model-cards title: Model Cards - local: guides/manage-spaces title: Verwalten Ihres Spaces - local: guides/integrations title: Integrieren einer Bibliothek - local: guides/webhooks_server title: Webhooks server - title: "Konzeptionelle Anleitungen" sections: - local: concepts/git_vs_http title: Git vs. HTTP-Paradigma huggingface_hub-0.31.1/docs/source/de/concepts/000077500000000000000000000000001500667546600213565ustar00rootroot00000000000000huggingface_hub-0.31.1/docs/source/de/concepts/git_vs_http.md000066400000000000000000000106011500667546600242300ustar00rootroot00000000000000 # Git vs. HTTP-Paradigma Die `huggingface_hub`-Bibliothek ist eine Bibliothek zur Interaktion mit dem Hugging Face Hub, einer Sammlung von auf Git basierenden Repositories (Modelle, Datensätze oder Spaces). Es gibt zwei Hauptmethoden, um auf den Hub mit `huggingface_hub` zuzugreifen. Der erste Ansatz, der sogenannte "Git-basierte" Ansatz, wird von der [`Repository`] Klasse geleitet. Diese Methode verwendet einen Wrapper um den `git`-Befehl mit zusätzlichen Funktionen, die speziell für die Interaktion mit dem Hub entwickelt wurden. Die zweite Option, die als "HTTP-basierter" Ansatz bezeichnet wird, umfasst das Senden von HTTP-Anfragen mit dem [`HfApi`] Client. Schauen wir uns die Vor- und Nachteile jeder Methode an. ## Repository: Der historische git-basierte Ansatz Ursprünglich wurde `huggingface_hub` größtenteils um die [`Repository`] Klasse herum entwickelt. Sie bietet Python-Wrapper für gängige git-Befehle wie `"git add"`, `"git commit"`, `"git push"`, `"git tag"`, `"git checkout"` usw. Die Bibliothek hilft auch beim Festlegen von Zugangsdaten und beim Tracking von großen Dateien, die in Machine-Learning-Repositories häufig verwendet werden. Darüber hinaus ermöglicht die Bibliothek das Ausführen ihrer Methoden im Hintergrund, was nützlich ist, um Daten während des Trainings hochzuladen. Der Hauptvorteil bei der Verwendung einer [`Repository`] besteht darin, dass Sie eine lokale Kopie des gesamten Repositorys auf Ihrem Computer pflegen können. Dies kann jedoch auch ein Nachteil sein, da es erfordert, diese lokale Kopie ständig zu aktualisieren und zu pflegen. Dies ähnelt der traditionellen Softwareentwicklung, bei der jeder Entwickler eine eigene lokale Kopie pflegt und Änderungen überträgt, wenn an einer Funktion gearbeitet wird. Im Kontext des Machine Learning ist dies jedoch nicht immer erforderlich, da Benutzer möglicherweise nur Gewichte für die Inferenz herunterladen oder Gewichte von einem Format in ein anderes konvertieren müssen, ohne das gesamte Repository zu klonen. ## HfApi: Ein flexibler und praktischer HTTP-Client Die [`HfApi`] Klasse wurde entwickelt, um eine Alternative zu lokalen Git-Repositories bereitzustellen, die besonders bei der Arbeit mit großen Modellen oder Datensätzen umständlich zu pflegen sein können. Die [`HfApi`] Klasse bietet die gleiche Funktionalität wie git-basierte Ansätze, wie das Herunterladen und Hochladen von Dateien sowie das Erstellen von Branches und Tags, jedoch ohne die Notwendigkeit eines lokalen Ordners, der synchronisiert werden muss. Zusätzlich zu den bereits von `git` bereitgestellten Funktionen bietet die [`HfApi`] Klasse zusätzliche Features wie die Möglichkeit, Repositories zu verwalten, Dateien mit Caching für effiziente Wiederverwendung herunterzuladen, im Hub nach Repositories und Metadaten zu suchen, auf Community-Funktionen wie Diskussionen, Pull Requests und Kommentare zuzugreifen und Spaces-Hardware und Geheimnisse zu konfigurieren. ## Was sollte ich verwenden ? Und wann ? Insgesamt ist der **HTTP-basierte Ansatz in den meisten Fällen die empfohlene Methode zur Verwendung von** `huggingface_hub`. Es gibt jedoch einige Situationen, in denen es vorteilhaft sein kann, eine lokale Git-Kopie (mit [`Repository`]) zu pflegen: - Wenn Sie ein Modell auf Ihrem Computer trainieren, kann es effizienter sein, einen herkömmlichen git-basierten Workflow zu verwenden und regelmäßige Updates zu pushen. [`Repository`] ist für diese Art von Situation mit seiner Fähigkeit zur Hintergrundarbeit optimiert. - Wenn Sie große Dateien manuell bearbeiten müssen, ist `git` die beste Option, da es nur die Differenz an den Server sendet. Mit dem [`HfAPI`] Client wird die gesamte Datei bei jeder Bearbeitung hochgeladen. Beachten Sie jedoch, dass die meisten großen Dateien binär sind und daher sowieso nicht von Git-Diffs profitieren. Nicht alle Git-Befehle sind über [`HfApi`] verfügbar. Einige werden vielleicht nie implementiert, aber wir bemühen uns ständig, die Lücken zu schließen und zu verbessern. Wenn Sie Ihren Anwendungsfall nicht abgedeckt sehen, öffnen Sie bitte [ein Issue auf Github](https://github.com/huggingface/huggingface_hub)! Wir freuen uns über Feedback, um das 🤗-Ökosystem mit und für unsere Benutzer aufzubauen. huggingface_hub-0.31.1/docs/source/de/guides/000077500000000000000000000000001500667546600210205ustar00rootroot00000000000000huggingface_hub-0.31.1/docs/source/de/guides/community.md000066400000000000000000000135041500667546600233710ustar00rootroot00000000000000 # Interaktion mit Diskussionen und Pull-Requests Die `huggingface_hub`-Bibliothek bietet eine Python-Schnittstelle, um mit Pull-Requests und Diskussionen auf dem Hub zu interagieren. Besuchen Sie [die spezielle Dokumentationsseite](https://huggingface.co/docs/hub/repositories-pull-requests-discussions), um einen tieferen Einblick in Diskussionen und Pull-Requests auf dem Hub zu erhalten und zu erfahren, wie sie im Hintergrund funktionieren. ## Diskussionen und Pull-Requests vom Hub abrufen Die Klasse `HfApi` ermöglicht es Ihnen, Diskussionen und Pull-Requests zu einem gegebenen Repository abzurufen: ```python >>> from huggingface_hub import get_repo_discussions >>> for discussion in get_repo_discussions(repo_id="bigscience/bloom-1b3"): ... print(f"{discussion.num} - {discussion.title}, pr: {discussion.is_pull_request}") # 11 - Add Flax weights, pr: True # 10 - Update README.md, pr: True # 9 - Training languages in the model card, pr: True # 8 - Update tokenizer_config.json, pr: True # 7 - Slurm training script, pr: False [...] ``` `HfApi.get_repo_discussions` gibt einen [Generator](https://docs.python.org/3.7/howto/functional.html#generators) zurück, der [`Diskussion`]-Objekte liefert. Um alle Diskussionen in einer einzelnen Liste zu erhalten, führen Sie den folgenden Befehl aus: ```python >>> from huggingface_hub import get_repo_discussions >>> discussions_list = list(get_repo_discussions(repo_id="bert-base-uncased")) ``` Das von [`HfApi.get_repo_discussions`] zurückgegebene [`Diskussion`]-Objekt enthält einen Überblick über die Diskussion oder Pull-Requests. Sie können auch detailliertere Informationen mit [`HfApi.get_discussion_details`] abrufen: ```python >>> from huggingface_hub import get_discussion_details >>> get_discussion_details( ... repo_id="bigscience/bloom-1b3", ... discussion_num=2 ... ) DiscussionWithDetails( num=2, author='cakiki', title='Update VRAM memory for the V100s', status='open', is_pull_request=True, events=[ DiscussionComment(type='comment', author='cakiki', ...), DiscussionCommit(type='commit', author='cakiki', summary='Update VRAM memory for the V100s', oid='1256f9d9a33fa8887e1c1bf0e09b4713da96773a', ...), ], conflicting_files=[], target_branch='refs/heads/main', merge_commit_oid=None, diff='diff --git a/README.md b/README.md\nindex a6ae3b9294edf8d0eda0d67c7780a10241242a7e..3a1814f212bc3f0d3cc8f74bdbd316de4ae7b9e3 100644\n--- a/README.md\n+++ b/README.md\n@@ -132,7 +132,7 [...]', ) ``` [`HfApi.get_discussion_details`] gibt ein [`DiskussionMitDetails`]-Objekt zurück, das eine Unterklasse von [`Diskussion`] mit detaillierteren Informationen über die Diskussion oder Pull-Requests ist. Informationen beinhalten alle Kommentare, Statusänderungen und Umbenennungen der Diskussion mittels [`DiskussionMitDetails.events`]. Im Fall eines Pull-Requests können Sie mit [`DiskussionMitDetails.diff`] den rohen git diff abrufen. Alle Commits des Pull-Requests sind in [`DiskussionMitDetails.events`] aufgelistet. ## Diskussion oder Pull-Request programmatisch erstellen und bearbeiten Die [`HfApi`]-Klasse bietet auch Möglichkeiten, Diskussionen und Pull-Requests zu erstellen und zu bearbeiten. Sie benötigen ein [Access Token](https://huggingface.co/docs/hub/security-tokens), um Diskussionen oder Pull-Requests zu erstellen und zu bearbeiten. Die einfachste Möglichkeit, Änderungen an einem Repo auf dem Hub vorzuschlagen, ist über die [`create_commit`]-API: Setzen Sie einfach das `create_pr`-Parameter auf `True`. Dieser Parameter ist auch bei anderen Methoden verfügbar, die [`create_commit`] umfassen: * [`upload_file`] * [`upload_folder`] * [`delete_file`] * [`delete_folder`] * [`metadata_update`] ```python >>> from huggingface_hub import metadata_update >>> metadata_update( ... repo_id="username/repo_name", ... metadata={"tags": ["computer-vision", "awesome-model"]}, ... create_pr=True, ... ) ``` Sie können auch [`HfApi.create_discussion`] (bzw. [`HfApi.create_pull_request`]) verwenden, um eine Diskussion (bzw. einen Pull-Request) für ein Repository zu erstellen. Das Öffnen eines Pull-Requests auf diese Weise kann nützlich sein, wenn Sie lokal an Änderungen arbeiten müssen. Auf diese Weise geöffnete Pull-Requests befinden sich im `"Entwurfs"`-Modus. ```python >>> from huggingface_hub import create_discussion, create_pull_request >>> create_discussion( ... repo_id="username/repo-name", ... title="Hi from the huggingface_hub library!", ... token="", ... ) DiscussionWithDetails(...) >>> create_pull_request( ... repo_id="username/repo-name", ... title="Hi from the huggingface_hub library!", ... token="", ... ) DiscussionWithDetails(..., is_pull_request=True) ``` Das Verwalten von Pull-Requests und Diskussionen kann vollständig mit der [`HfApi`]-Klasse durchgeführt werden. Zum Beispiel: * [`comment_discussion`] zum Hinzufügen von Kommentaren * [`edit_discussion_comment`] zum Bearbeiten von Kommentaren * [`rename_discussion`] zum Umbenennen einer Diskussion oder eines Pull-Requests * [`change_discussion_status`] zum Öffnen oder Schließen einer Diskussion / eines Pull-Requests * [`merge_pull_request`] zum Zusammenführen eines Pull-Requests Besuchen Sie die [`HfApi`]-Dokumentationsseite für eine vollständige Übersicht aller verfügbaren Methoden. ## Änderungen an einen Pull-Request senden *Demnächst verfügbar !* ## Siehe auch Für eine detailliertere Referenz besuchen Sie die [Diskussionen und Pull-Requests](../package_reference/community) und die [hf_api](../package_reference/hf_api)-Dokumentationen. huggingface_hub-0.31.1/docs/source/de/guides/download.md000066400000000000000000000327411500667546600231600ustar00rootroot00000000000000 # Dateien aus dem Hub herunterladen Die `huggingface_hub`-Bibliothek bietet Funktionen zum Herunterladen von Dateien aus den auf dem Hub gespeicherten Repositories. Sie können diese Funktionen unabhängig verwenden oder in Ihre eigene Bibliothek integrieren, um es Ihren Benutzern zu erleichtern, mit dem Hub zu interagieren. In diesem Leitfaden erfahren Sie, wie Sie: * Einzelne Dateien herunterladen und zwischenspeichern. * Ein gesamtes Repository herunterladen und zwischenspeichern. * Dateien in einen lokalen Ordner herunterladen. ## Einzelne Dateien herunterladen Die [`hf_hub_download`]-Funktion ist die Hauptfunktion zum Herunterladen von Dateien aus dem Hub. Sie lädt die Remote-Datei herunter, speichert sie auf der Festplatte (auf eine versionsbewusste Art und Weise) und gibt ihren lokalen Dateipfad zurück. Der zurückgegebene Dateipfad verweist auf den lokalen Cache von HF. Es ist daher wichtig, die Datei nicht zu ändern, um einen beschädigten Cache zu vermeiden. Wenn Sie mehr darüber erfahren möchten, wie Dateien zwischengespeichert werden, lesen Sie bitte unseren [Caching-Leitfaden](./manage-cache). ### Von der neuesten Version Wählen Sie die Datei zum Herunterladen anhand der Parameter `repo_id`, `repo_type` und `filename` aus. Standardmäßig wird davon ausgegangen, dass die Datei Teil einer `model`-Repository ist. ```python >>> from huggingface_hub import hf_hub_download >>> hf_hub_download(repo_id="lysandre/arxiv-nlp", filename="config.json") '/root/.cache/huggingface/hub/models--lysandre--arxiv-nlp/snapshots/894a9adde21d9a3e3843e6d5aeaaf01875c7fade/config.json' # Herunterladen von einem Dataset >>> hf_hub_download(repo_id="google/fleurs", filename="fleurs.py", repo_type="dataset") '/root/.cache/huggingface/hub/datasets--google--fleurs/snapshots/199e4ae37915137c555b1765c01477c216287d34/fleurs.py' ``` ### Von einer spezifischen Version Standardmäßig wird die neueste Version vom Hauptzweig `main` heruntergeladen. In einigen Fällen möchten Sie jedoch eine Datei in einer bestimmten Version herunterladen (z. B. aus einem bestimmten Zweig, einem PR, einem Tag oder einem Commit-Hash). Verwenden Sie dazu den Parameter `revision`: ```python # Herunterladen vom Tag `v1.0` >>> hf_hub_download(repo_id="lysandre/arxiv-nlp", filename="config.json", revision="v1.0") # Herunterladen vom Zweig `test-branch` >>> hf_hub_download(repo_id="lysandre/arxiv-nlp", filename="config.json", revision="test-branch") # Herunterladen von Pull Request #3 >>> hf_hub_download(repo_id="lysandre/arxiv-nlp", filename="config.json", revision="refs/pr/3") # Herunterladen von einem spezifischen Commit-Hash >>> hf_hub_download(repo_id="lysandre/arxiv-nlp", filename="config.json", revision="877b84a8f93f2d619faa2a6e514a32beef88ab0a") ``` **Hinweis:** Bei Verwendung des Commit-Hashs muss der vollständige Hash anstelle eines 7-Zeichen-Commit-Hashs verwendet werden. ### URL zum Herunterladen erstellen Falls Sie die URL erstellen möchten, die zum Herunterladen einer Datei aus einem Repo verwendet wird, können Sie [`hf_hub_url`] verwenden, das eine URL zurückgibt. Beachten Sie, dass es intern von [`hf_hub_download`] verwendet wird. ## Gesamte Repository herunterladen [`snapshot_download`] lädt ein gesamtes Repository zu einer bestimmten Revision herunter. Es verwendet intern [`hf_hub_download`], was bedeutet, dass alle heruntergeladenen Dateien auch auf Ihrer lokalen Festplatte zwischengespeichert werden. Die Downloads werden gleichzeitig durchgeführt, um den Prozess zu beschleunigen. Um ein ganzes Repository herunterzuladen, geben Sie einfach die `repo_id` und `repo_type` an: ```python >>> from huggingface_hub import snapshot_download >>> snapshot_download(repo_id="lysandre/arxiv-nlp") '/home/lysandre/.cache/huggingface/hub/models--lysandre--arxiv-nlp/snapshots/894a9adde21d9a3e3843e6d5aeaaf01875c7fade' # Oder von einem Dataset >>> snapshot_download(repo_id="google/fleurs", repo_type="dataset") '/home/lysandre/.cache/huggingface/hub/datasets--google--fleurs/snapshots/199e4ae37915137c555b1765c01477c216287d34' ``` [`snapshot_download`] lädt standardmäßig die neueste Revision herunter. Wenn Sie eine spezifische Repository-Revision wünschen, verwenden Sie den Parameter `revision`: ```python >>> from huggingface_hub import snapshot_download >>> snapshot_download(repo_id="lysandre/arxiv-nlp", revision="refs/pr/1") ``` ### Dateien filtern zum Herunterladen [`snapshot_download`] bietet eine einfache Möglichkeit, ein Repository herunterzuladen. Sie möchten jedoch nicht immer den gesamten Inhalt eines Repositories herunterladen. Beispielsweise möchten Sie möglicherweise verhindern, dass alle `.bin`-Dateien heruntergeladen werden, wenn Sie wissen, dass Sie nur die `.safetensors`-Gewichtungen verwenden werden. Dies können Sie mit den Parametern `allow_patterns` und `ignore_patterns` tun. Diese Parameter akzeptieren entweder ein einzelnes Muster oder eine Liste von Mustern. Muster sind Standard-Wildcards (globbing patterns) wie [hier](https://tldp.org/LDP/GNU-Linux-Tools-Summary/html/x11655.htm) dokumentiert. Die Mustervergleichung basiert auf [`fnmatch`](https://docs.python.org/3/library/fnmatch.html). Beispielsweise können Sie `allow_patterns` verwenden, um nur JSON-Konfigurationsdateien herunterzuladen: ```python >>> from huggingface_hub import snapshot_download >>> snapshot_download(repo_id="lysandre/arxiv-nlp", allow_patterns="*.json") ``` Andererseits können Sie mit `ignore_patterns` bestimmte Dateien vom Herunterladen ausschließen. Im folgenden Beispiel werden die Dateierweiterungen `.msgpack` und `.h5` ignoriert: ```python >>> from huggingface_hub import snapshot_download >>> snapshot_download(repo_id="lysandre/arxiv-nlp", ignore_patterns=["*.msgpack", "*.h5"]) ``` Schließlich können Sie beide kombinieren, um Ihren Download genau zu filtern. Hier ist ein Beispiel, wie man alle json- und markdown-Dateien herunterlädt, außer `vocab.json`. ```python >>> from huggingface_hub import snapshot_download >>> snapshot_download(repo_id="gpt2", allow_patterns=["*.md", "*.json"], ignore_patterns="vocab.json") ``` ## Datei(en) in lokalen Ordner herunterladen Die empfohlene (und standardmäßige) Methode zum Herunterladen von Dateien aus dem Hub besteht darin, das [Cache-System](./manage-cache) zu verwenden. Sie können Ihren Cache-Ort festlegen, indem Sie den `cache_dir`-Parameter setzen (sowohl in [`hf_hub_download`] als auch in [`snapshot_download`]). In einigen Fällen möchten Sie jedoch Dateien herunterladen und in einen bestimmten Ordner verschieben. Dies ist nützlich, um einen Workflow zu erhalten, der den `git`-Befehlen ähnelt. Sie können dies mit den Parametern `local_dir` und `local_dir_use_symlinks` tun: - local_dir muss ein Pfad zu einem Ordner auf Ihrem System sein. Die heruntergeladenen Dateien behalten dieselbe Dateistruktur wie im Repository. Wenn zum Beispiel `filename="data/train.csv"` und `local_dir="pfad/zum/ordner"` ist, wird der zurückgegebene Dateipfad `"pfad/zum/ordner/data/train.csv"` sein. - `local_dir_use_symlinks` definiert, wie die Datei in Ihrem lokalen Ordner gespeichert werden muss. - Das Standardverhalten (`"auto"`) besteht darin, kleine Dateien (<5MB) zu duplizieren und für größere Dateien Symlinks zu verwenden. Symlinks ermöglichen die Optimierung von Bandbreite und Speicherplatz. Das manuelle Bearbeiten einer verlinkten Datei könnte jedoch den Cache beschädigen, daher die Duplizierung für kleine Dateien. Die 5-MB-Schwelle kann mit der `HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD`-Umgebungsvariable konfiguriert werden. - Wenn `local_dir_use_symlinks=True` gesetzt ist, werden alle Dateien verlinkt, um den Speicherplatz optimal zu nutzen. Dies ist zum Beispiel nützlich, wenn ein riesiges Dataset mit Tausenden von kleinen Dateien heruntergeladen wird. - Wenn Sie überhaupt keine Symlinks möchten, können Sie sie deaktivieren (`local_dir_use_symlinks=False`). Das Cache-Verzeichnis wird weiterhin verwendet, um zu überprüfen, ob die Datei bereits im Cache ist oder nicht. Wenn sie bereits im Cache ist, wird die Datei aus dem Cache **dupliziert** (d.h. Bandbreite wird gespart, aber der Speicherplatzverbrauch steigt). Wenn die Datei noch nicht im Cache ist, wird sie heruntergeladen und direkt in das lokale Verzeichnis verschoben. Das bedeutet, dass wenn Sie sie später woanders wiederverwenden müssen, sie **erneut heruntergeladen** wird. Hier ist eine Tabelle, die die verschiedenen Optionen zusammenfasst, um Ihnen zu helfen, die Parameter zu wählen, die am besten zu Ihrem Anwendungsfall passen. | Parameter | Datei schon im Cache | Zurückgegebener Pfad | Pfad lesbar? | Kann im Pfad speichern? | Optimierter Datendurchsatz | Optimierter Speicherplatz | |---|:---:|:---:|:---:|:---:|:---:|:---:| | `local_dir=None` | | Symlink im Cache | ✅ | ❌
_(Speichern würde den Cache beschädigen)_ | ✅ | ✅ | | `local_dir="path/to/folder"`
`local_dir_use_symlinks="auto"` | | Datei oder Symlink im Ordner | ✅ | ✅ _(für kleine Dateien)_
⚠️ _(für große Dateien den Pfad nicht auflösen vor dem Speichern)_ | ✅ | ✅ | | `local_dir="path/to/folder"`
`local_dir_use_symlinks=True` | | Symlink im Ordner | ✅ | ⚠️
_(den Pfad nicht auflösen vor dem Speichern)_ | ✅ | ✅ | | `local_dir="path/to/folder"`
`local_dir_use_symlinks=False` | Nein | Datei im Ordner | ✅ | ✅ | ❌
_(bei erneutem Ausführen wird die Datei erneut heruntergeladen)_ | ⚠️
(mehrere Kopien, wenn in mehreren Ordnern ausgeführt) | | `local_dir="path/to/folder"`
`local_dir_use_symlinks=False` | Ja | Datei im Ordner | ✅ | ✅ | ⚠️
_(Datei muss zuerst im Cache gespeichert werden)_ | ❌
_(Datei wird dupliziert)_ | **Hinweis**: Wenn Sie einen Windows-Computer verwenden, müssen Sie den Entwicklermodus aktivieren oder `huggingface_hub` als Administrator ausführen, um Symlinks zu aktivieren. Weitere Details finden Sie im Abschnitt über [Cache-Beschränkungen](../guides/manage-cache#limitations). ## Herunterladen mit dem CLI Sie können den `huggingface-cli download`-Befehl im Terminal verwenden, um Dateien direkt aus dem Hub herunterzuladen. Intern verwendet es die gleichen [`hf_hub_download`] und [`snapshot_download`] Helfer, die oben beschrieben wurden, und gibt den zurückgegebenen Pfad im Terminal aus: ```bash >>> huggingface-cli download gpt2 config.json /home/wauplin/.cache/huggingface/hub/models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10/config.json ``` Standardmäßig wird das lokal gespeicherte Token (mit `huggingface-cli login`) verwendet. Wenn Sie sich ausdrücklich authentifizieren möchten, verwenden Sie die `--token` Option: ```bash >>> huggingface-cli download gpt2 config.json --token=hf_**** /home/wauplin/.cache/huggingface/hub/models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10/config.json ``` Sie können mehrere Dateien gleichzeitig herunterladen, wobei eine Fortschrittsleiste angezeigt wird und der Snapshot-Pfad zurückgegeben wird, in dem sich die Dateien befinden: ```bash >>> huggingface-cli download gpt2 config.json model.safetensors Fetching 2 files: 100%|████████████████████████████████████████████| 2/2 [00:00<00:00, 23831.27it/s] /home/wauplin/.cache/huggingface/hub/models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10 ``` Wenn Sie die Fortschrittsleisten und mögliche Warnungen stummschalten möchten, verwenden Sie die Option `--quiet`. Dies kann nützlich sein, wenn Sie die Ausgabe an einen anderen Befehl in einem Skript weitergeben möchten. ```bash >>> huggingface-cli download gpt2 config.json model.safetensors /home/wauplin/.cache/huggingface/hub/models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10 ``` Standardmäßig werden Dateien im Cache-Verzeichnis heruntergeladen, das durch die Umgebungsvariable `HF_HOME` definiert ist (oder `~/.cache/huggingface/hub`, wenn nicht angegeben). Sie können dies mit der Option `--cache-dir` überschreiben: ```bash >>> huggingface-cli download gpt2 config.json --cache-dir=./cache ./cache/models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10/config.json ``` Wenn Sie Dateien in einen lokalen Ordner herunterladen möchten, ohne die Cache-Verzeichnisstruktur, können Sie `--local-dir` verwenden. Das Herunterladen in einen lokalen Ordner hat seine Einschränkungen, die in dieser [Tabelle](https://huggingface.co/docs/huggingface_hub/guides/download#download-files-to-local-folder) aufgeführt sind. ```bash >>> huggingface-cli download gpt2 config.json --local-dir=./models/gpt2 ./models/gpt2/config.json ``` Es gibt weitere Argumente, die Sie angeben können, um aus verschiedenen Repo-Typen oder Revisionen herunterzuladen und Dateien zum Herunterladen mit Glob-Mustern ein- oder auszuschließen: ```bash >>> huggingface-cli download bigcode/the-stack --repo-type=dataset --revision=v1.2 --include="data/python/*" --exclu de="*.json" --exclude="*.zip" Fetching 206 files: 100%|████████████████████████████████████████████| 206/206 [02:31<2:31, ?it/s] /home/wauplin/.cache/huggingface/hub/datasets--bigcode--the-stack/snapshots/9ca8fa6acdbc8ce920a0cb58adcdafc495818ae7 ``` Für eine vollständige Liste der Argumente führen Sie bitte den folgenden Befehl aus: ```bash huggingface-cli download --help ``` huggingface_hub-0.31.1/docs/source/de/guides/hf_file_system.md000066400000000000000000000121331500667546600243420ustar00rootroot00000000000000 # Interagieren mit dem Hub über die Filesystem API Zusätzlich zur [`HfApi`] bietet die `huggingface_hub` Bibliothek [`HfFileSystem`], eine pythonische, [fsspec-kompatible](https://filesystem-spec.readthedocs.io/en/latest/) Dateischnittstelle zum Hugging Face Hub. Das [`HfFileSystem`] basiert auf der [`HfApi`] und bietet typische Dateisystemoperationen wie `cp`, `mv`, `ls`, `du`, `glob`, `get_file`, und `put_file`. ## Verwendung ```python >>> from huggingface_hub import HfFileSystem >>> fs = HfFileSystem() >>> # Alle Dateien in einem Verzeichnis auflisten >>> fs.ls("datasets/my-username/my-dataset-repo/data", detail=False) ['datasets/my-username/my-dataset-repo/data/train.csv', 'datasets/my-username/my-dataset-repo/data/test.csv'] >>> # Alle ".csv"-Dateien in einem Repo auflisten >>> fs.glob("datasets/my-username/my-dataset-repo/**.csv") ['datasets/my-username/my-dataset-repo/data/train.csv', 'datasets/my-username/my-dataset-repo/data/test.csv'] >>> # Eine entfernte Datei lesen >>> with fs.open("datasets/my-username/my-dataset-repo/data/train.csv", "r") as f: ... train_data = f.readlines() >>> # Den Inhalt einer entfernten Datei als Zeichenkette / String lesen >>> train_data = fs.read_text("datasets/my-username/my-dataset-repo/data/train.csv", revision="dev") >>> # Eine entfernte Datei schreiben >>> with fs.open("datasets/my-username/my-dataset-repo/data/validation.csv", "w") as f: ... f.write("text,label") ... f.write("Fantastic movie!,good") ``` Das optionale Argument `revision` kann übergeben werden, um eine Operation von einem spezifischen Commit auszuführen, wie z.B. einem Branch, Tag-Namen oder einem Commit-Hash. Anders als bei Pythons eingebautem `open`, ist der Standardmodus von `fsspec`'s `open` binär, `"rb"`. Das bedeutet, dass Sie den Modus explizit auf `"r"` zum Lesen und `"w"` zum Schreiben im Textmodus setzen müssen. Das Anhängen an eine Datei (Modi `"a"` und `"ab"`) wird noch nicht unterstützt. ## Integrationen Das [`HfFileSystem`] kann mit jeder Bibliothek verwendet werden, die `fsspec` integriert, vorausgesetzt die URL folgt dem Schema: ``` hf://[][@]/ ``` Der `repo_type_prefix` ist `datasets/` für Datensätze, `spaces/` für Spaces, und Modelle benötigen kein Präfix in der URL. Einige interessante Integrationen, bei denen [`HfFileSystem`] die Interaktion mit dem Hub vereinfacht, sind unten aufgeführt: * Lesen/Schreiben eines [Pandas](https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html#reading-writing-remote-files) DataFrame aus/in ein Hub-Repository: ```python >>> import pandas as pd >>> # Eine entfernte CSV-Datei in einen DataFrame lesen >>> df = pd.read_csv("hf://datasets/my-username/my-dataset-repo/train.csv") >>> # Einen DataFrame in eine entfernte CSV-Datei schreiben >>> df.to_csv("hf://datasets/my-username/my-dataset-repo/test.csv") ``` Der gleiche Arbeitsablauf kann auch für [Dask](https://docs.dask.org/en/stable/how-to/connect-to-remote-data.html) und [Polars](https://pola-rs.github.io/polars/py-polars/html/reference/io.html) verwendet werden. * Abfrage von (entfernten) Hub-Dateien mit [DuckDB](https://duckdb.org/docs/guides/python/filesystems): ```python >>> from huggingface_hub import HfFileSystem >>> import duckdb >>> fs = HfFileSystem() >>> duckdb.register_filesystem(fs) >>> # Eine entfernte Datei abfragen und das Ergebnis als DataFrame zurückbekommen >>> fs_query_file = "hf://datasets/my-username/my-dataset-repo/data_dir/data.parquet" >>> df = duckdb.query(f"SELECT * FROM '{fs_query_file}' LIMIT 10").df() ``` * Verwendung des Hub als Array-Speicher mit [Zarr](https://zarr.readthedocs.io/en/stable/tutorial.html#io-with-fsspec): ```python >>> import numpy as np >>> import zarr >>> embeddings = np.random.randn(50000, 1000).astype("float32") >>> # Ein Array in ein Repo schreiben >>> with zarr.open_group("hf://my-username/my-model-repo/array-store", mode="w") as root: ... foo = root.create_group("embeddings") ... foobar = foo.zeros('experiment_0', shape=(50000, 1000), chunks=(10000, 1000), dtype='f4') ... foobar[:] = embeddings >>> # Ein Array aus einem Repo lesen >>> with zarr.open_group("hf://my-username/my-model-repo/array-store", mode="r") as root: ... first_row = root["embeddings/experiment_0"][0] ``` ## Authentifizierung In vielen Fällen müssen Sie mit einem Hugging Face-Konto angemeldet sein, um mit dem Hub zu interagieren. Lesen Sie den [Login](../quick-start#login)-Abschnitt der Dokumentation, um mehr über Authentifizierungsmethoden auf dem Hub zu erfahren. Es ist auch möglich, sich programmatisch anzumelden, indem Sie Ihr `token` als Argument an [`HfFileSystem`] übergeben: ```python >>> from huggingface_hub import HfFileSystem >>> fs = HfFileSystem(token=token) ``` Wenn Sie sich auf diese Weise anmelden, seien Sie vorsichtig, das Token nicht versehentlich zu veröffentlichen, wenn Sie Ihren Quellcode teilen! huggingface_hub-0.31.1/docs/source/de/guides/inference.md000066400000000000000000000423741500667546600233120ustar00rootroot00000000000000 # Inferenz auf Servern ausführen Inferenz ist der Prozess, bei dem ein trainiertes Modell verwendet wird, um Vorhersagen für neue Daten zu treffen. Da dieser Prozess rechenintensiv sein kann, kann die Ausführung auf einem dedizierten Server eine interessante Option sein. Die `huggingface_hub` Bibliothek bietet eine einfache Möglichkeit, einen Dienst aufzurufen, der die Inferenz für gehostete Modelle durchführt. Es gibt mehrere Dienste, mit denen Sie sich verbinden können: - [Inferenz API](https://huggingface.co/docs/api-inference/index): ein Service, der Ihnen ermöglicht, beschleunigte Inferenz auf der Infrastruktur von Hugging Face kostenlos auszuführen. Dieser Service ist eine schnelle Möglichkeit, um anzufangen, verschiedene Modelle zu testen und AI-Produkte zu prototypisieren. - [Inferenz Endpunkte](https://huggingface.co/inference-endpoints/index): ein Produkt zur einfachen Bereitstellung von Modellen im Produktivbetrieb. Die Inferenz wird von Hugging Face in einer dedizierten, vollständig verwalteten Infrastruktur auf einem Cloud-Anbieter Ihrer Wahl durchgeführt. Diese Dienste können mit dem [`InferenceClient`] Objekt aufgerufen werden. Dieser fungiert als Ersatz für den älteren [`InferenceApi`] Client und fügt spezielle Unterstützung für Aufgaben und das Ausführen von Inferenz hinzu, sowohl auf [Inferenz API](https://huggingface.co/docs/api-inference/index) als auch auf [Inferenz Endpunkten](https://huggingface.co/docs/inference-endpoints/index). Im Abschnitt [Legacy InferenceAPI client](#legacy-inferenceapi-client) erfahren Sie, wie Sie zum neuen Client migrieren können. [`InferenceClient`] ist ein Python-Client, der HTTP-Anfragen an unsere APIs stellt. Wenn Sie die HTTP-Anfragen direkt mit Ihrem bevorzugten Tool (curl, postman,...) durchführen möchten, lesen Sie bitte die Dokumentationsseiten der [Inferenz API](https://huggingface.co/docs/api-inference/index) oder der [Inferenz Endpunkte](https://huggingface.co/docs/inference-endpoints/index). Für die Webentwicklung wurde ein [JS-Client](https://huggingface.co/docs/huggingface.js/inference/README) veröffentlicht. Wenn Sie sich für die Spieleentwicklung interessieren, sollten Sie einen Blick auf unser [C#-Projekt](https://github.com/huggingface/unity-api) werfen. ## Erste Schritte Los geht's mit einer Text-zu-Bild-Aufgabe: ```python >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> image = client.text_to_image("An astronaut riding a horse on the moon.") >>> image.save("astronaut.png") ``` Wir haben einen [`InferenceClient`] mit den Standardparametern initialisiert. Das Einzige, was Sie wissen müssen, ist die [Aufgabe](#unterstützte-aufgaben), die Sie ausführen möchten. Standardmäßig wird der Client sich mit der Inferenz API verbinden und ein Modell auswählen, um die Aufgabe abzuschließen. In unserem Beispiel haben wir ein Bild aus einem Textprompt generiert. Der zurückgegebene Wert ist ein `PIL.Image`-Objekt, das in eine Datei gespeichert werden kann. Die API ist darauf ausgelegt, einfach zu sein. Nicht alle Parameter und Optionen sind für den Endbenutzer verfügbar oder beschrieben. Schauen Sie auf [dieser Seite](https://huggingface.co/docs/api-inference/detailed_parameters) nach, wenn Sie mehr über alle verfügbaren Parameter für jede Aufgabe erfahren möchten. ### Verwendung eines spezifischen Modells Was ist, wenn Sie ein bestimmtes Modell verwenden möchten? Sie können es entweder als Parameter angeben oder direkt auf Instanzebene spezifizieren: ```python >>> from huggingface_hub import InferenceClient # Client für ein spezifisches Modell initialisieren >>> client = InferenceClient(model="prompthero/openjourney-v4") >>> client.text_to_image(...) # Oder nutzen Sie einen generischen Client, geben aber Ihr Modell als Argument an >>> client = InferenceClient() >>> client.text_to_image(..., model="prompthero/openjourney-v4") ``` Es gibt mehr als 200k Modelle im Hugging Face Hub! Jede Aufgabe im [`InferenceClient`] kommt mit einem empfohlenen Modell. Beachten Sie, dass die HF-Empfehlung sich im Laufe der Zeit ohne vorherige Ankündigung ändern kann. Daher ist es am besten, ein Modell explizit festzulegen, sobald Sie sich entschieden haben. In den meisten Fällen werden Sie daran interessiert sein, ein Modell zu finden, das speziell auf _Ihre_ Bedürfnisse zugeschnitten ist. Besuchen Sie die [Modelle](https://huggingface.co/models)-Seite im Hub, um Ihre Möglichkeiten zu erkunden. ### Verwendung einer spezifischen URL Die oben gesehenen Beispiele nutzen die kostenfrei gehostete Inferenz API. Dies erweist sich als sehr nützlich für Prototyping und schnelles Testen. Wenn Sie bereit sind, Ihr Modell in die Produktion zu übernehmen, müssen Sie eine dedizierte Infrastruktur verwenden. Hier kommen [Inferenz Endpunkte](https://huggingface.co/docs/inference-endpoints/index) ins Spiel. Es ermöglicht Ihnen, jedes Modell zu implementieren und als private API freizugeben. Nach der Implementierung erhalten Sie eine URL, zu der Sie mit genau dem gleichen Code wie zuvor eine Verbindung herstellen können, wobei nur der `Modell`-Parameter geändert wird: ```python >>> from huggingface_hub import InferenceClient >>> client = InferenceClient(model="https://uu149rez6gw9ehej.eu-west-1.aws.endpoints.huggingface.cloud/deepfloyd-if") # oder >>> client = InferenceClient() >>> client.text_to_image(..., model="https://uu149rez6gw9ehej.eu-west-1.aws.endpoints.huggingface.cloud/deepfloyd-if") ``` ### Authentifizierung Aufrufe, die mit dem [`InferenceClient`] gemacht werden, können mit einem [User Access Token](https://huggingface.co/docs/hub/security-tokens) authentifiziert werden. Standardmäßig wird das auf Ihrem Computer gespeicherte Token verwendet, wenn Sie angemeldet sind (sehen Sie hier, [wie Sie sich anmelden können](https://huggingface.co/docs/huggingface_hub/quick-start#login)). Wenn Sie nicht angemeldet sind, können Sie Ihr Token als Instanzparameter übergeben: ```python >>> from huggingface_hub import InferenceClient >>> client = InferenceClient(token="hf_***") ``` Die Authentifizierung ist NICHT zwingend erforderlich, wenn Sie die Inferenz API verwenden. Authentifizierte Benutzer erhalten jedoch ein höheres kostenloses Kontingent, um mit dem Service zu arbeiten. Ein Token ist auch zwingend erforderlich, wenn Sie Inferenz auf Ihren privaten Modellen oder auf privaten Endpunkten ausführen möchten. ## Unterstützte Aufgaben Das Ziel von [`InferenceClient`] ist es, die einfachste Schnittstelle zum Ausführen von Inferenzen auf Hugging Face-Modellen bereitzustellen. Es verfügt über eine einfache API, die die gebräuchlichsten Aufgaben unterstützt. Hier ist eine Liste der derzeit unterstützten Aufgaben: | Domäne | Aufgabe | Unterstützt | Dokumentation | |--------|--------------------------------|--------------|------------------------------------| | Audio | [Audio Classification](https://huggingface.co/tasks/audio-classification) | ✅ | [`~InferenceClient.audio_classification`] | | | [Automatic Speech Recognition](https://huggingface.co/tasks/automatic-speech-recognition) | ✅ | [`~InferenceClient.automatic_speech_recognition`] | | | [Text-to-Speech](https://huggingface.co/tasks/text-to-speech) | ✅ | [`~InferenceClient.text_to_speech`] | | Computer Vision | [Image Classification](https://huggingface.co/tasks/image-classification) | ✅ | [`~InferenceClient.image_classification`] | | | [Image Segmentation](https://huggingface.co/tasks/image-segmentation) | ✅ | [`~InferenceClient.image_segmentation`] | | | [Image-to-Image](https://huggingface.co/tasks/image-to-image) | ✅ | [`~InferenceClient.image_to_image`] | | | [Image-to-Text](https://huggingface.co/tasks/image-to-text) | ✅ | [`~InferenceClient.image_to_text`] | | | [Object Detection](https://huggingface.co/tasks/object-detection) | ✅ | [`~InferenceClient.object_detection`] | | | [Text-to-Image](https://huggingface.co/tasks/text-to-image) | ✅ | [`~InferenceClient.text_to_image`] | | | [Zero-Shot-Image-Classification](https://huggingface.co/tasks/zero-shot-image-classification) | ✅ | [`~InferenceClient.zero_shot_image_classification`] | | Multimodal | [Documentation Question Answering](https://huggingface.co/tasks/document-question-answering) | ✅ | [`~InferenceClient.document_question_answering`] | | | [Visual Question Answering](https://huggingface.co/tasks/visual-question-answering) | ✅ | [`~InferenceClient.visual_question_answering`] | | NLP | [Conversational](https://huggingface.co/tasks/conversational) | ✅ | [`~InferenceClient.conversational`] | | | [Feature Extraction](https://huggingface.co/tasks/feature-extraction) | ✅ | [`~InferenceClient.feature_extraction`] | | | [Fill Mask](https://huggingface.co/tasks/fill-mask) | ✅ | [`~InferenceClient.fill_mask`] | | | [Question Answering](https://huggingface.co/tasks/question-answering) | ✅ | [`~InferenceClient.question_answering`] | | | [Sentence Similarity](https://huggingface.co/tasks/sentence-similarity) | ✅ | [`~InferenceClient.sentence_similarity`] | | | [Summarization](https://huggingface.co/tasks/summarization) | ✅ | [`~InferenceClient.summarization`] | | | [Table Question Answering](https://huggingface.co/tasks/table-question-answering) | ✅ | [`~InferenceClient.table_question_answering`] | | | [Text Classification](https://huggingface.co/tasks/text-classification) | ✅ | [`~InferenceClient.text_classification`] | | | [Text Generation](https://huggingface.co/tasks/text-generation) | ✅ | [`~InferenceClient.text_generation`] | | | [Token Classification](https://huggingface.co/tasks/token-classification) | ✅ | [`~InferenceClient.token_classification`] | | | [Translation](https://huggingface.co/tasks/translation) | ✅ | [`~InferenceClient.translation`] | | | [Zero Shot Classification](https://huggingface.co/tasks/zero-shot-classification) | ✅ | [`~InferenceClient.zero_shot_classification`] | | Tabular | [Tabular Classification](https://huggingface.co/tasks/tabular-classification) | ✅ | [`~InferenceClient.tabular_classification`] | | | [Tabular Regression](https://huggingface.co/tasks/tabular-regression) | ✅ | [`~InferenceClient.tabular_regression`] | Schauen Sie sich die [Aufgaben](https://huggingface.co/tasks)-Seite an, um mehr über jede Aufgabe zu erfahren, wie man sie verwendet und die beliebtesten Modelle für jede Aufgabe. ## Asynchroner Client Eine asynchrone Version des Clients wird ebenfalls bereitgestellt, basierend auf `asyncio` und `aiohttp`. Sie können entweder `aiohttp` direkt installieren oder das `[inference]` Extra verwenden: ```sh pip install aiohttp # oder pip install --upgrade huggingface_hub[inference] ``` Nach der Installation sind alle asynchronen API-Endpunkte über [`AsyncInferenceClient`] verfügbar. Seine Initialisierung und APIs sind genau gleich wie die synchronisierte Version. ```py # Der Code muss in einem asyncio-konkurrenten Kontext ausgeführt werden. # $ python -m asyncio >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> image = await client.text_to_image("An astronaut riding a horse on the moon.") >>> image.save("astronaut.png") >>> async for token in await client.text_generation("The Huggingface Hub is", stream=True): ... print(token, end="") a platform for sharing and discussing ML-related content. ``` Für weitere Informationen zum `asyncio`-Modul konsultieren Sie bitte die [offizielle Dokumentation](https://docs.python.org/3/library/asyncio.html). ## Fortgeschrittene Tipps Im obigen Abschnitt haben wir die Hauptaspekte von [`InferenceClient`] betrachtet. Lassen Sie uns in einige fortgeschrittene Tipps eintauchen. ### Zeitüberschreitung Bei der Inferenz gibt es zwei Hauptursachen für eine Zeitüberschreitung: - Der Inferenzprozess dauert lange, um abgeschlossen zu werden. - Das Modell ist nicht verfügbar, beispielsweise wenn die Inferenz API es zum ersten Mal lädt. Der [`InferenceClient`] verfügt über einen globalen Zeitüberschreitungsparameter (`timeout`), um diese beiden Aspekte zu behandeln. Standardmäßig ist er auf `None` gesetzt, was bedeutet, dass der Client unendlich lange auf den Abschluss der Inferenz warten wird. Wenn Sie mehr Kontrolle in Ihrem Arbeitsablauf wünschen, können Sie ihn auf einen bestimmten Wert in Sekunden setzen. Wenn die Zeitüberschreitungsverzögerung abläuft, wird ein [`InferenceTimeoutError`] ausgelöst. Sie können diesen Fehler abfangen und in Ihrem Code behandeln: ```python >>> from huggingface_hub import InferenceClient, InferenceTimeoutError >>> client = InferenceClient(timeout=30) >>> try: ... client.text_to_image(...) ... except InferenceTimeoutError: ... print("Inference timed out after 30s.") ``` ### Binäre Eingaben Einige Aufgaben erfordern binäre Eingaben, zum Beispiel bei der Arbeit mit Bildern oder Audiodateien. In diesem Fall versucht der [`InferenceClient] so permissiv wie möglich zu sein und akzeptiert verschiedene Typen: - rohe `Bytes` - ein Datei-ähnliches Objekt, geöffnet als Binär (`with open("audio.flac", "rb") as f: ...`) - ein Pfad (`str` oder `Path`) zu einer lokalen Datei - eine URL (`str`) zu einer entfernten Datei (z.B. `https://...`). In diesem Fall wird die Datei lokal heruntergeladen, bevor sie an die Inferenz API gesendet wird. ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> client.image_classification("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg") [{'score': 0.9779096841812134, 'label': 'Blenheim spaniel'}, ...] ``` ## Legacy InferenceAPI client Der [`InferenceClient`] dient als Ersatz für den veralteten [`InferenceApi`]-Client. Er bietet spezifische Unterstützung für Aufgaben und behandelt Inferenz sowohl auf der [Inferenz API](https://huggingface.co/docs/api-inference/index) als auch auf den [Inferenz Endpunkten](https://huggingface.co/docs/inference-endpoints/index). Hier finden Sie eine kurze Anleitung, die Ihnen hilft, von [`InferenceApi`] zu [`InferenceClient`] zu migrieren. ### Initialisierung Ändern Sie von ```python >>> from huggingface_hub import InferenceApi >>> inference = InferenceApi(repo_id="bert-base-uncased", token=API_TOKEN) ``` zu ```python >>> from huggingface_hub import InferenceClient >>> inference = InferenceClient(model="bert-base-uncased", token=API_TOKEN) ``` ### Ausführen einer bestimmten Aufgabe Ändern Sie von ```python >>> from huggingface_hub import InferenceApi >>> inference = InferenceApi(repo_id="paraphrase-xlm-r-multilingual-v1", task="feature-extraction") >>> inference(...) ``` zu ```python >>> from huggingface_hub import InferenceClient >>> inference = InferenceClient() >>> inference.feature_extraction(..., model="paraphrase-xlm-r-multilingual-v1") ``` Dies ist der empfohlene Weg, um Ihren Code an [`InferenceClient`] anzupassen. Dadurch können Sie von den aufgabenspezifischen Methoden wie `feature_extraction` profitieren. ### Eigene Anfragen ausführen Ändern Sie von ```python >>> from huggingface_hub import InferenceApi >>> inference = InferenceApi(repo_id="bert-base-uncased") >>> inference(inputs="The goal of life is [MASK].") [{'sequence': 'the goal of life is life.', 'score': 0.10933292657136917, 'token': 2166, 'token_str': 'life'}] ``` zu ```python >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> response = client.post(json={"inputs": "The goal of life is [MASK]."}, model="bert-base-uncased") >>> response.json() [{'sequence': 'the goal of life is life.', 'score': 0.10933292657136917, 'token': 2166, 'token_str': 'life'}] ``` ### Mit Parametern ausführen Ändern Sie von ```python >>> from huggingface_hub import InferenceApi >>> inference = InferenceApi(repo_id="typeform/distilbert-base-uncased-mnli") >>> inputs = "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!" >>> params = {"candidate_labels":["refund", "legal", "faq"]} >>> inference(inputs, params) {'sequence': 'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!', 'labels': ['refund', 'faq', 'legal'], 'scores': [0.9378499388694763, 0.04914155602455139, 0.013008488342165947]} ``` zu ```python >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> inputs = "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!" >>> params = {"candidate_labels":["refund", "legal", "faq"]} >>> response = client.post(json={"inputs": inputs, "parameters": params}, model="typeform/distilbert-base-uncased-mnli") >>> response.json() {'sequence': 'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!', 'labels': ['refund', 'faq', 'legal'], 'scores': [0.9378499388694763, 0.04914155602455139, 0.013008488342165947]} ``` huggingface_hub-0.31.1/docs/source/de/guides/integrations.md000066400000000000000000000350241500667546600240540ustar00rootroot00000000000000 # Integrieren Sie jedes ML-Framework mit dem Hub Der Hugging Face Hub erleichtert das Hosten und Teilen von Modellen mit der Community. Er unterstützt [Dutzende von Bibliotheken](https://huggingface.co/docs/hub/models-libraries) im Open Source-Ökosystem. Wir arbeiten ständig daran, diese Unterstützung zu erweitern, um kollaboratives Machine Learning voranzutreiben. Die `huggingface_hub`-Bibliothek spielt eine Schlüsselrolle in diesem Prozess und ermöglicht es jedem Python-Skript, Dateien einfach hochzuladen und zu laden. Es gibt vier Hauptwege, eine Bibliothek mit dem Hub zu integrieren: 1. **Push to Hub**: Implementieren Sie eine Methode, um ein Modell auf den Hub hochzuladen. Dies beinhaltet das Modellgewicht sowie [die Modellkarte](https://huggingface.co/docs/huggingface_hub/how-to-model-cards) und alle anderen relevanten Informationen oder Daten, die für den Betrieb des Modells erforderlich sind (zum Beispiel Trainingsprotokolle). Diese Methode wird oft `push_to_hub()` genannt. 2. **Download from Hub**: Implementieren Sie eine Methode, um ein Modell vom Hub zu laden. Die Methode sollte die Modellkonfiguration/-gewichte herunterladen und das Modell laden. Diese Methode wird oft `from_pretrained` oder `load_from_hub()` genannt. 3. **Inference API**: Nutzen Sie unsere Server, um Inferenz auf von Ihrer Bibliothek unterstützten Modellen kostenlos auszuführen. 4. **Widgets**: Zeigen Sie ein Widget auf der Landing Page Ihrer Modelle auf dem Hub an. Dies ermöglicht es Benutzern, ein Modell schnell aus dem Browser heraus auszuprobieren. In diesem Leitfaden konzentrieren wir uns auf die ersten beiden Themen. Wir werden die beiden Hauptansätze vorstellen, die Sie zur Integration einer Bibliothek verwenden können, mit ihren Vor- und Nachteilen. Am Ende des Leitfadens ist alles zusammengefasst, um Ihnen bei der Auswahl zwischen den beiden zu helfen. Bitte beachten Sie, dass dies nur Richtlinien sind, die Sie an Ihre Anforderungen anpassen können. Wenn Sie sich für Inferenz und Widgets interessieren, können Sie [diesem Leitfaden](https://huggingface.co/docs/hub/models-adding-libraries#set-up-the-inference-api) folgen. In beiden Fällen können Sie sich an uns wenden, wenn Sie eine Bibliothek mit dem Hub integrieren und [in unserer Dokumentation](https://huggingface.co/docs/hub/models-libraries) aufgeführt haben möchten. ## Ein flexibler Ansatz: Helfer Der erste Ansatz zur Integration einer Bibliothek in den Hub besteht tatsächlich darin, die `push_to_hub` und `from_pretrained` Methoden selbst zu implementieren. Dies gibt Ihnen volle Flexibilität hinsichtlich der Dateien, die Sie hoch-/herunterladen möchten, und wie Sie Eingaben, die speziell für Ihr Framework sind, behandeln. Sie können sich die beiden Leitfäden [Dateien hochladen](./upload) und [Dateien herunterladen](./download) ansehen, um mehr darüber zu erfahren, wie dies funktioniert. Dies ist zum Beispiel die Art und Weise, wie die FastAI-Integration implementiert ist (siehe [`push_to_hub_fastai`] und [`from_pretrained_fastai`]). Die Implementierung kann zwischen den Bibliotheken variieren, aber der Workflow ist oft ähnlich. ### from_pretrained So sieht eine `from_pretrained` Methode normalerweise aus: ```python def from_pretrained(model_id: str) -> MyModelClass: # Modell vom Hub herunterladen cached_model = hf_hub_download( repo_id=repo_id, filename="model.pkl", library_name="fastai", library_version=get_fastai_version(), ) # Modell laden return load_model(cached_model) ``` ### push_to_hub Die `push_to_hub` Methode erfordert oft etwas mehr Komplexität, um die Repo-Erstellung, die Generierung der Modellkarte und das Speichern von Gewichten zu behandeln. Ein üblicher Ansatz besteht darin, all diese Dateien in einem temporären Ordner zu speichern, ihn hochzuladen und dann zu löschen. ```python def push_to_hub(model: MyModelClass, repo_name: str) -> None: api = HfApi() # Repo erstellen, wenn noch nicht vorhanden und die zugehörige repo_id erhalten repo_id = api.create_repo(repo_name, exist_ok=True) # Modell in temporärem Ordner speichern und in einem enzigen Commit pushen with TemporaryDirectory() as tmpdir: tmpdir = Path(tmpdir) # Gewichte speichern save_model(model, tmpdir / "model.safetensors") # Modellkarte generieren card = generate_model_card(model) (tmpdir / "README.md").write_text(card) # Logs speichern # Diagramme speichern # Evaluationsmetriken speichern # ... # Auf den Hub pushen return api.upload_folder(repo_id=repo_id, folder_path=tmpdir) ``` Dies ist natürlich nur ein Beispiel. Wenn Sie an komplexeren Manipulationen interessiert sind (entfernen von entfernten Dateien, hochladen von Gewichten on-the-fly, lokales Speichern von Gewichten, usw.), beachten Sie bitte den [Dateien hochladen](./upload) Leitfaden. ### Einschränkungen Obwohl dieser Ansatz flexibel ist, hat er einige Nachteile, insbesondere in Bezug auf die Wartung. Hugging Face-Benutzer sind oft an zusätzliche Funktionen gewöhnt, wenn sie mit `huggingface_hub` arbeiten. Zum Beispiel ist es beim Laden von Dateien aus dem Hub üblich, Parameter wie folgt anzubieten: - `token`: zum Herunterladen aus einem privaten Repository - `revision`: zum Herunterladen von einem spezifischen Branch - `cache_dir`: um Dateien in einem spezifischen Verzeichnis zu cachen - `force_download`/`resume_download`/`local_files_only`: um den Cache wieder zu verwenden oder nicht - `api_endpoint`/`proxies`: HTTP-Session konfigurieren Beim Pushen von Modellen werden ähnliche Parameter unterstützt: - `commit_message`: benutzerdefinierte Commit-Nachricht - `private`: ein privates Repository erstellen, falls nicht vorhanden - `create_pr`: erstellen Sie einen PR anstatt auf `main` zu pushen - `branch`: auf einen Branch pushen anstatt auf den `main` Branch - `allow_patterns`/`ignore_patterns`: filtern, welche Dateien hochgeladen werden sollen - `token` - `api_endpoint` - ... Alle diese Parameter können den zuvor gesehenen Implementierungen hinzugefügt und an die `huggingface_hub`-Methoden übergeben werden. Wenn sich jedoch ein Parameter ändert oder eine neue Funktion hinzugefügt wird, müssen Sie Ihr Paket aktualisieren. Die Unterstützung dieser Parameter bedeutet auch mehr Dokumentation, die Sie auf Ihrer Seite pflegen müssen. Um zu sehen, wie man diese Einschränkungen mildert, springen wir zu unserem nächsten Abschnitt **Klassenvererbung**. ## Ein komplexerer Ansatz: Klassenvererbung Wie wir oben gesehen haben, gibt es zwei Hauptmethoden, um Ihre Bibliothek mit dem Hub zu integrieren: Dateien hochladen (`push_to_hub`) und Dateien herunterladen (`from_pretrained`). Sie können diese Methoden selbst implementieren, aber das hat seine Tücken. Um dies zu bewältigen, bietet `huggingface_hub` ein Werkzeug an, das Klassenvererbung verwendet. Schauen wir uns an, wie es funktioniert! In vielen Fällen implementiert eine Bibliothek ihr Modell bereits mit einer Python-Klasse. Die Klasse enthält die Eigenschaften des Modells und Methoden zum Laden, Ausführen, Trainieren und Evaluieren. Unser Ansatz besteht darin, diese Klasse zu erweitern, um Upload- und Download-Funktionen mit Mixins hinzuzufügen. Ein [Mixin](https://stackoverflow.com/a/547714) ist eine Klasse, die dazu bestimmt ist, eine vorhandene Klasse mit einem Satz spezifischer Funktionen durch Mehrfachvererbung zu erweitern. `huggingface_hub` bietet sein eigenes Mixin, das [`ModelHubMixin`]. Der Schlüssel hier ist zu verstehen, wie es funktioniert und wie man es anpassen kann. Die Klasse [ModelHubMixin] implementiert 3 *öffentliche* Methoden (`push_to_hub`, `save_pretrained` und `from_pretrained`). Dies sind die Methoden, die Ihre Benutzer aufrufen werden, um Modelle mit Ihrer Bibliothek zu laden/speichern. [`ModelHubMixin`] definiert auch 2 private Methoden (`_save_pretrained` und `_from_pretrained`). Diese müssen Sie implementieren. Um Ihre Bibliothek zu integrieren, sollten Sie: 1. Lassen Sie Ihre Modell-Klasse von [`ModelHubMixin`] erben. 2. Implementieren Sie die privaten Methoden: - [`~ModelHubMixin._save_pretrained`]: Methode, die als Eingabe einen Pfad zu einem Verzeichnis nimmt und das Modell dort speichert. Sie müssen die gesamte Logik zum Speichern Ihres Modells in dieser Methode schreiben: Modellkarte, Modellgewichte, Konfigurationsdateien, Trainingsprotokolle und Diagramme. Alle relevanten Informationen für dieses Modell müssen von dieser Methode behandelt werden. [Model Cards](https://huggingface.co/docs/hub/model-cards) sind besonders wichtig, um Ihr Modell zu beschreiben. Weitere Details finden Sie in [unserem Implementierungsleitfaden](./model-cards). - [~ModelHubMixin._from_pretrained]: **Klassenmethode**, die als Eingabe eine `model_id` nimmt und ein instanziiertes Modell zurückgibt. Die Methode muss die relevanten Dateien herunterladen und laden. 3. Sie sind fertig! Der Vorteil der Verwendung von [`ModelHubMixin`] besteht darin, dass Sie, sobald Sie sich um die Serialisierung/das Laden der Dateien gekümmert haben, bereit sind los zu legen. Sie müssen sich keine Gedanken über Dinge wie Repository-Erstellung, Commits, PRs oder Revisionen machen. All dies wird von dem Mixin gehandhabt und steht Ihren Benutzern zur Verfügung. Das Mixin stellt auch sicher, dass öffentliche Methoden gut dokumentiert und typisiert sind. ### Ein konkretes Beispiel: PyTorch Ein gutes Beispiel für das, was wir oben gesehen haben, ist [`PyTorchModelHubMixin`], unsere Integration für das PyTorch-Framework. Dies ist eine einsatzbereite Integration. #### Wie verwendet man es? Hier ist, wie jeder Benutzer ein PyTorch-Modell vom/auf den Hub laden/speichern kann: ```python >>> import torch >>> import torch.nn as nn >>> from huggingface_hub import PyTorchModelHubMixin # 1. Definieren Sie Ihr Pytorch-Modell genau so, wie Sie es gewohnt sind >>> class MyModel(nn.Module, PyTorchModelHubMixin): # Mehrfachvererbung ... def __init__(self): ... super().__init__() ... self.param = nn.Parameter(torch.rand(3, 4)) ... self.linear = nn.Linear(4, 5) ... def forward(self, x): ... return self.linear(x + self.param) >>> model = MyModel() # 2. (optional) Modell in lokales Verzeichnis speichern >>> model.save_pretrained("path/to/my-awesome-model") # 3. Modellgewichte an den Hub übertragen >>> model.push_to_hub("my-awesome-model") # 4. Modell vom Hub initialisieren >>> model = MyModel.from_pretrained("username/my-awesome-model") ``` #### Implementierung Die Implementierung ist tatsächlich sehr einfach, und die vollständige Implementierung finden Sie [hier](https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/hub_mixin.py). 1. Zuerst, erben Ihrer Klasse von `ModelHubMixin`: ```python from huggingface_hub import ModelHubMixin class PyTorchModelHubMixin(ModelHubMixin): (...) ``` 2. Implementieren der `_save_pretrained` Methode: ```py from huggingface_hub import ModelCard, ModelCardData class PyTorchModelHubMixin(ModelHubMixin): (...) def _save_pretrained(self, save_directory: Path): """Generiere Modellkarte und speichere Gewichte von einem Pytorch-Modell in einem lokalen Verzeichnis.""" model_card = ModelCard.from_template( card_data=ModelCardData( license='mit', library_name="pytorch", ... ), model_summary=..., model_type=..., ... ) (save_directory / "README.md").write_text(str(model)) torch.save(obj=self.module.state_dict(), f=save_directory / "pytorch_model.bin") ``` 3. Implementieren der `_from_pretrained` Methode: ```python class PyTorchModelHubMixin(ModelHubMixin): (...) @classmethod # Muss eine Klassenmethode sein! def _from_pretrained( cls, *, model_id: str, revision: str, cache_dir: str, force_download: bool, proxies: Optional[Dict], resume_download: bool, local_files_only: bool, token: Union[str, bool, None], map_location: str = "cpu", # zusätzliches Argument strict: bool = False, # zusätzliches Argument **model_kwargs, ): """Load Pytorch pretrained weights and return the loaded model.""" if os.path.isdir(model_id): # Kann entweder ein lokales Verzeichnis sein print("Loading weights from local directory") model_file = os.path.join(model_id, "pytorch_model.bin") else: # Oder ein Modell am Hub model_file = hf_hub_download( # Herunterladen vom Hub, gleiche Eingabeargumente repo_id=model_id, filename="pytorch_model.bin", revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, token=token, local_files_only=local_files_only, ) # Modell laden und zurückgeben - benutzerdefinierte Logik je nach Ihrem Framework model = cls(**model_kwargs) state_dict = torch.load(model_file, map_location=torch.device(map_location)) model.load_state_dict(state_dict, strict=strict) model.eval() return model ``` Und das war's! Ihre Bibliothek ermöglicht es Benutzern nun, Dateien vom und zum Hub hoch- und herunterzuladen. ## Kurzer Vergleich Lassen Sie uns die beiden Ansätze, die wir gesehen haben, schnell mit ihren Vor- und Nachteilen zusammenfassen. Die untenstehende Tabelle ist nur indikativ. Ihr Framework könnte einige Besonderheiten haben, die Sie berücksichtigen müssen. Dieser Leitfaden soll nur Richtlinien und Ideen geben, wie Sie die Integration handhaben können. Kontaktieren Sie uns in jedem Fall, wenn Sie Fragen haben! | Integration | Mit Helfern | Mit [`ModelHubMixin`] | |:---:|:---:|:---:| | Benutzererfahrung | `model = load_from_hub(...)`
`push_to_hub(model, ...)` | `model = MyModel.from_pretrained(...)`
`model.push_to_hub(...)` | | Flexibilität | Sehr flexibel.
Sie haben die volle Kontrolle über die Implementierung. | Weniger flexibel.
Ihr Framework muss eine Modellklasse haben. | | Wartung | Mehr Wartung, um Unterstützung für Konfiguration und neue Funktionen hinzuzufügen. Könnte auch das Beheben von Benutzerproblemen erfordern. | Weniger Wartung, da die meisten Interaktionen mit dem Hub in `huggingface_hub` implementiert sind. | | Dokumentation/Typ-Annotation| Manuell zu schreiben. | Teilweise durch `huggingface_hub` behandelt. | huggingface_hub-0.31.1/docs/source/de/guides/manage-cache.md000066400000000000000000000673321500667546600236460ustar00rootroot00000000000000 # Verwalten des `huggingface_hub` Cache-Systems ## Caching verstehen Das Hugging Face Hub Cache-System wurde entwickelt, um der zentrale Cache zu sein, der zwischen Bibliotheken geteilt wird, welche vom Hub abhängen. Es wurde in v0.8.0 aktualisiert, um das erneute Herunterladen von Dateien zwischen Revisionen zu verhindern. Das Cache-System ist wie folgt aufgebaut: ``` ├─ ├─ ├─ ``` Der `` ist normalerweise das Home-Verzeichnis Ihres Benutzers. Es kann jedoch mit dem `cache_dir`-Argument in allen Methoden oder durch Angabe der Umgebungsvariablen `HF_HOME` oder `HF_HUB_CACHE` angepasst werden. Modelle, Datensätze und Räume teilen eine gemeinsame Wurzel. Jedes dieser Repositories enthält den Repository-Typ, den Namensraum (Organisation oder Benutzername), falls vorhanden, und den Repository-Namen: ``` ├─ models--julien-c--EsperBERTo-small ├─ models--lysandrejik--arxiv-nlp ├─ models--bert-base-cased ├─ datasets--glue ├─ datasets--huggingface--DataMeasurementsFiles ├─ spaces--dalle-mini--dalle-mini ``` Innerhalb dieser Ordner werden nun alle Dateien vom Hub heruntergeladen. Das Caching stellt sicher, dass eine Datei nicht zweimal heruntergeladen wird, wenn sie bereits existiert und nicht aktualisiert wurde; wurde sie jedoch aktualisiert und Sie fordern die neueste Datei an, wird die neueste Datei heruntergeladen (während die vorherige Datei intakt bleibt, falls Sie sie erneut benötigen). Um dies zu erreichen, enthalten alle Ordner dasselbe Grundgerüst: ``` ├─ datasets--glue │ ├─ refs │ ├─ blobs │ ├─ snapshots ... ``` Jeder Ordner ist so gestaltet, dass er das Folgende enthält: ### Refs Der Ordner `refs` enthält Dateien, die die neueste Revision des gegebenen Verweises anzeigen. Zum Beispiel, wenn wir zuvor eine Datei aus dem `main`-Branch eines Repositories abgerufen haben, wird der Ordner `refs` eine Datei namens `main` enthalten, die selbst den Commit-Identifikator der aktuellen HEAD-Branch enthält. Wenn der neueste Commit von `main` den Identifikator `aaaaaa` hat, dann enthält er `aaaaaa`. Wenn derselbe Zweig mit einem neuen Commit aktualisiert wird, der den Identifikator `bbbbbb` hat, wird das erneute Herunterladen einer Datei von diesem Verweis die Datei `refs/main` aktualisieren, um `bbbbbb` zu enthalten. ### Blobs Der Ordner `blobs` enthält die tatsächlichen Dateien, die wir heruntergeladen haben. Der Name jeder Datei ist ihr Hash. ### Snapshots Der Ordner `snapshots` enthält Symlinks zu den oben erwähnten Blobs. Er besteht selbst aus mehreren Ordnern: einem pro bekannter Revision! In der obigen Erklärung hatten wir zunächst eine Datei von der Revision `aaaaaa` abgerufen, bevor wir eine Datei von der Revision `bbbbbb` abgerufen haben. In dieser Situation hätten wir jetzt zwei Ordner im Ordner `snapshots`: `aaaaaa` und `bbbbbb`. In jedem dieser Ordner leben Symlinks, die die Namen der Dateien haben, die wir heruntergeladen haben. Wenn wir zum Beispiel die Datei `README.md` in der Revision `aaaaaa` heruntergeladen hätten, hätten wir den folgenden Pfad: ``` //snapshots/aaaaaa/README.md ``` Diese `README.md`-Datei ist tatsächlich ein Symlink, der auf den Blob verweist, der den Hash der Datei hat. Durch das Erstellen des Grundgerüsts auf diese Weise ermöglichen wir den Mechanismus der Dateifreigabe: Wenn dieselbe Datei in der Revision `bbbbbb` abgerufen wurde, hätte sie denselben Hash und die Datei müsste nicht erneut heruntergeladen werden. ### .no_exist (fortgeschritten) Zusätzlich zu den Ordnern `blobs`, `refs` und `snapshots` könnten Sie in Ihrem Cache auch einen `.no_exist` Ordner finden. Dieser Ordner hält fest, welche Dateien Sie einmal versucht haben herunterzuladen, die jedoch nicht auf dem Hub vorhanden sind. Seine Struktur ist dieselbe wie der `snapshots` Ordner mit einem Unterordner pro bekannter Revision: ``` //.no_exist/aaaaaa/config_that_does_not_exist.json ``` Im Gegensatz zum `snapshots` Ordner handelt es sich bei den Dateien um einfache leere Dateien (keine Symlinks). In diesem Beispiel existiert die Datei `"config_that_does_not_exist.json"` nicht auf dem Hub für die Revision `"aaaaaa"`. Da dieser Ordner nur leere Dateien speichert, ist sein Speicherplatzverbrauch vernachlässigbar. Sie fragen sich jetzt vielleicht, warum diese Information überhaupt relevant ist? In einigen Fällen versucht ein Framework, optionale Dateien für ein Modell zu laden. Das Speichern der Nicht-Existenz optionaler Dateien beschleunigt das Laden eines Modells, da 1 HTTP-Anfrage pro möglicher optionaler Datei gespart wird. Dies ist zum Beispiel bei `transformers` der Fall, wo jeder Tokenizer zusätzliche Dateien unterstützen kann. Beim ersten Laden des Tokenizers auf Ihrem Gerät wird im Cache gespeichert, welche optionalen Dateien vorhanden sind (und welche nicht), um die Ladezeit bei den nächsten Initialisierungen zu beschleunigen. Um zu testen, ob eine Datei lokal im Cache gespeichert ist (ohne eine HTTP-Anfrage zu senden), können Sie die [`try_to_load_from_cache`] Hilfsfunktion verwenden. Sie gibt entweder den Dateipfad zurück (falls vorhanden und im Cache gespeichert), das Objekt `_CACHED_NO_EXIST` (wenn die Nicht-Existenz im Cache gespeichert ist) oder `None` (wenn wir es nicht wissen). ```python from huggingface_hub import try_to_load_from_cache, _CACHED_NO_EXIST filepath = try_to_load_from_cache() if isinstance(filepath, str): # file exists and is cached ... elif filepath is _CACHED_NO_EXIST: # non-existence of file is cached ... else: # file is not cached ... ``` ### In der Praxis In der Praxis sollte Ihr Cache folgendermaßen aussehen: ```text [ 96] . └── [ 160] models--julien-c--EsperBERTo-small ├── [ 160] blobs │ ├── [321M] 403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd │ ├── [ 398] 7cb18dc9bafbfcf74629a4b760af1b160957a83e │ └── [1.4K] d7edf6bd2a681fb0175f7735299831ee1b22b812 ├── [ 96] refs │ └── [ 40] main └── [ 128] snapshots ├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f │ ├── [ 52] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812 │ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd └── [ 128] bbc77c8132af1cc5cf678da3f1ddf2de43606d48 ├── [ 52] README.md -> ../../blobs/7cb18dc9bafbfcf74629a4b760af1b160957a83e └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd ``` ### Einschränkungen Um ein effizientes Cache-System zu haben, verwendet `huggingface-hub` Symlinks. Allerdings werden Symlinks nicht auf allen Maschinen unterstützt. Dies ist eine bekannte Einschränkung, insbesondere bei Windows. Wenn dies der Fall ist, verwendet `huggingface_hub` nicht das `blobs/` Verzeichnis, sondern speichert die Dateien direkt im `snapshots/` Verzeichnis. Dieser Workaround ermöglicht es den Nutzern, Dateien vom Hub auf genau die gleiche Weise herunterzuladen und zu cachen. Auch Werkzeuge zur Überprüfung und Löschung des Caches (siehe unten) werden unterstützt. Allerdings ist das Cache-System weniger effizient, da eine einzelne Datei möglicherweise mehrmals heruntergeladen wird, wenn mehrere Revisionen des gleichen Repos heruntergeladen werden. Wenn Sie von dem Symlink-basierten Cache-System auf einem Windows-Gerät profitieren möchten, müssen Sie entweder den [Entwicklermodus aktivieren](https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development) oder Python als Administrator ausführen. Wenn Symlinks nicht unterstützt werden, wird dem Nutzer eine Warnmeldung angezeigt, um ihn darauf hinzuweisen, dass er eine eingeschränkte Version des Cache-Systems verwendet. Diese Warnung kann durch Setzen der Umgebungsvariable `HF_HUB_DISABLE_SYMLINKS_WARNING` auf true deaktiviert werden. ## Assets zwischenspeichern Zusätzlich zum Zwischenspeichern von Dateien aus dem Hub benötigen nachgelagerte Bibliotheken oft das Zwischenspeichern von anderen Dateien, die in Verbindung mit HF stehen, aber nicht direkt von `huggingface_hub` behandelt werden (zum Beispiel: Dateien, die von GitHub heruntergeladen werden, vorverarbeitete Daten, Protokolle,...). Um diese Dateien, die als `assets` bezeichnet werden, zwischenzuspeichern, kann man [`cached_assets_path`] verwenden. Dieser kleine Helfer generiert Pfade im HF-Cache auf eine einheitliche Weise, basierend auf dem Namen der anfragenden Bibliothek und optional auf einem Namensraum und einem Unterordnernamen. Das Ziel ist, dass jede nachgelagerte Bibliothek ihre Assets auf ihre eigene Weise verwaltet (z.B. keine Regelung über die Struktur), solange sie im richtigen Assets-Ordner bleibt. Diese Bibliotheken können dann die Werkzeuge von `huggingface_hub` nutzen, um den Cache zu verwalten, insbesondere um Teile der Assets über einen CLI-Befehl zu scannen und zu löschen. ```py from huggingface_hub import cached_assets_path assets_path = cached_assets_path(library_name="datasets", namespace="SQuAD", subfolder="download") something_path = assets_path / "something.json" # Machen Sie, was Sie möchten, in Ihrem Assets-Ordner! ``` [`cached_assets_path`] ist der empfohlene Weg, um Assets zu speichern, ist jedoch nicht verpflichtend. Wenn Ihre Bibliothek bereits ihren eigenen Cache verwendet, können Sie diesen gerne nutzen! ### Assets in der Praxis In der Praxis sollte Ihr Assets-Cache wie der folgende Verzeichnisbaum aussehen: ```text assets/ └── datasets/ │ ├── SQuAD/ │ │ ├── downloaded/ │ │ ├── extracted/ │ │ └── processed/ │ ├── Helsinki-NLP--tatoeba_mt/ │ ├── downloaded/ │ ├── extracted/ │ └── processed/ └── transformers/ ├── default/ │ ├── something/ ├── bert-base-cased/ │ ├── default/ │ └── training/ hub/ └── models--julien-c--EsperBERTo-small/ ├── blobs/ │ ├── (...) │ ├── (...) ├── refs/ │ └── (...) └── [ 128] snapshots/ ├── 2439f60ef33a0d46d85da5001d52aeda5b00ce9f/ │ ├── (...) └── bbc77c8132af1cc5cf678da3f1ddf2de43606d48/ └── (...) ``` ## Cache scannen Derzeit werden zwischengespeicherte Dateien nie aus Ihrem lokalen Verzeichnis gelöscht: Wenn Sie eine neue Revision eines Zweiges herunterladen, werden vorherige Dateien aufbewahrt, falls Sie sie wieder benötigen. Daher kann es nützlich sein, Ihr Cache-Verzeichnis zu scannen, um zu erfahren, welche Repos und Revisionen den meisten Speicherplatz beanspruchen. `huggingface_hub` bietet einen Helfer dafür, der über `huggingface-cli` oder in einem Python-Skript verwendet werden kann. ### Cache vom Terminal aus scannen Die einfachste Möglichkeit, Ihr HF-Cache-System zu scannen, besteht darin, den Befehl `scan-cache` aus dem `huggingface-cli`-Tool zu verwenden. Dieser Befehl scannt den Cache und gibt einen Bericht mit Informationen wie Repo-ID, Repo-Typ, Speicherverbrauch, Referenzen und vollständigen lokalen Pfad aus. Im folgenden Ausschnitt wird ein Scan-Bericht in einem Ordner angezeigt, in dem 4 Modelle und 2 Datensätze gecached sind. ```text ➜ huggingface-cli scan-cache REPO ID REPO TYPE SIZE ON DISK NB FILES LAST_ACCESSED LAST_MODIFIED REFS LOCAL PATH --------------------------- --------- ------------ -------- ------------- ------------- ------------------- ------------------------------------------------------------------------- glue dataset 116.3K 15 4 days ago 4 days ago 2.4.0, main, 1.17.0 /home/wauplin/.cache/huggingface/hub/datasets--glue google/fleurs dataset 64.9M 6 1 week ago 1 week ago refs/pr/1, main /home/wauplin/.cache/huggingface/hub/datasets--google--fleurs Jean-Baptiste/camembert-ner model 441.0M 7 2 weeks ago 16 hours ago main /home/wauplin/.cache/huggingface/hub/models--Jean-Baptiste--camembert-ner bert-base-cased model 1.9G 13 1 week ago 2 years ago /home/wauplin/.cache/huggingface/hub/models--bert-base-cased t5-base model 10.1K 3 3 months ago 3 months ago main /home/wauplin/.cache/huggingface/hub/models--t5-base t5-small model 970.7M 11 3 days ago 3 days ago refs/pr/1, main /home/wauplin/.cache/huggingface/hub/models--t5-small Done in 0.0s. Scanned 6 repo(s) for a total of 3.4G. Got 1 warning(s) while scanning. Use -vvv to print details. ``` Um einen detaillierteren Bericht zu erhalten, verwenden Sie die Option `--verbose`. Für jedes Repository erhalten Sie eine Liste aller heruntergeladenen Revisionen. Wie oben erläutert, werden Dateien, die sich zwischen 2 Revisionen nicht ändern, dank der symbolischen Links geteilt. Das bedeutet, dass die Größe des Repositorys auf der Festplatte voraussichtlich kleiner ist als die Summe der Größe jeder einzelnen Revision. Zum Beispiel hat hier `bert-base-cased` 2 Revisionen von 1,4G und 1,5G, aber der gesamte Festplattenspeicher beträgt nur 1,9G. ```text ➜ huggingface-cli scan-cache -v REPO ID REPO TYPE REVISION SIZE ON DISK NB FILES LAST_MODIFIED REFS LOCAL PATH --------------------------- --------- ---------------------------------------- ------------ -------- ------------- ----------- ---------------------------------------------------------------------------------------------------------------------------- glue dataset 9338f7b671827df886678df2bdd7cc7b4f36dffd 97.7K 14 4 days ago main, 2.4.0 /home/wauplin/.cache/huggingface/hub/datasets--glue/snapshots/9338f7b671827df886678df2bdd7cc7b4f36dffd glue dataset f021ae41c879fcabcf823648ec685e3fead91fe7 97.8K 14 1 week ago 1.17.0 /home/wauplin/.cache/huggingface/hub/datasets--glue/snapshots/f021ae41c879fcabcf823648ec685e3fead91fe7 google/fleurs dataset 129b6e96cf1967cd5d2b9b6aec75ce6cce7c89e8 25.4K 3 2 weeks ago refs/pr/1 /home/wauplin/.cache/huggingface/hub/datasets--google--fleurs/snapshots/129b6e96cf1967cd5d2b9b6aec75ce6cce7c89e8 google/fleurs dataset 24f85a01eb955224ca3946e70050869c56446805 64.9M 4 1 week ago main /home/wauplin/.cache/huggingface/hub/datasets--google--fleurs/snapshots/24f85a01eb955224ca3946e70050869c56446805 Jean-Baptiste/camembert-ner model dbec8489a1c44ecad9da8a9185115bccabd799fe 441.0M 7 16 hours ago main /home/wauplin/.cache/huggingface/hub/models--Jean-Baptiste--camembert-ner/snapshots/dbec8489a1c44ecad9da8a9185115bccabd799fe bert-base-cased model 378aa1bda6387fd00e824948ebe3488630ad8565 1.5G 9 2 years ago /home/wauplin/.cache/huggingface/hub/models--bert-base-cased/snapshots/378aa1bda6387fd00e824948ebe3488630ad8565 bert-base-cased model a8d257ba9925ef39f3036bfc338acf5283c512d9 1.4G 9 3 days ago main /home/wauplin/.cache/huggingface/hub/models--bert-base-cased/snapshots/a8d257ba9925ef39f3036bfc338acf5283c512d9 t5-base model 23aa4f41cb7c08d4b05c8f327b22bfa0eb8c7ad9 10.1K 3 1 week ago main /home/wauplin/.cache/huggingface/hub/models--t5-base/snapshots/23aa4f41cb7c08d4b05c8f327b22bfa0eb8c7ad9 t5-small model 98ffebbb27340ec1b1abd7c45da12c253ee1882a 726.2M 6 1 week ago refs/pr/1 /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/98ffebbb27340ec1b1abd7c45da12c253ee1882a t5-small model d0a119eedb3718e34c648e594394474cf95e0617 485.8M 6 4 weeks ago /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/d0a119eedb3718e34c648e594394474cf95e0617 t5-small model d78aea13fa7ecd06c29e3e46195d6341255065d5 970.7M 9 1 week ago main /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/d78aea13fa7ecd06c29e3e46195d6341255065d5 Done in 0.0s. Scanned 6 repo(s) for a total of 3.4G. Got 1 warning(s) while scanning. Use -vvv to print details. ``` #### Grep-Beispiel Da die Ausgabe im Tabellenformat erfolgt, können Sie sie mit `grep`-ähnlichen Tools kombinieren, um die Einträge zu filtern. Hier ein Beispiel, um nur Revisionen vom Modell "t5-small" auf einem Unix-basierten Gerät zu filtern. ```text ➜ eval "huggingface-cli scan-cache -v" | grep "t5-small" t5-small model 98ffebbb27340ec1b1abd7c45da12c253ee1882a 726.2M 6 1 week ago refs/pr/1 /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/98ffebbb27340ec1b1abd7c45da12c253ee1882a t5-small model d0a119eedb3718e34c648e594394474cf95e0617 485.8M 6 4 weeks ago /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/d0a119eedb3718e34c648e594394474cf95e0617 t5-small model d78aea13fa7ecd06c29e3e46195d6341255065d5 970.7M 9 1 week ago main /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/d78aea13fa7ecd06c29e3e46195d6341255065d5 ``` ### Den Cache von Python aus scannen Für eine erweiterte Nutzung verwenden Sie [`scan_cache_dir`], welches das von dem CLI-Tool aufgerufene Python-Dienstprogramm ist. Sie können es verwenden, um einen detaillierten Bericht zu erhalten, der um 4 Datenklassen herum strukturiert ist: - [`HFCacheInfo`]: vollständiger Bericht, der von [`scan_cache_dir`] zurückgegeben wird - [`CachedRepoInfo`]: Informationen über ein gecachtes Repo - [`CachedRevisionInfo`]: Informationen über eine gecachtes Revision (z.B. "snapshot) in einem Repo - [`CachedFileInfo`]: Informationen über eine gecachte Datei in einem Snapshot Hier ist ein einfaches Anwendungs-Beispiel in Python. Siehe Referenz für Details. ```py >>> from huggingface_hub import scan_cache_dir >>> hf_cache_info = scan_cache_dir() HFCacheInfo( size_on_disk=3398085269, repos=frozenset({ CachedRepoInfo( repo_id='t5-small', repo_type='model', repo_path=PosixPath(...), size_on_disk=970726914, nb_files=11, last_accessed=1662971707.3567169, last_modified=1662971107.3567169, revisions=frozenset({ CachedRevisionInfo( commit_hash='d78aea13fa7ecd06c29e3e46195d6341255065d5', size_on_disk=970726339, snapshot_path=PosixPath(...), # No `last_accessed` as blobs are shared among revisions last_modified=1662971107.3567169, files=frozenset({ CachedFileInfo( file_name='config.json', size_on_disk=1197 file_path=PosixPath(...), blob_path=PosixPath(...), blob_last_accessed=1662971707.3567169, blob_last_modified=1662971107.3567169, ), CachedFileInfo(...), ... }), ), CachedRevisionInfo(...), ... }), ), CachedRepoInfo(...), ... }), warnings=[ CorruptedCacheException("Snapshots dir doesn't exist in cached repo: ..."), CorruptedCacheException(...), ... ], ) ``` ## Cache leeren Das Durchsuchen Ihres Caches ist interessant, aber was Sie normalerweise als Nächstes tun möchten, ist einige Teile zu löschen, um etwas Speicherplatz auf Ihrem Laufwerk freizugeben. Dies ist möglich mit dem `delete-cache` CLI-Befehl. Man kann auch programmatisch den [`~HFCacheInfo.delete_revisions`] Helfer vom [`HFCacheInfo`] Objekt verwenden, das beim Durchsuchen des Caches zurückgegeben wird. ### Löschstrategie Um einige Cache zu löschen, müssen Sie eine Liste von Revisionen übergeben, die gelöscht werden sollen. Das Tool wird eine Strategie definieren, um den Speicherplatz auf der Grundlage dieser Liste freizugeben. Es gibt ein [`DeleteCacheStrategy`] Objekt zurück, das beschreibt, welche Dateien und Ordner gelöscht werden. Die [`DeleteCacheStrategy`] zeigt Ihnen, wie viel Speicherplatz voraussichtlich frei wird. Sobald Sie mit der Löschung einverstanden sind, müssen Sie sie ausführen, um die Löschung wirksam zu machen. Um Abweichungen zu vermeiden, können Sie ein Strategieobjekt nicht manuell bearbeiten. Die Strategie zur Löschung von Revisionen ist folgende: - Der Ordner `snapshot`, der die Revisions-Symlinks enthält, wird gelöscht. - Blob-Dateien, die nur von zu löschenden Revisionen verlinkt werden, werden ebenfalls gelöscht. - Wenn eine Revision mit 1 oder mehreren `refs` verknüpft ist, werden die Referenzen gelöscht. - Werden alle Revisionen aus einem Repo gelöscht, wird das gesamte zwischengespeicherte Repository gelöscht. Revisions-Hashes sind eindeutig über alle Repositories hinweg. Das bedeutet, dass Sie keine `repo_id` oder `repo_type` angeben müssen, wenn Sie Revisionen entfernen. Wenn eine Revision im Cache nicht gefunden wird, wird sie stillschweigend ignoriert. Außerdem wird, wenn eine Datei oder ein Ordner beim Versuch, ihn zu löschen, nicht gefunden wird, eine Warnung protokolliert, aber es wird kein Fehler ausgelöst. Die Löschung wird für andere Pfade im [`DeleteCacheStrategy`] Objekt fortgesetzt. ### Cache vom Terminal aus leeren Der einfachste Weg, einige Revisionen aus Ihrem HF-Cache-System zu löschen, ist die Verwendung des `delete-cache` Befehls vom `huggingface-cli` Tool. Der Befehl hat zwei Modi. Standardmäßig wird dem Benutzer eine TUI (Terminal User Interface) angezeigt, um auszuwählen, welche Revisionen gelöscht werden sollen. Diese TUI befindet sich derzeit in der Beta-Phase, da sie nicht auf allen Plattformen getestet wurde. Wenn die TUI auf Ihrem Gerät nicht funktioniert, können Sie sie mit dem Flag `--disable-tui` deaktivieren. #### Verwendung der TUI Dies ist der Standardmodus. Um ihn zu nutzen, müssen Sie zuerst zusätzliche Abhängigkeiten installieren, indem Sie den folgenden Befehl ausführen: ``` pip install huggingface_hub["cli"] ``` Führen Sie dann den Befehl aus: ``` huggingface-cli delete-cache ``` Sie sollten jetzt eine Liste von Revisionen sehen, die Sie auswählen/abwählen können:
Anleitung: - Drücken Sie die Pfeiltasten `Hoch>` und `` auf der Tastatur, um den Cursor zu bewegen. - Drücken Sie ``, um einen Eintrag zu wechseln (auswählen/abwählen). - Wenn eine Revision ausgewählt ist, wird die erste Zeile aktualisiert, um Ihnen anzuzeigen, wie viel Speicherplatz freigegeben wird. - Drücken Sie ``, um Ihre Auswahl zu bestätigen. - Wenn Sie den Vorgang abbrechen und beenden möchten, können Sie den ersten Eintrag ("None of the following") auswählen. Wenn dieser Eintrag ausgewählt ist, wird der Löschvorgang abgebrochen, unabhängig davon, welche anderen Einträge ausgewählt sind. Alternativ können Sie auch ` `drücken, um die TUI zu verlassen. Nachdem Sie die Revisionen ausgewählt haben, die Sie löschen möchten, und `` gedrückt haben, wird eine letzte Bestätigungsnachricht angezeigt. Drücken Sie erneut ``, und die Löschung wird wirksam. Wenn Sie abbrechen möchten, geben Sie `n` ein. ```txt ✗ huggingface-cli delete-cache --dir ~/.cache/huggingface/hub ? Select revisions to delete: 2 revision(s) selected. ? 2 revisions selected counting for 3.1G. Confirm deletion ? Yes Start deletion. Done. Deleted 1 repo(s) and 0 revision(s) for a total of 3.1G. ``` #### Ohne TUI Wie bereits erwähnt, befindet sich der TUI-Modus derzeit in der Beta-Phase und ist optional. Es könnte sein, dass er auf Ihrem Gerät nicht funktioniert oder dass Sie ihn nicht als praktisch finden. Ein anderer Ansatz besteht darin, das Flag `--disable-tui` zu verwenden. Der Vorgang ähnelt sehr dem vorherigen, da Sie aufgefordert werden, die Liste der zu löschenden Revisionen manuell zu überprüfen. Dieser manuelle Schritt findet jedoch nicht direkt im Terminal statt, sondern in einer temporären Datei, die ad hoc generiert wird und die Sie manuell bearbeiten können. Diese Datei enthält alle erforderlichen Anweisungen im Kopfteil. Öffnen Sie sie in Ihrem bevorzugten Texteditor. Um eine Revision auszuwählen/abzuwählen, kommentieren Sie sie einfach mit einem `#` aus oder ein. Sobald die manuelle Überprüfung abgeschlossen ist und die Datei bearbeitet wurde, können Sie sie speichern. Gehen Sie zurück zu Ihrem Terminal und drücken Sie ``. Standardmäßig wird berechnet, wie viel Speicherplatz mit der aktualisierten Revisionsliste freigegeben würde. Sie können die Datei weiter bearbeiten oder mit `"y"` bestätigen. ```sh huggingface-cli delete-cache --disable-tui ``` Beispiel für eine Befehlsdatei: ```txt # INSTRUCTIONS # ------------ # This is a temporary file created by running `huggingface-cli delete-cache` with the # `--disable-tui` option. It contains a set of revisions that can be deleted from your # local cache directory. # # Please manually review the revisions you want to delete: # - Revision hashes can be commented out with '#'. # - Only non-commented revisions in this file will be deleted. # - Revision hashes that are removed from this file are ignored as well. # - If `CANCEL_DELETION` line is uncommented, the all cache deletion is cancelled and # no changes will be applied. # # Once you've manually reviewed this file, please confirm deletion in the terminal. This # file will be automatically removed once done. # ------------ # KILL SWITCH # ------------ # Un-comment following line to completely cancel the deletion process # CANCEL_DELETION # ------------ # REVISIONS # ------------ # Dataset chrisjay/crowd-speech-africa (761.7M, used 5 days ago) ebedcd8c55c90d39fd27126d29d8484566cd27ca # Refs: main # modified 5 days ago # Dataset oscar (3.3M, used 4 days ago) # 916f956518279c5e60c63902ebdf3ddf9fa9d629 # Refs: main # modified 4 days ago # Dataset wikiann (804.1K, used 2 weeks ago) 89d089624b6323d69dcd9e5eb2def0551887a73a # Refs: main # modified 2 weeks ago # Dataset z-uo/male-LJSpeech-italian (5.5G, used 5 days ago) # 9cfa5647b32c0a30d0adfca06bf198d82192a0d1 # Refs: main # modified 5 days ago ``` ### Cache aus Python leeren Für mehr Flexibilität können Sie auch die Methode [`~HFCacheInfo.delete_revisions`] programmatisch verwenden. Hier ist ein einfaches Beispiel. Siehe Referenz für Details. ```py >>> from huggingface_hub import scan_cache_dir >>> delete_strategy = scan_cache_dir().delete_revisions( ... "81fd1d6e7847c99f5862c9fb81387956d99ec7aa" ... "e2983b237dccf3ab4937c97fa717319a9ca1a96d", ... "6c0e6080953db56375760c0471a8c5f2929baf11", ... ) >>> print("Will free " + delete_strategy.expected_freed_size_str) Will free 8.6G >>> delete_strategy.execute() Cache deletion done. Saved 8.6G. ``` huggingface_hub-0.31.1/docs/source/de/guides/manage-spaces.md000066400000000000000000000337211500667546600240540ustar00rootroot00000000000000 # Verwalten Ihres Spaces (Bereiches) In diesem Leitfaden werden wir sehen, wie man den Laufzeitbereich eines Space ([Geheimnisse (Secrets)](https://huggingface.co/docs/hub/spaces-overview#managing-secrets), [Hardware](https://huggingface.co/docs/hub/spaces-gpus) und Speicher (Storage)) mit `huggingface_hub` verwaltet. ## Ein einfaches Beispiel: Konfigurieren von Geheimnissen und Hardware Hier ist ein End-to-End-Beispiel, um einen Space auf dem Hub zu erstellen und einzurichten. **1. Einen Space auf dem Hub erstellen.** ```py >>> from huggingface_hub import HfApi >>> repo_id = "Wauplin/my-cool-training-space" >>> api = HfApi() # Zum Beispiel mit einem Gradio SDK >>> api.create_repo(repo_id=repo_id, repo_type="space", space_sdk="gradio") ``` **1. (bis) Duplizieren eines Space.** Das kann nützlich sein, wenn Sie auf einem bestehenden Space aufbauen möchten, anstatt von Grund auf neu zu beginnen. Es ist auch nützlich, wenn Sie die Kontrolle über die Konfiguration/Einstellungen eines öffentlichen Space haben möchten. Siehe [`duplicate_space`] für weitere Details. ```py >>> api.duplicate_space("multimodalart/dreambooth-training") ``` **2. Code mit bevorzugter Lösung hochladen.** Hier ist ein Beispiel, wie man den lokalen Ordner `src/` von Ihrem Computer in Ihren Space hochlädt: ```py >>> api.upload_folder(repo_id=repo_id, repo_type="space", folder_path="src/") ``` In diesem Schritt sollte Ihre App bereits kostenlos auf dem Hub laufen! Möglicherweise möchten Sie sie jedoch weiterhin mit Geheimnissen und aufgerüsteter Hardware konfigurieren. **3. Konfigurieren von Geheimnissen und Variablen** Ihr Space könnte einige geheime Schlüssel, Tokens oder Variablen benötigen, um zu funktionieren. Siehe [Dokumentation](https://huggingface.co/docs/hub/spaces-overview#managing-secrets) für weitere Details. Zum Beispiel ein HF-Token, um einen Bilddatensatz auf den Hub hochzuladen, sobald er aus Ihrem Space generiert wurde. ```py >>> api.add_space_secret(repo_id=repo_id, key="HF_TOKEN", value="hf_api_***") >>> api.add_space_variable(repo_id=repo_id, key="MODEL_REPO_ID", value="user/repo") ``` Geheimnisse und Variablen können auch gelöscht werden: ```py >>> api.delete_space_secret(repo_id=repo_id, key="HF_TOKEN") >>> api.delete_space_variable(repo_id=repo_id, key="MODEL_REPO_ID") ``` Innerhalb Ihres Space sind Geheimnisse als Umgebungsvariablen verfügbar (oder Streamlit Secrets Management, wenn Streamlit verwendet wird). Keine Notwendigkeit, sie über die API abzurufen! Jede Änderung in der Konfiguration Ihres Space (Geheimnisse oder Hardware) wird einen Neustart Ihrer App auslösen. **Bonus: Geheimnisse und Variablen beim Erstellen oder Duplizieren des Space festlegen!** Geheimnisse und Variablen können beim Erstellen oder Duplizieren eines Space gesetzt werden: ```py >>> api.create_repo( ... repo_id=repo_id, ... repo_type="space", ... space_sdk="gradio", ... space_secrets=[{"key"="HF_TOKEN", "value"="hf_api_***"}, ...], ... space_variables=[{"key"="MODEL_REPO_ID", "value"="user/repo"}, ...], ... ) ``` ```py >>> api.duplicate_space( ... from_id=repo_id, ... secrets=[{"key"="HF_TOKEN", "value"="hf_api_***"}, ...], ... variables=[{"key"="MODEL_REPO_ID", "value"="user/repo"}, ...], ... ) ``` **4. Konfigurieren von Hardware** Standardmäßig wird Ihr Space kostenlos in einer CPU-Umgebung ausgeführt. Sie können die Hardware aktualisieren, um sie auf GPUs laufen zu lassen. Eine Zahlungskarte oder ein Community-Grant wird benötigt, um Ihren Space zu aktualisieren. Siehe [Dokumentation](https://huggingface.co/docs/hub/spaces-gpus) für weitere Details. ```py # Verwenden von `SpaceHardware` Enum >>> from huggingface_hub import SpaceHardware >>> api.request_space_hardware(repo_id=repo_id, hardware=SpaceHardware.T4_MEDIUM) # Oder einfach einen String-Wert angeben >>> api.request_space_hardware(repo_id=repo_id, hardware="t4-medium") ``` Hardware-Aktualisierungen erfolgen nicht sofort, da Ihr Space auf unseren Servern neu geladen werden muss. Jederzeit können Sie überprüfen, auf welcher Hardware Ihr Space läuft, um zu sehen, ob Ihre Anfrage erfüllt wurde. ```py >>> runtime = api.get_space_runtime(repo_id=repo_id) >>> runtime.stage "RUNNING_BUILDING" >>> runtime.hardware "cpu-basic" >>> runtime.requested_hardware "t4-medium" ``` Sie verfügen jetzt über einen vollständig konfigurierten Space. Stellen Sie sicher, dass Sie Ihren Space wieder auf "cpu-classic" zurückstufen, wenn Sie ihn nicht mehr verwenden. **Bonus: Hardware beim Erstellen oder Duplizieren des Space anfordern!** Aktualisierte Hardware wird Ihrem Space automatisch zugewiesen, sobald er erstellt wurde. ```py >>> api.create_repo( ... repo_id=repo_id, ... repo_type="space", ... space_sdk="gradio" ... space_hardware="cpu-upgrade", ... space_storage="small", ... space_sleep_time="7200", # 2 hours in secs ... ) ``` ```py >>> api.duplicate_space( ... from_id=repo_id, ... hardware="cpu-upgrade", ... storage="small", ... sleep_time="7200", # 2 hours in secs ... ) ``` **5. Pausieren und Neustarten des Spaces** Standardmäßig, wenn Ihr Space auf augewerteter Hardware läuft, wird er nie angehalten. Um jedoch zu vermeiden, dass Ihnen Gebühren berechnet werden, möchten Sie ihn möglicherweise anhalten, wenn Sie ihn nicht verwenden. Dies ist mit [`pause_space`] möglich. Ein pausierter Space bleibt inaktiv, bis der Besitzer des Space ihn entweder über die Benutzeroberfläche oder über die API mit [`restart_space`] neu startet. Weitere Informationen zum Pausenmodus finden Sie in [diesem Abschnitt](https://huggingface.co/docs/hub/spaces-gpus#pause). ```py # Pausieren des Space, um Gebühren zu vermeiden >>> api.pause_space(repo_id=repo_id) # (...) # Erneut starten, wenn benötigt >>> api.restart_space(repo_id=repo_id) ``` Eine weitere Möglichkeit besteht darin, für Ihren Space einen Timeout festzulegen. Wenn Ihr Space länger als die Timeout-Dauer inaktiv ist, wird er in den Schlafmodus versetzt. Jeder Besucher, der auf Ihren Space zugreift, wird ihn wieder starten. Sie können ein Timeout mit [`set_space_sleep_time`] festlegen. Weitere Informationen zum Schlafmodus finden Sie in [diesem Abschnitt](https://huggingface.co/docs/hub/spaces-gpus#sleep-time). ```py # Setzen den Space nach 1h Inaktivität in den Schlafmodus >>> api.set_space_sleep_time(repo_id=repo_id, sleep_time=3600) ``` Hinweis: Wenn Sie eine 'cpu-basic' Hardware verwenden, können Sie keine benutzerdefinierte Schlafzeit konfigurieren. Ihr Space wird automatisch nach 48h Inaktivität pausiert. **Bonus: Schlafzeit festlegen, während der Hardwareanforderung** Aufgewertete Hardware wird Ihrem Space automatisch zugewiesen, sobald er erstellt wurde. ```py >>> api.request_space_hardware(repo_id=repo_id, hardware=SpaceHardware.T4_MEDIUM, sleep_time=3600) ``` **Bonus: Schlafzeit beim Erstellen oder Duplizieren des Space festlegen!** ```py >>> api.create_repo( ... repo_id=repo_id, ... repo_type="space", ... space_sdk="gradio" ... space_hardware="t4-medium", ... space_sleep_time="3600", ... ) ``` ```py >>> api.duplicate_space( ... from_id=repo_id, ... hardware="t4-medium", ... sleep_time="3600", ... ) ``` **6. Dem Space dauerhaften Speicherplatz hinzufügen** Sie können den Speicher-Tier Ihrer Wahl auswählen, um auf Festplattenspeicher zuzugreifen, der Neustarts Ihres Space überdauert. Dies bedeutet, dass Sie von der Festplatte lesen und darauf schreiben können, wie Sie es von einer herkömmlichen Festplatte gewöhnt sind. Weitere Informationen finden Sie in der [Dokumentation](https://huggingface.co/docs/hub/spaces-storage#persistent-storage) . ```py >>> from huggingface_hub import SpaceStorage >>> api.request_space_storage(repo_id=repo_id, storage=SpaceStorage.LARGE) ``` Sie können auch Ihren Speicher löschen und dabei alle Daten dauerhaft verlieren. ```py >>> api.delete_space_storage(repo_id=repo_id) ``` Hinweis: Nachdem Ihnen ein Speicher-Tier zugewiesen wurde, können Sie diesen nicht mehr herabsetzen. Um dies zu tun, müssen Sie zuerst den Speicher löschen und dann den gewünschten Tier anfordern. **Bonus: Speicher beim Erstellen oder Duplizieren des Space anfordern!** ```py >>> api.create_repo( ... repo_id=repo_id, ... repo_type="space", ... space_sdk="gradio" ... space_storage="large", ... ) ``` ```py >>> api.duplicate_space( ... from_id=repo_id, ... storage="large", ... ) ``` ## Fortgeschritten: Temporäres Space Upgrade Spaces ermöglichen viele verschiedene Einsatzmöglichkeiten. Manchmal möchten Sie vielleicht einen Space vorübergehend auf einer bestimmten Hardware ausführen, etwas tun und ihn dann herunterfahren. In diesem Abschnitt werden wir untersuchen, wie Sie die Vorteile von Spaces nutzen können, um ein Modell auf Abruf zu finetunen. Dies ist nur eine Möglichkeit, dieses spezielle Problem zu lösen. Es sollte als Vorschlag betrachtet und an Ihren Anwendungsfall angepasst werden. Nehmen wir an, wir haben einen Space, um ein Modell zu finetunen. Es handelt sich um eine Gradio-App, die ein Modell-Id und eine Dataset-Id als Eingabe nimmt. Der Ablauf sieht folgendermaßen aus: 0. (Den Benutzer nach einem Modell und einem Datensatz auffordern) 1. Das Modell aus dem Hub laden. 2. Den Datensatz aus dem Hub laden. 3. Das Modell mit dem Datensatz finetunen. 4. Das neue Modell auf den Hub hochladen. Schritt 3 erfordert eine spezielle Hardware, aber Sie möchten nicht, dass Ihr Space die ganze Zeit auf einer kostenpflichtigen GPU läuft. Eine Lösung besteht darin, dynamisch Hardware für das Training anzufordern und es anschließend herunterzufahren. Da das Anfordern von Hardware Ihren Space neu startet, muss sich Ihre App irgendwie die aktuelle Aufgabe "merken", die sie ausführt. Es gibt mehrere Möglichkeiten, dies zu tun. In diesem Leitfaden sehen wir eine Lösung, bei der ein Datensatz als "Aufgabenplaner (task scheduler)" verwendet wird. ### App-Grundgerüst So würde Ihre App aussehen. Beim Start überprüfen, ob eine Aufgabe geplant ist und ob ja, führen Sie sie auf der richtigen Hardware aus. Ist die Aufgabe erledigt, setzen Sie die Hardware zurück auf den kostenlosen CPU-Plan und fordern den Benutzer auf, eine neue Aufgabe anzufordern. Ein solcher Workflow unterstützt keinen gleichzeitigen Zugriff wie normale Demos. Insbesondere wird die Schnittstelle deaktiviert, wenn das Training stattfindet. Es ist vorzuziehen, Ihr Repo auf privat zu setzen, um sicherzustellen, dass Sie der einzige Benutzer sind. ```py # Für den Space wird Ihr Token benötigt, um Hardware anzufordern: Legen Sie es als Geheimnis fest! HF_TOKEN = os.environ.get("HF_TOKEN") # Eigene repo_id des Space TRAINING_SPACE_ID = "Wauplin/dreambooth-training" from huggingface_hub import HfApi, SpaceHardware api = HfApi(token=HF_TOKEN) # Beim Start des Space überprüfen, ob eine Aufgabe geplant ist. Wenn ja, finetunen Sie das Modell. # Wenn nicht, zeigen Sie eine Schnittstelle an, um eine neue Aufgabe anzufordern. task = get_task() if task is None: # Starten der Gradio-App def gradio_fn(task): # Bei Benutzeranfrage, Aufgabe hinzufügen und Hardware anfordern add_task(task) api.request_space_hardware(repo_id=TRAINING_SPACE_ID, hardware=SpaceHardware.T4_MEDIUM) gr.Interface(fn=gradio_fn, ...).launch() else: runtime = api.get_space_runtime(repo_id=TRAINING_SPACE_ID) # Überprüfen, ob der Space mit einer GPU geladen ist. if runtime.hardware == SpaceHardware.T4_MEDIUM: # Wenn ja, finetunen des Basismodells auf den Datensatz! train_and_upload(task) # Dann die Aufgabe als "DONE / ERLEDIGT" markieren mark_as_done(task) # NICHT VERGESSEN: CPU-Hardware zurück setzen api.request_space_hardware(repo_id=TRAINING_SPACE_ID, hardware=SpaceHardware.CPU_BASIC) else: api.request_space_hardware(repo_id=TRAINING_SPACE_ID, hardware=SpaceHardware.T4_MEDIUM) ``` ### Aufgabenplaner (Task scheduler) Das Planen von Aufgaben kann auf viele Arten erfolgen. Hier ist ein Beispiel, wie es mit einer einfachen CSV gemacht werden könnte, die als Datensatz gespeichert ist. ```py # Dataset-ID, in der eine `tasks.csv` Datei die auszuführenden Aufgaben enthält. # Hier ist ein einfaches Beispiel für `tasks.csv`, das Eingaben (Basis-Modell und Datensatz) # und Status (PENDING / AUSSTEHEND oder DONE / ERLEDIGT) enthält. # multimodalart/sd-fine-tunable,Wauplin/concept-1,DONE # multimodalart/sd-fine-tunable,Wauplin/concept-2,PENDING TASK_DATASET_ID = "Wauplin/dreambooth-task-scheduler" def _get_csv_file(): return hf_hub_download(repo_id=TASK_DATASET_ID, filename="tasks.csv", repo_type="dataset", token=HF_TOKEN) def get_task(): with open(_get_csv_file()) as csv_file: csv_reader = csv.reader(csv_file, delimiter=',') for row in csv_reader: if row[2] == "PENDING": return row[0], row[1] # model_id, dataset_id def add_task(task): model_id, dataset_id = task with open(_get_csv_file()) as csv_file: with open(csv_file, "r") as f: tasks = f.read() api.upload_file( repo_id=repo_id, repo_type=repo_type, path_in_repo="tasks.csv", # Schnelle und einfache Möglichkeit, eine Aufgabe hinzuzufügen path_or_fileobj=(tasks + f"\n{model_id},{dataset_id},PENDING").encode() ) def mark_as_done(task): model_id, dataset_id = task with open(_get_csv_file()) as csv_file: with open(csv_file, "r") as f: tasks = f.read() api.upload_file( repo_id=repo_id, repo_type=repo_type, path_in_repo="tasks.csv", # Schnelle und einfache Möglichkeit, die Aufgabe als DONE / ERLEDIGT zu markieren path_or_fileobj=tasks.replace( f"{model_id},{dataset_id},PENDING", f"{model_id},{dataset_id},DONE" ).encode() ) ``` huggingface_hub-0.31.1/docs/source/de/guides/model-cards.md000066400000000000000000000225301500667546600235360ustar00rootroot00000000000000 # Erstellen und Teilen von Model Cards Die `huggingface_hub`-Bibliothek bietet eine Python-Schnittstelle zum Erstellen, Teilen und Aktualisieren von Model Cards. Besuchen Sie [die spezielle Dokumentationsseite](https://huggingface.co/docs/hub/models-cards) für einen tieferen Einblick in das, was Model Cards im Hub sind und wie sie unter der Haube funktionieren. [Neu (Beta)! Probieren Sie unsere experimentelle Model Card Creator App aus](https://huggingface.co/spaces/huggingface/Model_Cards_Writing_Tool) ## Eine Model Card vom Hub laden Um eine bestehende Karte vom Hub zu laden, können Sie die Funktion [`ModelCard.load`] verwenden. Hier laden wir die Karte von [`nateraw/vit-base-beans`](https://huggingface.co/nateraw/vit-base-beans). ```python from huggingface_hub import ModelCard card = ModelCard.load('nateraw/vit-base-beans') ``` Diese Karte hat einige nützliche Attribute, auf die Sie zugreifen oder die Sie nutzen möchten: - `card.data`: Gibt eine [`ModelCardData`]-Instanz mit den Metadaten der Model Card zurück. Rufen Sie `.to_dict()` auf diese Instanz auf, um die Darstellung als Wörterbuch zu erhalten. - `card.text`: Gibt den Textinhalt der Karte *ohne den Metadatenkopf* zurück. - `card.content`: Gibt den Textinhalt der Karte, *einschließlich des Metadatenkopfes*, zurück. ## Model Cards erstellen ### Aus Text Um eine Model Card aus Text zu initialisieren, übergeben Sie einfach den Textinhalt der Karte an `ModelCard` beim Initialisieren. ```python content = """ --- language: en license: mit --- # My Model Card """ card = ModelCard(content) card.data.to_dict() == {'language': 'en', 'license': 'mit'} # True ``` Eine andere Möglichkeit besteht darin, dies mit f-Strings zu tun. Im folgenden Beispiel: - Verwenden wir [`ModelCardData.to_yaml`], um die von uns definierten Metadaten in YAML umzuwandeln, damit wir sie in die Model Card einfügen können. - Zeigen wir, wie Sie eine Vorlagenvariable über Python f-Strings verwenden könnten. ```python card_data = ModelCardData(language='en', license='mit', library='timm') example_template_var = 'nateraw' content = f""" --- { card_data.to_yaml() } --- # My Model Card This model created by [@{example_template_var}](https://github.com/{example_template_var}) """ card = ModelCard(content) print(card) ``` Das obige Beispiel würde uns eine Karte hinterlassen, die so aussieht: ``` --- language: en license: mit library: timm --- # My Model Card This model created by [@nateraw](https://github.com/nateraw) ``` ### Aus einem Jinja-Template Wenn Sie `Jinja2` installiert haben, können Sie Model Cards aus einer Jinja-Vorlagendatei erstellen. Schauen wir uns ein einfaches Beispiel an: ```python from pathlib import Path from huggingface_hub import ModelCard, ModelCardData # Definieren Sie Ihre Jinja-Vorlage template_text = """ --- {{ card_data }} --- # Model Card for MyCoolModel This model does this and that. This model was created by [@{{ author }}](https://hf.co/{{author}}). """.strip() # Schreiben Sie die Vorlage in eine Datei Path('custom_template.md').write_text(template_text) # Definieren Sie die Metadaten der Karte card_data = ModelCardData(language='en', license='mit', library_name='keras') # Erstellen Sie eine Karte aus der Vorlage und übergeben Sie dabei alle gewünschten Jinja-Vorlagenvariablen. # In unserem Fall übergeben wir "author" card = ModelCard.from_template(card_data, template_path='custom_template.md', author='nateraw') card.save('my_model_card_1.md') print(card) ``` Das resultierende Karten-Markdown sieht so aus: ``` --- language: en license: mit library_name: keras --- # Model Card for MyCoolModel This model does this and that. This model was created by [@nateraw](https://hf.co/nateraw). ``` Wenn Sie Daten in card.data aktualisieren, wird dies in der Karte selbst widergespiegelt. ``` card.data.library_name = 'timm' card.data.language = 'fr' card.data.license = 'apache-2.0' print(card) ``` Jetzt, wie Sie sehen können, wurde der Metadatenkopf aktualisiert: ``` --- language: fr license: apache-2.0 library_name: timm --- # Model Card for MyCoolModel This model does this and that. This model was created by [@nateraw](https://hf.co/nateraw). ``` Wenn Sie die Karteninformationen aktualisieren, können Sie durch Aufrufen von [`ModelCard.validate`] überprüfen, ob die Karte immer noch gültig für den Hub ist. Dies stellt sicher, dass die Karte alle Validierungsregeln erfüllt, die im Hugging Face Hub eingerichtet wurden. ### Aus dem Standard-Template Anstatt Ihr eigenes Template zu verwenden, können Sie auch das [Standard-Template](https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/modelcard_template.md) verwenden, welches eine vollständig ausgestattete Model Card mit vielen Abschnitten ist, die Sie vielleicht ausfüllen möchten. Unter der Haube verwendet es [Jinja2](https://jinja.palletsprojects.com/en/3.1.x/), um eine Vorlagendatei auszufüllen. Beachten Sie, dass Sie Jinja2 installiert haben müssen, um `from_template` zu verwenden. Sie können dies mit `pip install Jinja2` tun. ```python card_data = ModelCardData(language='en', license='mit', library_name='keras') card = ModelCard.from_template( card_data, model_id='my-cool-model', model_description="this model does this and that", developers="Nate Raw", repo="https://github.com/huggingface/huggingface_hub", ) card.save('my_model_card_2.md') print(card) ``` ## Model Cards teilen Wenn Sie mit dem Hugging Face Hub authentifiziert sind (entweder durch Verwendung von `huggingface-cli login` oder [`login`]), können Sie Karten zum Hub hinzufügen, indem Sie einfach [`ModelCard.push_to_hub`] aufrufen. Schauen wir uns an, wie das funktioniert... Zuerst erstellen wir ein neues Repo namens 'hf-hub-modelcards-pr-test' im Namensraum des authentifizierten Benutzers: ```python from huggingface_hub import whoami, create_repo user = whoami()['name'] repo_id = f'{user}/hf-hub-modelcards-pr-test' url = create_repo(repo_id, exist_ok=True) ``` Dann erstellen wir eine Karte aus der Standardvorlage (genau wie die oben definierte): ```python card_data = ModelCardData(language='en', license='mit', library_name='keras') card = ModelCard.from_template( card_data, model_id='my-cool-model', model_description="this model does this and that", developers="Nate Raw", repo="https://github.com/huggingface/huggingface_hub", ) ``` Schließlich laden wir das zum Hub hoch: ```python card.push_to_hub(repo_id) ``` Sie können die resultierende Karte [hier](https://huggingface.co/nateraw/hf-hub-modelcards-pr-test/blob/main/README.md) überprüfen. Wenn Sie eine Karte als Pull-Request hinzufügen möchten, können Sie beim Aufruf von `push_to_hub` einfach `create_pr=True` angeben: ```python card.push_to_hub(repo_id, create_pr=True) ``` Ein PR, der mit diesem Befehl erstellt wurde, kann [hier](https://huggingface.co/nateraw/hf-hub-modelcards-pr-test/discussions/3) aufgerufen werden. ### Evaluierungsergebnisse einbeziehen Um Evaluierungsergebnisse in den Metadaten `model-index` einzufügen, können Sie ein [`EvalResult`] oder eine Liste von `EvalResult` mit Ihren zugehörigen Evaluierungsergebnissen übergeben. Im Hintergrund wird der `model-index` erstellt, wenn Sie `card.data.to_dict()` aufrufen. Weitere Informationen darüber, wie dies funktioniert, finden Sie in [diesem Abschnitt der Hub-Dokumentation](https://huggingface.co/docs/hub/models-cards#evaluation-results). Beachten Sie, dass die Verwendung dieser Funktion erfordert, dass Sie das Attribut `model_name` in [`ModelCardData`] einbeziehen. ```python card_data = ModelCardData( language='en', license='mit', model_name='my-cool-model', eval_results = EvalResult( task_type='image-classification', dataset_type='beans', dataset_name='Beans', metric_type='accuracy', metric_value=0.7 ) ) card = ModelCard.from_template(card_data) print(card.data) ``` Die resultierende `card.data` sollte so aussehen: ``` language: en license: mit model-index: - name: my-cool-model results: - task: type: image-classification dataset: name: Beans type: beans metrics: - type: accuracy value: 0.7 ``` Wenn Sie mehr als ein Evaluierungsergebnis teilen möchten, übergeben Sie einfach eine Liste von `EvalResult`: ```python card_data = ModelCardData( language='en', license='mit', model_name='my-cool-model', eval_results = [ EvalResult( task_type='image-classification', dataset_type='beans', dataset_name='Beans', metric_type='accuracy', metric_value=0.7 ), EvalResult( task_type='image-classification', dataset_type='beans', dataset_name='Beans', metric_type='f1', metric_value=0.65 ) ] ) card = ModelCard.from_template(card_data) card.data ``` Dies sollte Ihnen die folgenden `card.data` hinterlassen: ``` language: en license: mit model-index: - name: my-cool-model results: - task: type: image-classification dataset: name: Beans type: beans metrics: - type: accuracy value: 0.7 - type: f1 value: 0.65 ``` huggingface_hub-0.31.1/docs/source/de/guides/overview.md000066400000000000000000000135601500667546600232150ustar00rootroot00000000000000 # Anleitungen In diesem Abschnitt finden Sie praktische Anleitungen, die Ihnen helfen, ein bestimmtes Ziel zu erreichen. Schauen Sie sich diese Anleitungen an, um zu lernen, wie Sie huggingface_hub verwenden, um reale Probleme zu lösen: huggingface_hub-0.31.1/docs/source/de/guides/repository.md000066400000000000000000000261021500667546600235620ustar00rootroot00000000000000 # Ein Repository erstellen und verwalten Das Hugging Face Hub besteht aus einer Sammlung von Git-Repositories. [Git](https://git-scm.com/) ist ein in der Softwareentwicklung weit verbreitetes Tool, um Projekte bei der Zusammenarbeit einfach zu versionieren. Dieser Leitfaden zeigt Ihnen, wie Sie mit den Repositories auf dem Hub interagieren, insbesondere: - Ein Repository erstellen und löschen. - Zweige (Branches) und Tags verwalten. - Ihr Repository umbenennen. - Die Sichtbarkeit Ihres Repositories aktualisieren. - Eine lokale Kopie Ihres Repositories verwalten. Wenn Sie es gewohnt sind, mit Plattformen wie GitLab/GitHub/Bitbucket zu arbeiten, könnte Ihr erster Instinkt sein, die `git` CLI zu verwenden, um Ihr Repo zu klonen (`git clone`), Änderungen zu übernehmen (`git add`, `git commit`) und diese hochzuladen (`git push`). Dies ist beim Verwenden des Hugging Face Hubs gültig. Softwareentwicklung und maschinelles Lernen haben jedoch nicht dieselben Anforderungen und Arbeitsabläufe. Modell-Repositories könnten große Modellgewichtsdateien für verschiedene Frameworks und Tools beinhalten, sodass das Klonen des Repositories dazu führen kann, dass Sie große lokale Ordner mit massiven Größen pflegen. Daher kann es effizienter sein, unsere benutzerdefinierten HTTP-Methoden zu verwenden. Sie können unsere [Git vs HTTP Paradigma](../concepts/git_vs_http) Erklärungsseite für weitere Details lesen. Wenn Sie ein Repository auf dem Hub erstellen und verwalten möchten, muss Ihr Computer angemeldet sein. Wenn Sie es nicht sind, beziehen Sie sich bitte auf [diesen Abschnitt](../quick-start#login). Im Rest dieses Leitfadens gehen wir davon aus, dass Ihr Computer angemeldet ist. ## Erstellung und Löschung von Repos Der erste Schritt besteht darin, zu wissen, wie man Repositories erstellt und löscht. Sie können nur Repositories verwalten, die Ihnen gehören (unter Ihrem Benutzernamensraum) oder von Organisationen, in denen Sie Schreibberechtigungen haben. ### Ein Repository erstellen Erstellen Sie ein leeres Repository mit [`create_repo`] und geben Sie ihm mit dem Parameter `repo_id` einen Namen. Die `repo_id` ist Ihr Namensraum gefolgt vom Repository-Namen: `username_or_org/repo_name`. ```py >>> from huggingface_hub import create_repo >>> create_repo("lysandre/test-model") 'https://huggingface.co/lysandre/test-model' ``` Standardmäßig erstellt [`create_repo`] ein Modellrepository. Sie können jedoch den Parameter `repo_type` verwenden, um einen anderen Repository-Typ anzugeben. Wenn Sie beispielsweise ein Dataset-Repository erstellen möchten: ```py >>> from huggingface_hub import create_repo >>> create_repo("lysandre/test-dataset", repo_type="dataset") 'https://huggingface.co/datasets/lysandre/test-dataset' ``` Wenn Sie ein Repository erstellen, können Sie mit dem Parameter `private` die Sichtbarkeit Ihres Repositories festlegen. ```py >>> from huggingface_hub import create_repo >>> create_repo("lysandre/test-private", private=True) ``` Wenn Sie die Sichtbarkeit des Repositories zu einem späteren Zeitpunkt ändern möchten, können Sie die Funktion [`update_repo_settings`] verwenden. ### Ein Repository löschen Löschen Sie ein Repository mit [`delete_repo`]. Stellen Sie sicher, dass Sie ein Repository löschen möchten, da dieser Vorgang unwiderruflich ist! Geben Sie die `repo_id` des Repositories an, das Sie löschen möchten: ```py >>> delete_repo(repo_id="lysandre/my-corrupted-dataset", repo_type="dataset") ``` ### Ein Repository duplizieren (nur für Spaces) In einigen Fällen möchten Sie möglicherweise das Repo von jemand anderem kopieren, um es an Ihren Anwendungsfall anzupassen. Dies ist für Spaces mit der Methode [`duplicate_space`] möglich. Es wird das gesamte Repository dupliziert. Sie müssen jedoch noch Ihre eigenen Einstellungen konfigurieren (Hardware, Schlafzeit, Speicher, Variablen und Geheimnisse). Weitere Informationen finden Sie in unserem Leitfaden [Verwalten Ihres Spaces](./manage-spaces). ```py >>> from huggingface_hub import duplicate_space >>> duplicate_space("multimodalart/dreambooth-training", private=False) RepoUrl('https://huggingface.co/spaces/nateraw/dreambooth-training',...) ``` ## Dateien hochladen und herunterladen Jetzt, wo Sie Ihr Repository erstellt haben, möchten Sie Änderungen daran vornehmen und Dateien daraus herunterladen. Diese 2 Themen verdienen ihre eigenen Leitfäden. Bitte beziehen Sie sich auf die [Hochladen](./upload) und die [Herunterladen](./download) Leitfäden, um zu erfahren, wie Sie Ihr Repository verwenden können. ## Branches und Tags Git-Repositories verwenden oft Branches, um verschiedene Versionen eines gleichen Repositories zu speichern. Tags können auch verwendet werden, um einen bestimmten Zustand Ihres Repositories zu kennzeichnen, z. B. bei der Veröffentlichung einer Version. Allgemeiner gesagt, werden Branches und Tags als [git-Referenzen](https://git-scm.com/book/en/v2/Git-Internals-Git-References) bezeichnet. ### Branches und Tags erstellen Sie können neue Branches und Tags mit [`create_branch`] und [`create_tag`] erstellen: ```py >>> from huggingface_hub import create_branch, create_tag # Erstellen Sie einen Branch auf einem Space-Repo vom `main` Branch >>> create_branch("Matthijs/speecht5-tts-demo", repo_type="space", branch="handle-dog-speaker") # Erstellen Sie einen Tag auf einem Dataset-Repo vom `v0.1-release` Branch >>> create_branch("bigcode/the-stack", repo_type="dataset", revision="v0.1-release", tag="v0.1.1", tag_message="Bump release version.") ``` Sie können die Funktionen [`delete_branch`] und [`delete_tag`] auf die gleiche Weise verwenden, um einen Branch oder einen Tag zu löschen. ### Alle Branches und Tags auflisten Sie können auch die vorhandenen git-Referenzen von einem Repository mit [`list_repo_refs`] auflisten: ```py >>> from huggingface_hub import list_repo_refs >>> list_repo_refs("bigcode/the-stack", repo_type="dataset") GitRefs( branches=[ GitRefInfo(name='main', ref='refs/heads/main', target_commit='18edc1591d9ce72aa82f56c4431b3c969b210ae3'), GitRefInfo(name='v1.1.a1', ref='refs/heads/v1.1.a1', target_commit='f9826b862d1567f3822d3d25649b0d6d22ace714') ], converts=[], tags=[ GitRefInfo(name='v1.0', ref='refs/tags/v1.0', target_commit='c37a8cd1e382064d8aced5e05543c5f7753834da') ] ) ``` ## Repository-Einstellungen ändern Repositories verfügen über einige Einstellungen, die Sie konfigurieren können. Die meiste Zeit möchten Sie dies manuell auf der Repo-Einstellungsseite in Ihrem Browser tun. Sie müssen Schreibzugriff auf ein Repo haben, um es zu konfigurieren (entweder besitzen oder Teil einer Organisation sein). In diesem Abschnitt werden wir die Einstellungen sehen, die Sie auch programmgesteuert mit `huggingface_hub` konfigurieren können. Einige Einstellungen sind spezifisch für Spaces (Hardware, Umgebungsvariablen,...). Um diese zu konfigurieren, lesen Sie bitte unseren [Verwalten Ihres Spaces](../guides/manage-spaces) Leitfaden. ### Sichtbarkeit aktualisieren Ein Repository kann öffentlich oder privat sein. Ein privates Repository ist nur für Sie oder die Mitglieder der Organisation sichtbar, in der das Repository sich befindet. Ändern Sie ein Repository wie im Folgenden gezeigt in ein privates: ```py >>> from huggingface_hub import update_repo_settings >>> update_repo_settings(repo_id=repo_id, private=True) ``` ### Benennen Sie Ihr Repository um Sie können Ihr Repository auf dem Hub mit [`move_repo] umbenennen. Mit dieser Methode können Sie das Repo auch von einem Benutzer zu einer Organisation verschieben. Dabei gibt es [einige Einschränkungen](https://hf.co/docs/hub/repositories-settings#renaming-or-transferring-a-repo), die Sie beachten sollten. Zum Beispiel können Sie Ihr Repo nicht an einen anderen Benutzer übertragen. ```py >>> from huggingface_hub import move_repo >>> move_repo(from_id="Wauplin/cool-model", to_id="huggingface/cool-model") ``` ## Verwalten Sie eine lokale Kopie Ihres Repositories Alle oben beschriebenen Aktionen können mit HTTP-Anfragen durchgeführt werden. In einigen Fällen möchten Sie jedoch vielleicht eine lokale Kopie Ihres Repositories haben und damit interagieren, indem Sie die Git-Befehle verwenden, die Sie kennen. Die [`Repository`] Klasse ermöglicht es Ihnen, mit Dateien und Repositories auf dem Hub mit Funktionen zu interagieren, die Git-Befehlen ähneln. Es ist ein Wrapper über Git und Git-LFS-Methoden, um die Git-Befehle zu verwenden, die Sie bereits kennen und lieben. Stellen Sie vor dem Start sicher, dass Sie Git-LFS installiert haben (siehe [hier](https://git-lfs.github.com/) für Installationsanweisungen). ### Verwenden eines lokalen Repositories Instanziieren Sie ein [`Repository`] Objekt mit einem Pfad zu einem lokalen Repository: ```py >>> from huggingface_hub import Repository >>> repo = Repository(local_dir="//") ``` ### Klonen Der `clone_from` Parameter klont ein Repository von einer Hugging Face Repository-ID in ein lokales Verzeichnis, das durch das Argument `local_dir` angegeben wird: ```py >>> from huggingface_hub import Repository >>> repo = Repository(local_dir="w2v2", clone_from="facebook/wav2vec2-large-960h-lv60") ``` `clone_from` kann auch ein Repository mit einer URL klonen: ```py >>> repo = Repository(local_dir="huggingface-hub", clone_from="https://huggingface.co/facebook/wav2vec2-large-960h-lv60") ``` Sie können den `clone_from` Parameter mit [`create_repo`] kombinieren, um ein Repository zu erstellen und zu klonen: ```py >>> repo_url = create_repo(repo_id="repo_name") >>> repo = Repository(local_dir="repo_local_path", clone_from=repo_url) ``` Sie können auch einen Git-Benutzernamen und eine E-Mail zu einem geklonten Repository konfigurieren, indem Sie die Parameter `git_user` und `git_email` beim Klonen eines Repositories angeben. Wenn Benutzer Änderungen in diesem Repository committen, wird Git über den Autor des Commits informiert sein. ```py >>> repo = Repository( ... "my-dataset", ... clone_from="/", ... token=True, ... repo_type="dataset", ... git_user="MyName", ... git_email="me@cool.mail" ... ) ``` ### Branch Branches sind wichtig für die Zusammenarbeit und das Experimentieren, ohne Ihre aktuellen Dateien und Codes zu beeinflussen. Wechseln Sie zwischen den Branches mit [`~Repository.git_checkout`]. Wenn Sie beispielsweise von `branch1` zu `branch2` wechseln möchten: ```py >>> from huggingface_hub import Repository >>> repo = Repository(local_dir="huggingface-hub", clone_from="/", revision='branch1') >>> repo.git_checkout("branch2") ``` ### Pull Mit [`~Repository.git_pull`] können Sie eine aktuelle lokale Branch mit Änderungen aus einem Remote-Repository aktualisieren: ```py >>> from huggingface_hub import Repository >>> repo.git_pull() ``` Setzen Sie `rebase=True`, wenn Sie möchten, dass Ihre lokalen Commits nach dem Aktualisieren Ihres Zweigs mit den neuen Commits aus dem Remote erfolgen: ```py >>> repo.git_pull(rebase=True) ``` huggingface_hub-0.31.1/docs/source/de/guides/search.md000066400000000000000000000050521500667546600226110ustar00rootroot00000000000000 # Den Hub durchsuchen In diesem Tutorial lernen Sie, wie Sie Modelle, Datensätze und Spaces auf dem Hub mit `huggingface_hub` durchsuchen können. ## Wie listet man Repositories auf? Die `huggingface_hub`-Bibliothek enthält einen HTTP-Client [`HfApi`], um mit dem Hub zu interagieren. Unter anderem kann er Modelle, Datensätze und Spaces auflisten, die auf dem Hub gespeichert sind: ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> models = api.list_models() ``` Die Ausgabe von [`list_models`] ist ein Iterator über die auf dem Hub gespeicherten Modelle. Ähnlich können Sie [`list_datasets`] verwenden, um Datensätze aufzulisten und [`list_spaces`], um Spaces aufzulisten. ## Wie filtert man Repositories? Das Auflisten von Repositories ist großartig, aber jetzt möchten Sie vielleicht Ihre Suche filtern. Die List-Helfer haben mehrere Attribute wie: - `filter` - `author` - `search` - ... Zwei dieser Parameter sind intuitiv (`author` und `search`), aber was ist mit diesem `filter`? `filter` nimmt als Eingabe ein [`ModelFilter`]-Objekt (oder [`DatasetFilter`]) entgegen. Sie können es instanziieren, indem Sie angeben, welche Modelle Sie filtern möchten. Hier ist ein Beispiel, um alle Modelle auf dem Hub zu erhalten, die Bildklassifizierung durchführen, auf dem Imagenet-Datensatz trainiert wurden und mit PyTorch laufen. Das kann mit einem einzigen [`ModelFilter`] erreicht werden. Attribute werden als "logisches UND" kombiniert. ```py models = hf_api.list_models( filter=ModelFilter( task="image-classification", library="pytorch", trained_dataset="imagenet" ) ) ``` Während des Filterns können Sie auch die Modelle sortieren und nur die Top-Ergebnisse abrufen. Zum Beispiel holt das folgende Beispiel die 5 am häufigsten heruntergeladenen Datensätze auf dem Hub: ```py >>> list(list_datasets(sort="downloads", direction=-1, limit=5)) [DatasetInfo( id='argilla/databricks-dolly-15k-curated-en', author='argilla', sha='4dcd1dedbe148307a833c931b21ca456a1fc4281', last_modified=datetime.datetime(2023, 10, 2, 12, 32, 53, tzinfo=datetime.timezone.utc), private=False, downloads=8889377, (...) ``` Eine andere Möglichkeit, dies zu tun, besteht darin, die [Modelle](https://huggingface.co/models) und [Datensätze](https://huggingface.co/datasets) Seiten in Ihrem Browser zu besuchen, nach einigen Parametern zu suchen und die Werte in der URL anzusehen. huggingface_hub-0.31.1/docs/source/de/guides/upload.md000066400000000000000000001011201500667546600226210ustar00rootroot00000000000000 # Dateien auf den Hub hochladen Das Teilen Ihrer Dateien und Arbeiten ist ein wichtiger Aspekt des Hubs. Das `huggingface_hub` bietet mehrere Optionen, um Ihre Dateien auf den Hub hochzuladen. Sie können diese Funktionen unabhängig verwenden oder sie in Ihre Bibliothek integrieren, um es Ihren Benutzern zu erleichtern, mit dem Hub zu interagieren. In dieser Anleitung erfahren Sie, wie Sie Dateien hochladen: - ohne Git zu verwenden. - mit [Git LFS](https://git-lfs.github.com/) wenn die Dateien sehr groß sind. - mit dem `commit`-Context-Manager. - mit der Funktion [`~Repository.push_to_hub`]. Wenn Sie Dateien auf den Hub hochladen möchten, müssen Sie sich bei Ihrem Hugging Face-Konto anmelden: - Melden Sie sich bei Ihrem Hugging Face-Konto mit dem folgenden Befehl an: ```bash huggingface-cli login # oder mit einer Umgebungsvariable huggingface-cli login --token $HUGGINGFACE_TOKEN ``` - Alternativ können Sie sich in einem Notebook oder einem Skript programmatisch mit [`login`] anmelden: ```python >>> from huggingface_hub import login >>> login() ``` Wenn es in einem Jupyter- oder Colaboratory-Notebook ausgeführt wird, startet [`login`] ein Widget, über das Sie Ihren Hugging Face-Zugriffstoken eingeben können. Andernfalls wird eine Meldung im Terminal angezeigt. Es ist auch möglich, sich programmatisch ohne das Widget anzumelden, indem Sie den Token direkt an [`login`] übergeben. Seien Sie jedoch vorsichtig, wenn Sie Ihr Notebook teilen. Es ist am Besten, den Token aus einem sicheren Passwortspeicher zu laden, anstatt ihn in Ihrem Colaboratory-Notebook zu speichern. ## Datei hochladen Sobald Sie ein Repository mit [`create_repo`] erstellt haben, können Sie mit [`upload_file`] eine Datei in Ihr Repository hochladen. Geben Sie den Pfad der hochzuladenden Datei, den Ort, an den Sie die Datei im Repository hochladen möchten, und den Namen des Repositories an, zu dem Sie die Datei hinzufügen möchten. Abhängig von Ihrem Repository-Typ können Sie optional den Repository-Typ als `dataset`, `model`, oder `space` festlegen. ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.upload_file( ... path_or_fileobj="/path/to/local/folder/README.md", ... path_in_repo="README.md", ... repo_id="username/test-dataset", ... repo_type="dataset", ... ) ``` ## Ordner hochladen Verwenden Sie die Funktion [`upload_folder`], um einen lokalen Ordner in ein vorhandenes Repository hochzuladen. Geben Sie den Pfad des lokalen Ordners an, den Sie hochladen möchten, an welchem Ort Sie den Ordner im Repository hochladen möchten, und den Namen des Repositories, zu dem Sie den Ordner hinzufügen möchten. Abhängig von Ihrem Repository-Typ können Sie optional den Repository-Typ als `dataset`, `model`, oder `space` festlegen. ```py >>> from huggingface_hub import HfApi >>> api = HfApi() # Den gesamten Inhalt aus dem lokalen Ordner in den entfernten Space hoch laden. # Standardmäßig werden Dateien im Hauptverzeichnis des Repos hochgeladen >>> api.upload_folder( ... folder_path="/path/to/local/space", ... repo_id="username/my-cool-space", ... repo_type="space", ... ) ``` Verwenden Sie die Argumente `allow_patterns` und `ignore_patterns`, um anzugeben, welche Dateien hochgeladen werden sollen. Diese Parameter akzeptieren entweder ein einzelnes Muster oder eine Liste von Mustern. Muster sind Standard-Wildcards (globbing patterns) wie [hier](https://tldp.org/LDP/GNU-Linux-Tools-Summary/html/x11655.htm) dokumentiert. Wenn sowohl `allow_patterns` als auch `ignore_patterns` angegeben werden, gelten beide Einschränkungen. Standardmäßig werden alle Dateien aus dem Ordner hochgeladen. Jeder `.git/`-Ordner in einem Unterverzeichnis wird ignoriert. Bitte beachten Sie jedoch, dass die `.gitignore`-Datei nicht berücksichtigt wird. Dies bedeutet, dass Sie `allow_patterns` und `ignore_patterns` verwenden müssen, um anzugeben, welche Dateien stattdessen hochgeladen werden sollen. ```py >>> api.upload_folder( ... folder_path="/path/to/local/folder", ... path_in_repo="my-dataset/train", # Hochladen in einen bestimmten Ordner ... repo_id="username/test-dataset", ... repo_type="dataset", ... ignore_patterns="**/logs/*.txt", # Alle Textprotokolle ignorieren ... ) ``` Sie können auch das Argument `delete_patterns` verwenden, um Dateien anzugeben, die Sie im selben Commit aus dem Repo löschen möchten. Dies kann nützlich sein, wenn Sie einen entfernten Ordner reinigen möchten, bevor Sie Dateien darin ablegen und nicht wissen, welche Dateien bereits vorhanden sind. Im folgenden Beispiel wird der lokale Ordner `./logs` in den entfernten Ordner `/experiment/logs/` hochgeladen. Es werden nur txt-Dateien hochgeladen, aber davor werden alle vorherigen Protokolle im Repo gelöscht. All dies in einem einzigen Commit. ```py >>> api.upload_folder( ... folder_path="/path/to/local/folder/logs", ... repo_id="username/trained-model", ... path_in_repo="experiment/logs/", ... allow_patterns="*.txt", # Alle lokalen Textdateien hochladen ... delete_patterns="*.txt", # Vorher alle enfernten Textdateien löschen ... ) ``` ## Erweiterte Funktionen In den meisten Fällen benötigen Sie nicht mehr als [`upload_file`] und [`upload_folder`], um Ihre Dateien auf den Hub hochzuladen. Das `huggingface_hub` bietet jedoch fortschrittlichere Funktionen, um die Dinge einfacher zu machen. Schauen wir sie uns an! ### Nicht blockierende Uploads In einigen Fällen möchten Sie Daten hochladen, ohne Ihren Hauptthread zu blockieren. Dies ist besonders nützlich, um Protokolle und Artefakte hochzuladen, während Sie weiter trainieren. Um dies zu tun, können Sie das Argument `run_as_future` in beiden [`upload_file`] und [`upload_folder`] verwenden. Dies gibt ein [`concurrent.futures.Future`](https://docs.python.org/3/library/concurrent.futures.html#future-objects)-Objekt zurück, mit dem Sie den Status des Uploads überprüfen können. ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> future = api.upload_folder( # Hochladen im Hintergrund (nicht blockierende Aktion) ... repo_id="username/my-model", ... folder_path="checkpoints-001", ... run_as_future=True, ... ) >>> future Future(...) >>> future.done() False >>> future.result() # Warten bis der Upload abgeschlossen ist (blockierende Aktion) ... ``` Hintergrund-Aufgaben werden in die Warteschlange gestellt, wenn `run_as_future=True` verwendet wird. Das bedeutet, dass garantiert wird, dass die Aufgaben in der richtigen Reihenfolge ausgeführt werden. Auch wenn Hintergrundaufgaben hauptsächlich dazu dienen, Daten hochzuladen/Commits zu erstellen, können Sie jede gewünschte Methode in die Warteschlange stellen, indem Sie [`run_as_future`] verwenden. Sie können es beispielsweise verwenden, um ein Repo zu erstellen und dann Daten im Hintergrund dorthin hochzuladen. Das integrierte Argument `run_as_future` in Upload-Methoden ist lediglich ein Alias dafür. ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.run_as_future(api.create_repo, "username/my-model", exists_ok=True) Future(...) >>> api.upload_file( ... repo_id="username/my-model", ... path_in_repo="file.txt", ... path_or_fileobj=b"file content", ... run_as_future=True, ... ) Future(...) ``` ### Ordner in Teilen hochladen Mit [`upload_folder`] können Sie ganz einfach einen gesamten Ordner ins Hub hochladen. Bei großen Ordnern (Tausende von Dateien oder Hunderte von GB) kann dies jedoch immer noch herausfordernd sein. Wenn Sie einen Ordner mit vielen Dateien haben, möchten Sie ihn möglicherweise in mehreren Commits hochladen. Wenn während des Uploads ein Fehler oder ein Verbindungsproblem auftritt, müssen Sie den Vorgang nicht von Anfang an wiederholen. Um einen Ordner in mehreren Commits hochzuladen, übergeben Sie einfach `multi_commits=True` als Argument. Intern wird `huggingface_hub` die hochzuladenden/zu löschenden Dateien auflisten und sie in mehrere Commits aufteilen. Die "Strategie" (d.h. wie die Commits aufgeteilt werden) basiert auf der Anzahl und Größe der hochzuladenden Dateien. Ein PR wird im Hub geöffnet, um alle Commits zu pushen. Sobald der PR bereit ist, werden die Commits zu einem einzigen Commit zusammengefasst. Wenn der Prozess unterbrochen wird, bevor er abgeschlossen ist, können Sie Ihr Skript erneut ausführen, um den Upload fortzusetzen. Der erstellte PR wird automatisch erkannt und der Upload setzt dort fort, wo er gestoppt wurde. Es wird empfohlen, `multi_commits_verbose=True` zu übergeben, um ein besseres Verständnis für den Upload und dessen Fortschritt zu erhalten. Das untenstehende Beispiel lädt den Ordner "checkpoints" in ein Dataset in mehreren Commits hoch. Ein PR wird im Hub erstellt und automatisch zusammengeführt, sobald der Upload abgeschlossen ist. Wenn Sie möchten, dass der PR offen bleibt und Sie ihn manuell überprüfen können, übergeben Sie `create_pr=True`. ```py >>> upload_folder( ... folder_path="local/checkpoints", ... repo_id="username/my-dataset", ... repo_type="dataset", ... multi_commits=True, ... multi_commits_verbose=True, ... ) ``` Wenn Sie die Upload-Strategie besser steuern möchten (d.h. die erstellten Commits), können Sie sich die Low-Level-Methoden [`plan_multi_commits`] und [`create_commits_on_pr`] ansehen. `multi_commits` ist noch ein experimentelles Feature. Seine API und sein Verhalten können in Zukunft ohne vorherige Ankündigung geändert werden. ### Geplante Uploads Das Hugging Face Hub erleichtert das Speichern und Versionieren von Daten. Es gibt jedoch einige Einschränkungen, wenn Sie dieselbe Datei Tausende von Malen aktualisieren möchten. Sie möchten beispielsweise Protokolle eines Trainingsprozesses oder Benutzerfeedback in einem bereitgestellten Space speichern. In diesen Fällen macht es Sinn, die Daten als Dataset im Hub hochzuladen, aber es kann schwierig sein, dies richtig zu machen. Der Hauptgrund ist, dass Sie nicht jede Aktualisierung Ihrer Daten versionieren möchten, da dies das git-Repository unbrauchbar machen würde. Die Klasse [`CommitScheduler`] bietet eine Lösung für dieses Problem. Die Idee besteht darin, einen Hintergrundjob auszuführen, der regelmäßig einen lokalen Ordner ins Hub schiebt. Nehmen Sie an, Sie haben einen Gradio Space, der als Eingabe einen Text nimmt und zwei Übersetzungen davon generiert. Der Benutzer kann dann seine bevorzugte Übersetzung auswählen. Für jeden Durchlauf möchten Sie die Eingabe, Ausgabe und Benutzerpräferenz speichern, um die Ergebnisse zu analysieren. Dies ist ein perfekter Anwendungsfall für [`CommitScheduler`]; Sie möchten Daten ins Hub speichern (potenziell Millionen von Benutzerfeedbacks), aber Sie müssen nicht in Echtzeit jede Benutzereingabe speichern. Stattdessen können Sie die Daten lokal in einer JSON-Datei speichern und sie alle 10 Minuten hochladen. Zum Beispiel: ```py >>> import json >>> import uuid >>> from pathlib import Path >>> import gradio as gr >>> from huggingface_hub import CommitScheduler # Definieren Sie die Datei, in der die Daten gespeichert werden sollen. Verwenden Sie UUID, um sicherzustellen, dass vorhandene Daten aus einem früheren Lauf nicht überschrieben werden. >>> feedback_file = Path("user_feedback/") / f"data_{uuid.uuid4()}.json" >>> feedback_folder = feedback_file.parent # Geplante regelmäßige Uploads. Das Remote-Repo und der lokale Ordner werden erstellt, wenn sie noch nicht existieren. >>> scheduler = CommitScheduler( ... repo_id="report-translation-feedback", ... repo_type="dataset", ... folder_path=feedback_folder, ... path_in_repo="data", ... every=10, ... ) # Eine einfache Gradio-Anwendung, die einen Text als Eingabe nimmt und zwei Übersetzungen generiert. Der Benutzer wählt seine bevorzugte Übersetzung aus. >>> def save_feedback(input_text:str, output_1: str, output_2:str, user_choice: int) -> None: ... """ ... Füge Eingabe/Ausgabe und Benutzerfeedback zu einer JSON-Lines-Datei hinzu und verwende ein Thread-Lock, um gleichzeitiges Schreiben von verschiedenen Benutzern zu vermeiden. ... """ ... with scheduler.lock: ... with feedback_file.open("a") as f: ... f.write(json.dumps({"input": input_text, "output_1": output_1, "output_2": output_2, "user_choice": user_choice})) ... f.write("\n") # Starte Gradio >>> with gr.Blocks() as demo: >>> ... # Definiere Gradio Demo + verwende `save_feedback` >>> demo.launch() ``` Und das war's! Benutzereingabe/-ausgaben und Feedback sind als Dataset auf dem Hub verfügbar. Durch die Verwendung eines eindeutigen JSON-Dateinamens können Sie sicher sein, dass Sie keine Daten von einem vorherigen Lauf oder Daten von anderen Spaces/Replikas überschreiben, die gleichzeitig in dasselbe Repository pushen. Für weitere Details über den [`CommitScheduler`], hier das Wichtigste: - **append-only / Nur hinzufügen:** Es wird davon ausgegangen, dass Sie nur Inhalte zum Ordner hinzufügen. Sie dürfen nur Daten zu bestehenden Dateien hinzufügen oder neue Dateien erstellen. Das Löschen oder Überschreiben einer Datei könnte Ihr Repository beschädigen. - **git history / git Historie**: Der Scheduler wird den Ordner alle every Minuten committen. Um das Git-Repository nicht zu überladen, wird empfohlen, einen minimalen Wert von 5 Minuten festzulegen. Außerdem ist der Scheduler darauf ausgelegt, leere Commits zu vermeiden. Wenn im Ordner kein neuer Inhalt erkannt wird, wird der geplante Commit verworfen. - **errors / Fehler**: Der Scheduler läuft als Hintergrund-Thread. Er wird gestartet, wenn Sie die Klasse instanziieren, und stoppt nie. Insbesondere, wenn während des Uploads ein Fehler auftritt (z. B. Verbindungsproblem), wird der Scheduler ihn stillschweigend ignorieren und beim nächsten geplanten Commit erneut versuchen. - **thread-safety / Thread-Sicherheit**: In den meisten Fällen können Sie davon ausgehen, dass Sie eine Datei schreiben können, ohne sich um eine Lock-Datei kümmern zu müssen. Der Scheduler wird nicht abstürzen oder beschädigt werden, wenn Sie Inhalte in den Ordner schreiben, während er hochlädt. In der Praxis ist es möglich, dass bei stark ausgelasteten Apps Probleme mit der Parallelität auftreten. In diesem Fall empfehlen wir, das `scheduler.lock` Lock zu verwenden, um die Thread-Sicherheit zu gewährleisten. Das Lock wird nur gesperrt, wenn der Scheduler den Ordner auf Änderungen überprüft, nicht beim Hochladen von Daten. Sie können sicher davon ausgehen, dass dies das Benutzererlebnis in Ihrem Space nicht beeinflusst. #### Space Persistenz-Demo Das Speichern von Daten aus einem Space in einem Dataset auf dem Hub ist der Hauptanwendungsfall für den [`CommitScheduler`]. Je nach Anwendungsfall möchten Sie Ihre Daten möglicherweise anders strukturieren. Die Struktur muss robust gegenüber gleichzeitigen Benutzern und Neustarts sein, was oft das Generieren von UUIDs impliziert. Neben der Robustheit sollten Sie Daten in einem Format hochladen, das von der 🤗 Datasets-Bibliothek für die spätere Wiederverwendung gelesen werden kann. Wir haben einen [Space](https://huggingface.co/spaces/Wauplin/space_to_dataset_saver) erstellt, der zeigt, wie man verschiedene Datenformate speichert (dies muss möglicherweise für Ihre speziellen Bedürfnisse angepasst werden). #### Benutzerdefinierte Uploads [`CommitScheduler`] geht davon aus, dass Ihre Daten nur hinzugefügt werden und "wie sie sind" hochgeladen werden sollten. Sie möchten jedoch möglicherweise anpassen, wie Daten hochgeladen werden. Dies können Sie tun, indem Sie eine Klasse erstellen, die vom [`CommitScheduler`] erbt und die Methode `push_to_hub` überschreibt (fühlen Sie sich frei, sie nach Belieben zu überschreiben). Es ist garantiert, dass sie alle `every` Minuten in einem Hintergrund-Thread aufgerufen wird. Sie müssen sich keine Gedanken über Parallelität und Fehler machen, aber Sie müssen vorsichtig sein bei anderen Aspekten, wie z. B. dem Pushen von leeren Commits oder doppelten Daten. Im folgenden (vereinfachten) Beispiel überschreiben wir `push_to_hub`, um alle PNG-Dateien in einem einzigen Archiv zu zippen, um das Repo auf dem Hub nicht zu überladen: ```py class ZipScheduler(CommitScheduler): def push_to_hub(self): # 1. Liste PNG-Dateien auf png_files = list(self.folder_path.glob("*.png")) if len(png_files) == 0: return None # kehre früh zurück, wenn nichts zu committen ist # 2. Zippe PNG-Dateien in ein einzelnes Archiv with tempfile.TemporaryDirectory() as tmpdir: archive_path = Path(tmpdir) / "train.zip" with zipfile.ZipFile(archive_path, "w", zipfile.ZIP_DEFLATED) as zip: for png_file in png_files: zip.write(filename=png_file, arcname=png_file.name) # 3. Lade das Archiv hoch self.api.upload_file(..., path_or_fileobj=archive_path) # 4. Lösche lokale PNG-Dateien, um späteres erneutes Hochladen zu vermeiden for png_file in png_files: png_file.unlink() ``` Wenn Sie `push_to_hub` überschreiben, haben Sie Zugriff auf die Attribute vom [`CommitScheduler`] und insbesondere: - [`HfApi`] Client: `api` - Ordnerparameter: `folder_path` und `path_in_repo` - Repo-Parameter: `repo_id`, `repo_type`, `revision` - Das Thread-Lock: `lock` Für weitere Beispiele von benutzerdefinierten Schedulern, schauen Sie sich unseren [Demo Space](https://huggingface.co/spaces/Wauplin/space_to_dataset_saver) an, der verschiedene Implementierungen je nach Ihren Anforderungen enthält. ### create_commit Die Funktionen [`upload_file`] und [`upload_folder`] sind High-Level-APIs, die im Allgemeinen bequem zu verwenden sind. Wir empfehlen, diese Funktionen zuerst auszuprobieren, wenn Sie nicht auf einer niedrigeren Ebene arbeiten müssen. Wenn Sie jedoch auf Commit-Ebene arbeiten möchten, können Sie die Funktion [`create_commit`] direkt verwenden. Es gibt drei von [`create_commit`] unterstützte Operationstypen: - [`CommitOperationAdd`] lädt eine Datei in den Hub hoch. Wenn die Datei bereits existiert, werden die Dateiinhalte überschrieben. Diese Operation akzeptiert zwei Argumente: - `path_in_repo`: der Repository-Pfad, um eine Datei hochzuladen. - `path_or_fileobj`: entweder ein Pfad zu einer Datei auf Ihrem Dateisystem oder ein Datei-ähnliches Objekt. Dies ist der Inhalt der Datei, die auf den Hub hochgeladen werden soll. - [`CommitOperationDelete`] entfernt eine Datei oder einen Ordner aus einem Repository. Diese Operation akzeptiert path_in_repo als Argument. - [`CommitOperationCopy`] kopiert eine Datei innerhalb eines Repositorys. Diese Operation akzeptiert drei Argumente: - `src_path_in_repo`: der Repository-Pfad der zu kopierenden Datei. - `path_in_repo`: der Repository-Pfad, wohin die Datei kopiert werden soll. - `src_revision`: optional - die Revision der zu kopierenden Datei, wenn Sie eine Datei von einem anderen Branch/Revision kopieren möchten. Zum Beispiel, wenn Sie zwei Dateien hochladen und eine Datei in einem Hub-Repository löschen möchten: 1. Verwenden Sie die entsprechende `CommitOperation`, um eine Datei hinzuzufügen oder zu löschen und einen Ordner zu löschen: ```py >>> from huggingface_hub import HfApi, CommitOperationAdd, CommitOperationDelete >>> api = HfApi() >>> operations = [ ... CommitOperationAdd(path_in_repo="LICENSE.md", path_or_fileobj="~/repo/LICENSE.md"), ... CommitOperationAdd(path_in_repo="weights.h5", path_or_fileobj="~/repo/weights-final.h5"), ... CommitOperationDelete(path_in_repo="old-weights.h5"), ... CommitOperationDelete(path_in_repo="logs/"), ... CommitOperationCopy(src_path_in_repo="image.png", path_in_repo="duplicate_image.png"), ... ] ``` 2. Übergeben Sie Ihre Operationen an [`create_commit`]: ```py >>> api.create_commit( ... repo_id="lysandre/test-model", ... operations=operations, ... commit_message="Hochladen meiner Modell-Gewichte und -Lizenz", ... ) ``` Zusätzlich zu [`upload_file`] und [`upload_folder`] verwenden auch die folgenden Funktionen [`create_commit`] im Hintergrund: - [`delete_file`] löscht eine einzelne Datei aus einem Repository auf dem Hub. - [`delete_folder`] löscht einen gesamten Ordner aus einem Repository auf dem Hub. - [`metadata_update`] aktualisiert die Metadaten eines Repositorys. Für detailliertere Informationen werfen Sie einen Blick auf die [`HfApi`] Referenz. ## Tipps und Tricks für große Uploads Bei der Verwaltung einer großen Datenmenge in Ihrem Repo gibt es einige Einschränkungen zu beachten. Angesichts der Zeit, die es dauert, die Daten zu streamen, kann es sehr ärgerlich sein, am Ende des Prozesses einen Upload/Push zu verlieren oder eine degradierte Erfahrung zu machen, sei es auf hf.co oder bei lokalem Arbeiten. Wir haben eine Liste von Tipps und Empfehlungen zusammengestellt, um Ihr Repo zu strukturieren. | Eigenschaft | Empfohlen | Tipps | | ---------------- | ------------------ | ------------------------------------------------------ | | Repo-Größe | - | Kontaktieren Sie uns für große Repos (TBs Daten) | | Dateien pro Repo | <100k | Daten in weniger Dateien zusammenführen | | Einträge pro Ordner | <10k | Unterverzeichnisse im Repo verwenden | | Dateigröße | <5GB | Daten in geteilte Dateien aufteilen | | Commit-Größe | <100 files* | Dateien in mehreren Commits hochladen | | Commits pro Repo | - | Mehrere Dateien pro Commit hochladen und/oder Historie zusammenführen | _* Nicht relevant bei direkter Verwendung des `git` CLI_ Bitte lesen Sie den nächsten Abschnitt, um diese Beschränkungen besser zu verstehen und zu erfahren, wie Sie damit umgehen können. ### Hub-Repository Größenbeschränkungen Was meinen wir, wenn wir von "großen Uploads" sprechen, und welche Einschränkungen sind damit verbunden? Große Uploads können sehr unterschiedlich sein, von Repositories mit einigen riesigen Dateien (z. B. Modellgewichten) bis hin zu Repositories mit Tausenden von kleinen Dateien (z. B. einem Bilddatensatz). Hinter den Kulissen verwendet der Hub Git zur Versionierung der Daten, was strukturelle Auswirkungen darauf hat, was Sie in Ihrem Repo tun können. Wenn Ihr Repo einige der im vorherigen Abschnitt erwähnten Zahlen überschreitet, **empfehlen wir Ihnen dringend, [`git-sizer`](https://github.com/github/git-sizer) zu verwenden**, das eine sehr detaillierte Dokumentation über die verschiedenen Faktoren bietet, die Ihr Erlebnis beeinflussen werden. Hier ist ein TL;DR der zu berücksichtigenden Faktoren: - **Repository-Größe**: Die Gesamtgröße der Daten, die Sie hochladen möchten. Es gibt keine feste Obergrenze für die Größe eines Hub-Repositories. Wenn Sie jedoch vorhaben, Hunderte von GBs oder sogar TBs an Daten hochzuladen, würden wir es begrüßen, wenn Sie uns dies im Voraus mitteilen könnten, damit wir Ihnen besser helfen können, falls Sie während des Prozesses Fragen haben. Sie können uns unter datasets@huggingface.co oder auf [unserem Discord](http://hf.co/join/discord) kontaktieren. - **Anzahl der Dateien**: - Für ein optimales Erlebnis empfehlen wir, die Gesamtzahl der Dateien unter 100k zu halten. Versuchen Sie, die Daten zu weniger Dateien zusammenzuführen, wenn Sie mehr haben. Zum Beispiel können json-Dateien zu einer einzigen jsonl-Datei zusammengeführt oder große Datensätze als Parquet-Dateien exportiert werden. - Die maximale Anzahl von Dateien pro Ordner darf 10k Dateien pro Ordner nicht überschreiten. Eine einfache Lösung besteht darin, eine Repository-Struktur zu erstellen, die Unterverzeichnisse verwendet. Ein Repo mit 1k Ordnern von `000/` bis `999/`, in dem jeweils maximal 1000 Dateien enthalten sind, reicht bereits aus. - **Dateigröße**: Bei hochzuladenden großen Dateien (z. B. Modellgewichte) empfehlen wir dringend, sie **in Blöcke von etwa 5GB aufzuteilen**. Es gibt mehrere Gründe dafür: - Das Hoch- und Herunterladen kleinerer Dateien ist sowohl für Sie als auch für andere Benutzer viel einfacher. Bei der Datenübertragung können immer Verbindungsprobleme auftreten, und kleinere Dateien vermeiden das erneute Starten von Anfang an im Falle von Fehlern. - Dateien werden den Benutzern über CloudFront bereitgestellt. Aus unserer Erfahrung werden riesige Dateien von diesem Dienst nicht zwischengespeichert, was zu einer langsameren Downloadgeschwindigkeit führt. In jedem Fall wird keine einzelne LFS-Datei >50GB sein können. D. h. 50GB ist das absolute Limit für die Einzeldateigröße. - **Anzahl der Commits**: Es gibt kein festes Limit für die Gesamtzahl der Commits in Ihrer Repo-Historie. Aus unserer Erfahrung heraus beginnt das Benutzererlebnis im Hub jedoch nach einigen Tausend Commits abzunehmen. Wir arbeiten ständig daran, den Service zu verbessern, aber man sollte immer daran denken, dass ein Git-Repository nicht als Datenbank mit vielen Schreibzugriffen gedacht ist. Wenn die Historie Ihres Repos sehr groß wird, können Sie immer alle Commits mit [`super_squash_history`] zusammenfassen, um einen Neuanfang zu erhalten. Dies ist eine nicht rückgängig zu machende Operation. - **Anzahl der Operationen pro Commit**: Auch hier gibt es keine feste Obergrenze. Wenn ein Commit im Hub hochgeladen wird, wird jede Git-Operation (Hinzufügen oder Löschen) vom Server überprüft. Wenn hundert LFS-Dateien auf einmal committed werden, wird jede Datei einzeln überprüft, um sicherzustellen, dass sie korrekt hochgeladen wurde. Beim Pushen von Daten über HTTP mit `huggingface_hub` wird ein Timeout von 60s für die Anforderung festgelegt, was bedeutet, dass, wenn der Prozess mehr Zeit in Anspruch nimmt, clientseitig ein Fehler ausgelöst wird. Es kann jedoch (in seltenen Fällen) vorkommen, dass selbst wenn das Timeout clientseitig ausgelöst wird, der Prozess serverseitig dennoch abgeschlossen wird. Dies kann manuell überprüft werden, indem man das Repo im Hub durchsucht. Um dieses Timeout zu vermeiden, empfehlen wir, pro Commit etwa 50-100 Dateien hinzuzufügen. ### Praktische Tipps Nachdem wir die technischen Aspekte gesehen haben, die Sie bei der Strukturierung Ihres Repositories berücksichtigen müssen, schauen wir uns einige praktische Tipps an, um Ihren Upload-Prozess so reibungslos wie möglich zu gestalten. - **Fangen Sie klein an**: Wir empfehlen, mit einer kleinen Datenmenge zu beginnen, um Ihr Upload-Skript zu testen. Es ist einfacher, an einem Skript zu arbeiten, wenn ein Fehler nur wenig Zeit kostet. - **Rechnen Sie mit Ausfällen**: Das Streamen großer Datenmengen ist eine Herausforderung. Sie wissen nicht, was passieren kann, aber es ist immer am besten anzunehmen, dass etwas mindestens einmal schiefgehen wird - unabhängig davon, ob es an Ihrem Gerät, Ihrer Verbindung oder unseren Servern liegt. Wenn Sie zum Beispiel vorhaben, eine große Anzahl von Dateien hochzuladen, ist es am besten, lokal zu verfolgen, welche Dateien Sie bereits hochgeladen haben, bevor Sie die nächste Batch hochladen. Sie können sicher sein, dass eine LFS-Datei, die bereits committed wurde, niemals zweimal hochgeladen wird, aber es kann clientseitig trotzdem Zeit sparen, dies zu überprüfen. - **Verwenden Sie `hf_transfer`**: Dabei handelt es sich um eine auf Rust basierende [Bibliothek](https://github.com/huggingface/hf_transfer), die dazu dient, Uploads auf Maschinen mit sehr hoher Bandbreite zu beschleunigen. Um sie zu verwenden, müssen Sie sie installieren (`pip install hf_transfer`) und sie durch Einstellen von `HF_HUB_ENABLE_HF_TRANSFER=1` als Umgebungsvariable aktivieren. Anschließend können Sie `huggingface_hub` wie gewohnt verwenden. Hinweis: Dies ist ein Tool für Power-User. Es ist getestet und einsatzbereit, verfügt jedoch nicht über benutzerfreundliche Funktionen wie Fortschrittsanzeigen oder erweiterte Fehlerbehandlung. ## (veraltet) Dateien mit Git LFS hochladen Alle oben beschriebenen Methoden verwenden die Hub-API, um Dateien hochzuladen. Dies ist der empfohlene Weg, Dateien in den Hub hochzuladen. Wir bieten jedoch auch [`Repository`] an, einen Wrapper um das git-Tool, um ein lokales Repository zu verwalten. Obwohl [`Repository`] formell nicht als veraltet gekennzeichnet ist, empfehlen wir stattdessen die Nutzung der HTTP-basierten Methoden, die oben beschrieben sind. Für weitere Details zu dieser Empfehlung werfen Sie bitte einen Blick auf diesen [Leitfaden](../concepts/git_vs_http), der die Kernunterschiede zwischen HTTP- und Git-basierten Ansätzen erklärt. Git LFS verarbeitet automatisch Dateien, die größer als 10MB sind. Für sehr große Dateien (>5GB) müssen Sie jedoch einen benutzerdefinierten Transferagenten für Git LFS installieren: ```bash huggingface-cli lfs-enable-largefiles ``` Sie sollten dies für jedes Repository installieren, das eine sehr große Datei enthält. Einmal installiert, können Sie Dateien hochladen, die größer als 5GB sind. ### commit Kontextmanager Der `commit` Kontextmanager handhabt vier der gängigsten Git-Befehle: pull, add, commit und push. `git-lfs` beobactet automatisch jede Datei, die größer als 10MB ist. Im folgenden Beispiel handhabt der `commit` Kontextmanager die folgenden Aufgaben: 1. Holt Daten aus dem `text-files` Repository. 2. Fügt eine Änderung an `file.txt` hinzu. 3. Committet die Änderung. 4. Schickt die Änderung an das `text-files` Repository. ```python >>> from huggingface_hub import Repository >>> with Repository(local_dir="text-files", clone_from="/text-files").commit(commit_message="Mein erste Datei :)"): ... with open("file.txt", "w+") as f: ... f.write(json.dumps({"hey": 8})) ``` Hier ist ein weiteres Beispiel, wie man den `commit` Kontextmanager verwendet, um eine Datei in einem Repository zu speichern und hochzuladen: ```python >>> import torch >>> model = torch.nn.Transformer() >>> with Repository("torch-model", clone_from="/torch-model", token=True).commit(commit_message="Mein cooles Model :)"): ... torch.save(model.state_dict(), "model.pt") ``` Setzen Sie `blocking=False`, wenn Sie Ihre Commits asynchron pushen möchten. Das nicht-blockierende Verhalten ist nützlich, wenn Sie Ihr Skript weiterhin ausführen möchten, während Ihre Commits gesendet werden. ```python >>> with repo.commit(commit_message="Mein cooles Model :)", blocking=False) ``` Sie können den Status Ihres Pushs mit der Methode `command_queue` überprüfen: ```python >>> last_command = repo.command_queue[-1] >>> last_command.status ``` Beachten Sie die Tabelle mit möglichen Statuscodes: | Status | Beschreibung | | -------- | ------------------------------------ | | -1 | Der Push wird ausgeführt. | | 0 | Der Push wurde erfolgreich beendet. | | Non-zero | Ein Fehler ist aufgetreten. | Wenn `blocking=False` gesetzt ist, werden Befehle beobachtet und Ihr Skript wird erst beendet, wenn alle Pushs abgeschlossen sind, auch wenn andere Fehler in Ihrem Skript auftreten. Einige zusätzliche nützliche Befehle, um den Status eines Pushs zu überprüfen, sind: ```python # Einen Fehler inspizieren. >>> last_command.stderr # Überprüfen, ob ein Push abgeschlossen ist oder noch läuft. >>> last_command.is_done # Überprüfen, ob bei einem Push-Befehl ein Fehler aufgetreten ist. >>> last_command.failed ``` ### push_to_hub Die Klasse [`Repository`] hat eine Funktion [`~Repository.push_to_hub`], um Dateien hinzuzufügen, einen Commit zu machen und diese zu einem Repository zu pushen. Im Gegensatz zum `commit` Kontextmanager müssen Sie zuerst von einem Repository pullen, bevor Sie [`~Repository.push_to_hub`] aufrufen. Zum Beispiel, wenn Sie bereits ein Repository vom Hub geklont haben, können Sie das `repo` vom lokalen Verzeichnis initialisieren: ```python >>> from huggingface_hub import Repository >>> repo = Repository(local_dir="pfad/zur/lokalen/repo") ``` Aktualisieren Sie Ihren lokalen Klon mit [`~Repository.git_pull`] und dann pushen Sie Ihre Datei zum Hub: ```py >>> repo.git_pull() >>> repo.push_to_hub(commit_message="Committe meine geniale Datei zum Hub") ``` Wenn Sie jedoch noch nicht bereit sind, eine Datei zu pushen, können Sie [`~Repository.git_add`] und [`~Repository.git_commit`] verwenden, um nur Ihre Datei hinzuzufügen und zu committen: ```py >>> repo.git_add("path/to/file") >>> repo.git_commit(commit_message="füge meine erste Modell-Konfigurationsdatei hinzu :)") ``` Wenn Sie bereit sind, pushen Sie die Datei zu Ihrem Repository mit [`~Repository.git_push`]: ```py >>> repo.git_push() ``` huggingface_hub-0.31.1/docs/source/de/guides/webhooks_server.md000066400000000000000000000260001500667546600245470ustar00rootroot00000000000000 # Webhooks Server Webhooks sind ein Grundpfeiler für MLOps-bezogene Funktionen. Sie ermöglichen es Ihnen, auf neue Änderungen in bestimmten Repos oder auf alle Repos, die bestimmten Benutzern/Organisationen gehören, die Sie interessieren, zu hören. Dieser Leitfaden erklärt, wie Sie den `huggingface_hub` nutzen können, um einen Server zu erstellen, der auf Webhooks hört und ihn in einen Space zu implementieren. Es wird davon ausgegangen, dass Sie mit dem Konzept der Webhooks auf dem Huggingface Hub vertraut sind. Um mehr über Webhooks selbst zu erfahren, können Sie zuerst diesen [Leitfaden](https://huggingface.co/docs/hub/webhooks) lesen. Die Basis-Klasse, die wir in diesem Leitfaden verwenden werden, ist der [`WebhooksServer`]. Es handelt sich um eine Klasse, mit der sich ein Server leicht konfigurieren lässt, der Webhooks vom Huggingface Hub empfangen kann. Der Server basiert auf einer Gradio-App. Er verfügt über eine Benutzeroberfläche zur Anzeige von Anweisungen für Sie oder Ihre Benutzer und eine API zum Hören auf Webhooks. Um ein Beispiel eines laufenden Webhook-Servers zu sehen, werfen Sie einen Blick auf den [Spaces CI Bot](https://huggingface.co/spaces/spaces-ci-bot/webhook). Es handelt sich um einen Space, der kurzlebige Umgebungen startet, wenn ein PR in einem Space geöffnet wird. Dies ist ein [experimentelles Feature](../package_reference/environment_variables#hfhubdisableexperimentalwarning). Das bedeutet, dass wir noch daran arbeiten, die API zu verbessern. Es könnten in der Zukunft ohne vorherige Ankündigung Änderungen vorgenommen werden. Stellen Sie sicher, dass Sie die Version des `huggingface_hub` in Ihren Anforderungen festlegen. ## Einen Endpunkt erstellen Das Implementieren eines Webhook-Endpunkts ist so einfach wie das Dekorieren einer Funktion. Lassen Sie uns ein erstes Beispiel betrachten, um die Hauptkonzepte zu erklären: ```python # app.py from huggingface_hub import webhook_endpoint, WebhookPayload @webhook_endpoint async def trigger_training(payload: WebhookPayload) -> None: if payload.repo.type == "dataset" and payload.event.action == "update": # Einen Trainingsjob auslösen, wenn ein Datensatz aktualisiert wird ... ``` Speichern Sie diesen Ausschnitt in einer Datei namens `'app.py'` und führen Sie ihn mit `'python app.py'` aus. Sie sollten eine Nachricht wie diese sehen: ```text Webhook secret is not defined. This means your webhook endpoints will be open to everyone. To add a secret, set `WEBHOOK_SECRET` as environment variable or pass it at initialization: `app = WebhooksServer(webhook_secret='my_secret', ...)` For more details about webhook secrets, please refer to https://huggingface.co/docs/hub/webhooks#webhook-secret. Running on local URL: http://127.0.0.1:7860 Running on public URL: https://1fadb0f52d8bf825fc.gradio.live This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces Webhooks are correctly setup and ready to use: - POST https://1fadb0f52d8bf825fc.gradio.live/webhooks/trigger_training Go to https://huggingface.co/settings/webhooks to setup your webhooks. ``` Gute Arbeit! Sie haben gerade einen Webhook-Server gestartet! Lassen Sie uns genau aufschlüsseln, was passiert ist: 1. Durch das Dekorieren einer Funktion mit [`webhook_endpoint`] wurde im Hintergrund ein [`WebhooksServer`]-Objekt erstellt. Wie Sie sehen können, handelt es sich bei diesem Server um eine Gradio-App, die unter http://127.0.0.1:7860 läuft. Wenn Sie diese URL in Ihrem Browser öffnen, sehen Sie eine Landing Page mit Anweisungen zu den registrierten Webhooks. 2. Eine Gradio-App ist im Kern ein FastAPI-Server. Eine neue POST-Route `/webhooks/trigger_training` wurde hinzugefügt. Dies ist die Route, die auf Webhooks hört und die Funktion `trigger_training` ausführt, wenn sie ausgelöst wird. FastAPI wird das Payload automatisch parsen und es der Funktion als [`WebhookPayload`]-Objekt übergeben. Dies ist ein `pydantisches` Objekt, das alle Informationen über das Ereignis enthält, das den Webhook ausgelöst hat. 3. Die Gradio-App hat auch einen Tunnel geöffnet, um Anfragen aus dem Internet zu empfangen. Das Interessante daran ist: Sie können einen Webhook auf https://huggingface.co/settings/webhooks konfigurieren, der auf Ihren lokalen Rechner zeigt. Dies ist nützlich zum Debuggen Ihres Webhook-Servers und zum schnellen Iterieren, bevor Sie ihn in einem Space bereitstellen. 4. Schließlich teilen Ihnen die Logs auch mit, dass Ihr Server derzeit nicht durch ein Geheimnis gesichert ist. Dies ist für das lokale Debuggen nicht problematisch, sollte aber für später berücksichtigt werden. Standardmäßig wird der Server am Ende Ihres Skripts gestartet. Wenn Sie es in einem Notizbuch ausführen, können Sie den Server manuell starten, indem Sie `decorated_function.run()` aufrufen. Da ein einzigartiger Server verwendet wird, müssen Sie den Server nur einmal starten, auch wenn Sie mehrere Endpunkte haben. ## Konfigurieren eines Webhook Jetzt, da Sie einen Webhook-Server am Laufen haben, möchten Sie einen Webhook konfigurieren, um Nachrichten zu empfangen. Gehen Sie zu https://huggingface.co/settings/webhooks, klicken Sie auf "Add a new webhook" und konfigurieren Sie Ihren Webhook. Legen Sie die Ziel-Repositories fest, die Sie beobachten möchten, und die Webhook-URL, hier `https://1fadb0f52d8bf825fc.gradio.live/webhooks/trigger_training`.
Und das war's! Sie können den Webhook jetzt auslösen, indem Sie das Ziel-Repository aktualisieren (z.B. einen Commit pushen). Überprüfen Sie den Aktivitäts-Tab Ihres Webhooks, um die ausgelösten Ereignisse zu sehen. Jetzt, wo Sie eine funktionierende Einrichtung haben, können Sie sie testen und schnell iterieren. Wenn Sie Ihren Code ändern und den Server neu starten, könnte sich Ihre öffentliche URL ändern. Stellen Sie sicher, dass Sie die Webhook-Konfiguration im Hub bei Bedarf aktualisieren. ## Bereitstellung in einem Space Jetzt, da Sie einen funktionierenden Webhook-Server haben, ist das Ziel, ihn in einem Space bereitzustellen. Gehen Sie zu https://huggingface.co/new-space, um einen Space zu erstellen. Geben Sie ihm einen Namen, wählen Sie das Gradio SDK und klicken Sie auf "Create Space". Laden Sie Ihren Code in den Space in einer Datei namens `app.py` hoch. Ihr Space wird automatisch gestartet! Für weitere Informationen zu Spaces lesen Sie bitte diesen [Leitfaden](https://huggingface.co/docs/hub/spaces-overview). Ihr Webhook-Server läuft nun auf einem öffentlichen Space. In den meisten Fällen möchten Sie ihn mit einem Geheimnis absichern. Gehen Sie zu Ihren Space-Einstellungen > Abschnitt "Repository secrets" > "Add a secret". Setzen Sie die Umgebungsvariable `WEBHOOK_SECRET` auf den von Ihnen gewählten Wert. Gehen Sie zurück zu den [Webhook-Einstellungen](https://huggingface.co/settings/webhooks) und setzen Sie das Geheimnis in der Webhook-Konfiguration. Jetzt werden von Ihrem Server nur Anfragen mit dem korrekten Geheimnis akzeptiert. Und das war's! Ihr Space ist nun bereit, Webhooks vom Hub zu empfangen. Bitte beachten Sie, dass wenn Sie den Space auf einer kostenlosen 'cpu-basic' Hardware ausführen, er nach 48 Stunden Inaktivität heruntergefahren wird. Wenn Sie einen permanenten Space benötigen, sollten Sie in Erwägung ziehen, auf eine [upgraded hardware](https://huggingface.co/docs/hub/spaces-gpus#hardware-specs) umzustellen. ## Erweiterte Nutzung Der obenstehende Leitfaden erklärte den schnellsten Weg, einen [`WebhooksServer`] einzurichten. In diesem Abschnitt werden wir sehen, wie man ihn weiter anpassen kann. ### Mehrere Endpunkte Sie können mehrere Endpunkte auf demselben Server registrieren. Beispielsweise möchten Sie vielleicht einen Endpunkt haben, um einen Trainingsjob auszulösen und einen anderen, um eine Modellevaluierung auszulösen. Dies können Sie tun, indem Sie mehrere `@webhook_endpoint`-Dekorateure hinzufügen: ```python # app.py from huggingface_hub import webhook_endpoint, WebhookPayload @webhook_endpoint async def trigger_training(payload: WebhookPayload) -> None: if payload.repo.type == "dataset" and payload.event.action == "update": # Einen Trainingsjob auslösen, wenn ein Datensatz aktualisiert wird ... @webhook_endpoint async def trigger_evaluation(payload: WebhookPayload) -> None: if payload.repo.type == "model" and payload.event.action == "update": # Einen Evaluierungsauftrag auslösen, wenn ein Modell aktualisiert wird ... ``` Dies wird zwei Endpunkte erstellen: ```text (...) Webhooks are correctly setup and ready to use: - POST https://1fadb0f52d8bf825fc.gradio.live/webhooks/trigger_training - POST https://1fadb0f52d8bf825fc.gradio.live/webhooks/trigger_evaluation ``` ### Benutzerdefinierter Server Um mehr Flexibilität zu erhalten, können Sie auch direkt ein [`WebhooksServer`] Objekt erstellen. Dies ist nützlich, wenn Sie die Startseite Ihres Servers anpassen möchten. Sie können dies tun, indem Sie eine [Gradio UI](https://gradio.app/docs/#blocks) übergeben, die die Standard-UI überschreibt. Zum Beispiel können Sie Anweisungen für Ihre Benutzer hinzufügen oder ein Formular zur manuellen Auslösung der Webhooks hinzufügen. Bei der Erstellung eines [`WebhooksServer`] können Sie mit dem Dekorateur [`~WebhooksServer.add_webhook`] neue Webhooks registrieren. Hier ist ein vollständiges Beispiel: ```python import gradio as gr from fastapi import Request from huggingface_hub import WebhooksServer, WebhookPayload # 1. Benutzerdefinierte UI definieren with gr.Blocks() as ui: ... # 2. Erstellen eines WebhooksServer mit benutzerdefinierter UI und Geheimnis app = WebhooksServer(ui=ui, webhook_secret="my_secret_key") # 3. Webhook mit explizitem Namen registrieren @app.add_webhook("/say_hello") async def hello(payload: WebhookPayload): return {"message": "hello"} # 4. Webhook mit implizitem Namen registrierene @app.add_webhook async def goodbye(payload: WebhookPayload): return {"message": "goodbye"} # 5. Server starten (optional) app.run() ``` 1. Wir definieren eine benutzerdefinierte UI mit Gradio-Blöcken. Diese UI wird auf der Startseite des Servers angezeigt. 2. Wir erstellen ein [`WebhooksServer`]-Objekt mit einer benutzerdefinierten UI und einem Geheimnis. Das Geheimnis ist optional und kann mit der `WEBHOOK_SECRET` Umgebungsvariable gesetzt werden. 3. Wir registrieren einen Webhook mit einem expliziten Namen. Dies wird einen Endpunkt unter `/webhooks/say_hello` erstellen. 4. Wir registrieren einen Webhook mit einem impliziten Namen. Dies wird einen Endpunkt unter `/webhooks/goodbye` erstellen. 5. Wir starten den Server. Dies ist optional, da Ihr Server automatisch am Ende des Skripts gestartet wird. huggingface_hub-0.31.1/docs/source/de/index.md000066400000000000000000000077061500667546600212030ustar00rootroot00000000000000 # 🤗 Hub client bibliothek Die `huggingface_hub` Bibliothek ermöglicht die Interaktion mit dem [Hugging Face Hub](https://hf.co), einer Plattform für maschinelles Lernen, die für Entwickler und Mitwirkende konzipiert ist. Hier können Sie vorab trainierte Modelle und Datensätze entdecken, mit zahlreichen Apps für maschinelles Lernen experimentieren und eigene Modelle sowie Datensätze mit der Community teilen. Die `huggingface_hub` Bibliothek macht es einfach, all das in Python umzusetzen. In der [Kurzanleitung](quick-start) der `huggingface_hub` Bibliothek erfahren Sie wie Sie Dateien vom Hub herunterladen, Repositories erstellen und Inhalte auf den Hub hochladen können. Weiterführend können Sie sich über Verwaltung von Repositories, die Interaktion in Diskussionen und den Zugriff auf die Inferenz-API auf dem 🤗 Hub informieren. ## Beitragen Alle Beiträge zum `huggingface_hub` sind willkommen und werden gleichermaßen geschätzt! 🤗 Neben dem Hinzufügen neuer Features oder dem Beheben von Problemen können Sie auch zur Verbesserung der Dokumentation beitragen, indem Sie ihre Richtigkeit und Aktualität gewährleisten. Sie können auch bei der Lösung von Fragen mithelfen oder neue Features vorschlagen, um die Bibliothek weiterzuentwickeln. Schauen Sie in die [Beitragsrichtlinien](https://github.com/huggingface/huggingface_hub/blob/main/CONTRIBUTING.md), um mehr zu erfahren über das Melden von Problemen, das Vorschlagen von Features, das Einreichen eines Pull Requests, und das Testen Ihrer Code-Einsendungen, um sicherzustellen dass alles so wie erwartet funktioniert. Mitwirkende halten sich bitte an unseren [Verhaltenskodex](https://github.com/huggingface/huggingface_hub/blob/main/CODE_OF_CONDUCT.md), um eine inklusive und einladende Umgebung zur Zusammenarbeit für Alle zu gewährleisten. huggingface_hub-0.31.1/docs/source/de/installation.md000066400000000000000000000171521500667546600225710ustar00rootroot00000000000000 # Installation Bevor Sie beginnen, müssen Sie Ihre Umgebung vorbereiten, indem Sie die entsprechenden Pakete installieren. `huggingface_hub` wurde für **Python 3.8+** getestet. ## Installation mit pip Es wird dringend empfohlen, `huggingface_hub` in einer [virtuellen Umgebung](https://docs.python.org/3/library/venv.html) zu installieren. Wenn Sie mit virtuellen Umgebungen in Python nicht vertraut sind, werfen Sie einen Blick auf diesen [Leitfaden](https://packaging.python.org/en/latest/guides/installing-using-pip-and-virtual-environments/). Eine virtuelle Umgebung erleichtert die Verwaltung verschiedener Projekte und verhindert Kompatibilitätsprobleme zwischen Abhängigkeiten. Beginnen Sie damit, eine virtuelle Umgebung in Ihrem Projektverzeichnis zu erstellen: ```bash python -m venv .env ``` Aktivieren Sie die virtuelle Umgebung. Unter Linux und macOS: ```bash source .env/bin/activate ``` Aktivieren der virtuellen Umgebung unter Windows: ```bash .env/Scripts/activate ``` Jetzt können Sie `huggingface_hub` aus dem [PyPi-Register](https://pypi.org/project/huggingface-hub/) installieren: ```bash pip install --upgrade huggingface_hub ``` Überprüfen Sie nach Abschluss, ob die [Installation korrekt funktioniert](#installation-berprfen). ### Installieren optionaler Abhängigkeiten Einige Abhängigkeiten von `huggingface_hub` sind [optional](https://setuptools.pypa.io/en/latest/userguide/dependency_management.html#optional-dependencies), da sie nicht notwendig sind, um die Kernfunktionen von `huggingface_hub` auszuführen. Allerdings könnten einige Funktionen von `huggingface_hub` nicht verfügbar sein, wenn die optionalen Abhängigkeiten nicht installiert sind. Sie können optionale Abhängigkeiten über `pip` installieren: ```bash # Abhängigkeiten für spezifische TensorFlow-Funktionen installieren # /!\ Achtung: dies entspricht nicht `pip install tensorflow` pip install 'huggingface_hub[tensorflow]' # Abhängigkeiten sowohl für torch-spezifische als auch für CLI-spezifische Funktionen installieren. pip install 'huggingface_hub[cli,torch]' ``` Hier ist die Liste der optionalen Abhängigkeiten in huggingface_hub: - `cli`: bietet eine komfortablere CLI-Schnittstelle für huggingface_hub. - `fastai`, `torch`, `tensorflow`: Abhängigkeiten, um framework-spezifische Funktionen auszuführen. - `dev`: Abhängigkeiten, um zur Bibliothek beizutragen. Enthält `testing` (um Tests auszuführen), `typing` (um den Type Checker auszuführen) und `quality` (um Linters auszuführen). ### Installieren von der Quelle In einigen Fällen kann es sinnvoll sein, `huggingface_hub` direkt von der Quelle zu installieren. Dies ermöglicht es Ihnen, die aktuellste `main`-Version anstelle der neuesten stabilen Version zu verwenden. Die `main`-Version ist nützlich, um immer auf dem neuesten Stand der Entwicklungen zu bleiben, zum Beispiel wenn ein Fehler seit der letzten offiziellen Veröffentlichung behoben wurde, aber noch keine neue Version herausgegeben wurde. Das bedeutet jedoch, dass die `main`-Version nicht immer stabil sein könnte. Wir bemühen uns, die Hauptversion funktionsfähig zu halten, und die meisten Probleme werden in der Regel innerhalb von einigen Stunden oder einem Tag gelöst. Wenn Sie auf ein Problem stoßen, eröffnen Sie bitte ein "Issue", damit wir es noch schneller beheben können! ```bash pip install git+https://github.com/huggingface/huggingface_hub ``` Bei der Installation von der Quelle können Sie auch einen bestimmten Zweig angeben. Dies ist nützlich, wenn Sie ein neues Feature oder einen neuen Fehlerbehebung testen möchten, der noch nicht zusammengeführt wurde: ```bash pip install git+https://github.com/huggingface/huggingface_hub@my-feature-branch ``` Überprüfen Sie nach Abschluss, ob die [Installation korrekt funktioniert](#installation-berprfen). ### Editierbare Installation Die Installation von der Quelle ermöglicht Ihnen eine [editierbare Installation](https://pip.pypa.io/en/stable/topics/local-project-installs/#editable-installs). Dies ist eine fortgeschrittenere Installation, wenn Sie zur Entwicklung von `huggingface_hub` beitragen und Änderungen im Code testen möchten. Sie müssen eine lokale Kopie von `huggingface_hub` auf Ihrem Computer klonen. ```bash # Zuerst die Repository lokal klonen git clone https://github.com/huggingface/huggingface_hub.git # Dann mit dem -e Flag installieren cd huggingface_hub pip install -e . ``` Diese Befehle verknüpfen den Ordner, in den Sie das Repository geklont haben, mit Ihren Python-Bibliothekspfaden. Python wird nun zusätzlich zu den normalen Bibliothekspfaden im geklonten Ordner suchen. Wenn Ihre Python-Pakete normalerweise in `./.venv/lib/python3.13/site-packages/` installiert sind, wird Python auch den geklonten Ordner `./huggingface_hub/` durchsuchen. ## Installieren mit conda Wenn Sie damit vertrauter sind, können Sie `huggingface_hub` über den [conda-forge-Kanal](https://anaconda.org/conda-forge/huggingface_hub) installieren: ```bash conda install -c conda-forge huggingface_hub ``` Überprüfen Sie nach Abschluss, ob die [Installation korrekt funktioniert](#installation-berprfen). ## Installation überprüfen Nach der Installation überprüfen Sie, ob `huggingface_hub` richtig funktioniert, indem Sie den folgenden Befehl ausführen: ```bash python -c "from huggingface_hub import model_info; print(model_info('gpt2'))" ``` Dieser Befehl ruft Informationen vom Hub über das [gpt2](https://huggingface.co/gpt2)-Modell ab. Die Ausgabe sollte so aussehen: ```text Model Name: gpt2 Tags: ['pytorch', 'tf', 'jax', 'tflite', 'rust', 'safetensors', 'gpt2', 'text-generation', 'en', 'doi:10.57967/hf/0039', 'transformers', 'exbert', 'license:mit', 'has_space'] Task: text-generation ``` ## Windows-Einschränkungen Mit unserem Ziel, gutes ML überall zu demokratisieren, haben wir `huggingface_hub` als plattformübergreifende Bibliothek entwickelt, insbesondere um sowohl auf Unix-basierten als auch auf Windows-Systemen korrekt zu funktionieren. Es gibt jedoch einige Fälle, in denen `huggingface_hub` unter Windows gewisse Einschränkungen hat. Hier ist eine ausführliche Liste der bekannten Probleme. Bitte informieren Sie uns, wenn Sie auf ein nicht dokumentiertes Problem stoßen, indem Sie ein [Issue auf Github eröffnen](https://github.com/huggingface/huggingface_hub/issues/new/choose). - Das Cache-System von `huggingface_hub` verwendet Symlinks, um Dateien, die vom Hub heruntergeladen wurden, effizient zu cachen. Unter Windows müssen Sie den Entwicklermodus aktivieren oder Ihr Skript als Admin ausführen, um Symlinks zu aktivieren. Wenn sie nicht aktiviert sind, funktioniert das Cache-System immer noch, aber nicht optimiert. Bitte lesen Sie den Abschnitt über [Cache-Einschränkungen](./guides/manage-cache#limitations) für weitere Details. - Dateipfade auf dem Hub können Sonderzeichen enthalten (z.B. `"pfad/zu?/meiner/datei"`). Windows ist bei [Sonderzeichen](https://learn.microsoft.com/en-us/windows/win32/intl/character-sets-used-in-file-names) restriktiver, wodurch es unmöglich ist, diese Dateien unter Windows herunterzuladen. Hoffentlich ist dies ein seltener Fall. Bitte wenden Sie sich an den Repo-Eigentümer, wenn Sie denken, dass dies ein Fehler ist, oder an uns, um eine Lösung zu finden. ## Nächste Schritte Sobald `huggingface_hub`` richtig auf Ihrem Computer installiert ist, möchten Sie vielleicht [Umgebungsvariablen konfigurieren](package_reference/environment_variables) oder [einen unserer Leitfäden durchgehen](guides/overview), um loszulegen. huggingface_hub-0.31.1/docs/source/de/quick-start.md000066400000000000000000000151051500667546600223330ustar00rootroot00000000000000 # Kurzanleitung Der [Hugging Face Hub](https://huggingface.co/) ist die erste Anlaufstelle für das Teilen von Maschinenlernmodellen, Demos, Datensätzen und Metriken. Die `huggingface_hub`-Bibliothek hilft Ihnen, mit dem Hub zu interagieren, ohne Ihre Entwicklungs-Umgebung zu verlassen. Sie können Repositories einfach erstellen und verwalten, Dateien herunterladen und hochladen und nützliche Model- und Datensatz-Metadaten vom Hub abrufen. ## Installation Um loszulegen, installieren Sie die `huggingface_hub`-Bibliothek: ```bash pip install --upgrade huggingface_hub ``` Für weitere Details schauen Sie sich bitte den [Installationsleitfaden](installation) an. ## Dateien herunterladen Repositories auf dem Hub sind mit git versioniert, und Benutzer können eine einzelne Datei oder das gesamte Repository herunterladen. Sie können die Funktion [`hf_hub_download`] verwenden, um Dateien herunterzuladen. Diese Funktion lädt eine Datei herunter und speichert sie im Cache auf Ihrer lokalen Festplatte. Das nächste Mal, wenn Sie diese Datei benötigen, wird sie aus Ihrem Cache geladen, sodass Sie sie nicht erneut herunterladen müssen. Sie benötigen die Repository-ID und den Dateinamen der Datei, die Sie herunterladen möchten. Zum Beispiel, um die Konfigurationsdatei des [Pegasus](https://huggingface.co/google/pegasus-xsum) Modells herunterzuladen: ```py >>> from huggingface_hub import hf_hub_download >>> hf_hub_download(repo_id="google/pegasus-xsum", filename="config.json") ``` Um eine bestimmte Version der Datei herunterzuladen, verwenden Sie den `revision`-Parameter, um den Namen der Branch, des Tags oder des Commit-Hashes anzugeben. Wenn Sie sich für den Commit-Hash entscheiden, muss es der vollständige Hash anstelle des kürzeren 7-Zeichen-Commit-Hashes sein: ```py >>> from huggingface_hub import hf_hub_download >>> hf_hub_download( ... repo_id="google/pegasus-xsum", ... filename="config.json", ... revision="4d33b01d79672f27f001f6abade33f22d993b151" ... ) ``` Für weitere Details und Optionen siehe die API-Referenz für [`hf_hub_download`]. ## Anmeldung In vielen Fällen müssen Sie mit einem Hugging Face-Konto angemeldet sein, um mit dem Hub zu interagieren: private Repos herunterladen, Dateien hochladen, PRs erstellen,... [Erstellen Sie ein Konto](https://huggingface.co/join), wenn Sie noch keines haben, und melden Sie sich dann an, um Ihr ["User Access Token"](https://huggingface.co/docs/hub/security-tokens) von Ihrer [Einstellungsseite](https://huggingface.co/settings/tokens) zu erhalten. Das "User Access Token" wird verwendet, um Ihre Identität gegenüber dem Hub zu authentifizieren. Sobald Sie Ihr "User Access Token" haben, führen Sie den folgenden Befehl in Ihrem Terminal aus: ```bash huggingface-cli login # or using an environment variable huggingface-cli login --token $HUGGINGFACE_TOKEN ``` Alternativ können Sie sich auch programmatisch in einem Notebook oder einem Skript mit [`login`] anmelden: ```py >>> from huggingface_hub import login >>> login() ``` Es ist auch möglich, sich programmatisch anzumelden, ohne aufgefordert zu werden, Ihr Token einzugeben, indem Sie das Token direkt an [`login`] weitergeben, wie z.B. `login(token="hf_xxx")`. Seien Sie vorsichtig, wenn Sie Ihren Quellcode teilen. Es ist eine bewährte Methode, das Token aus einem sicheren Tresor/Vault zu laden, anstatt es explizit in Ihrer Codebasis/Notebook zu speichern. Sie können nur auf 1 Konto gleichzeitig angemeldet sein. Wenn Sie Ihren Computer mit einem neuen Konto anmelden, werden Sie vom vorherigen abgemeldet. Mit dem Befehl `huggingface-cli whoami` stellen Sie sicher, dass Sie immer wissen, welches Konto Sie gerade verwenden. Wenn Sie mehrere Konten im selben Skript verwalten möchten, können Sie Ihr Token bereitstellen, wenn Sie jede Methode aufrufen. Dies ist auch nützlich, wenn Sie kein Token auf Ihrem Computer speichern möchten. Sobald Sie angemeldet sind, werden alle Anfragen an den Hub - auch Methoden, die nicht unbedingt eine Authentifizierung erfordern - standardmäßig Ihr Zugriffstoken verwenden. Wenn Sie die implizite Verwendung Ihres Tokens deaktivieren möchten, sollten Sie die Umgebungsvariable `HF_HUB_DISABLE_IMPLICIT_TOKEN` setzen. ## Eine Repository erstellen Nachdem Sie sich registriert und angemeldet haben, können Sie mit der Funktion [`create_repo`] ein Repository erstellen: ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.create_repo(repo_id="super-cool-model") ``` If you want your repository to be private, then: ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.create_repo(repo_id="super-cool-model", private=True) ``` Private Repositories sind für niemanden außer Ihnen selbst sichtbar. Um eine Repository zu erstellen oder Inhalte auf den Hub zu pushen, müssen Sie ein "User Access Token" bereitstellen, das die Schreibberechtigung (`write`) hat. Sie können die Berechtigung auswählen, wenn Sie das Token auf Ihrer [Einstellungsseite](https://huggingface.co/settings/tokens) erstellen. ## Dateien hochladen Verwenden Sie die [`upload_file`]-Funktion, um eine Datei zu Ihrem neu erstellten Repository hinzuzufügen. Sie müssen dabei das Folgende angeben: 1. Den Pfad der hochzuladenden Datei. 2. Den Pfad der Datei im Repository. 3. Die Repository-ID, zu der Sie die Datei hinzufügen möchten. ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.upload_file( ... path_or_fileobj="/home/lysandre/dummy-test/README.md", ... path_in_repo="README.md", ... repo_id="lysandre/test-model", ... ) ``` Um mehr als eine Datei gleichzeitig hochzuladen, werfen Sie bitte einen Blick auf den [Upload](./guides/upload)-Leitfaden, der Ihnen verschiedene Methoden zum Hochladen von Dateien vorstellt (mit oder ohne git). ## Nächste Schritte Die `huggingface_hub`-Bibliothek bietet den Benutzern eine einfache Möglichkeit, mittels Python mit dem Hub zu interagieren. Um mehr darüber zu erfahren, wie Sie Ihre Dateien und Repositories auf dem Hub verwalten können, empfehlen wir, unsere [How-to-Leitfäden](./guides/overview) zu lesen: - [Verwalten Sie Ihre Repository](./guides/repository). - Dateien vom Hub [herunterladen](./guides/download). - Dateien auf den Hub [hochladen](./guides/upload). - [Durchsuchen Sie den Hub](./guides/search) nach dem gewünschten Modell oder Datensatz. - [Greifen Sie auf die Inferenz-API zu](./guides/inference), um schnelle Inferenzen durchzuführen. huggingface_hub-0.31.1/docs/source/en/000077500000000000000000000000001500667546600175525ustar00rootroot00000000000000huggingface_hub-0.31.1/docs/source/en/_redirects.yml000066400000000000000000000013631500667546600224230ustar00rootroot00000000000000# Move "how-to" pages to the guides/ folder how-to-cache: guides/manage-cache how-to-discussions-and-pull-requests: guides/community how-to-downstream: guides/download how-to-inference: guides/inference how-to-manage: guides/repository how-to-model-cards: guides/model-cards how-to-upstream: guides/upload search-the-hub: guides/search guides/manage_spaces: guides/manage-spaces package_reference/inference_api: package_reference/inference_client package_reference/login: package_reference/authentication # Alias for hf-transfer description hf_transfer: package_reference/environment_variables#hfhubenablehftransfer # Alias for auth authentication: quick-start#authentication # Rename webhooks_server to webhooks guides/webhooks_server: guides/webhooks huggingface_hub-0.31.1/docs/source/en/_toctree.yml000066400000000000000000000053321500667546600221040ustar00rootroot00000000000000- title: 'Get started' sections: - local: index title: Home - local: quick-start title: Quickstart - local: installation title: Installation - title: 'How-to guides' sections: - local: guides/overview title: Overview - local: guides/download title: Download files - local: guides/upload title: Upload files - local: guides/cli title: Use the CLI - local: guides/hf_file_system title: HfFileSystem - local: guides/repository title: Repository - local: guides/search title: Search - local: guides/inference title: Inference - local: guides/inference_endpoints title: Inference Endpoints - local: guides/community title: Community Tab - local: guides/collections title: Collections - local: guides/manage-cache title: Cache - local: guides/model-cards title: Model Cards - local: guides/manage-spaces title: Manage your Space - local: guides/integrations title: Integrate a library - local: guides/webhooks title: Webhooks - title: 'Conceptual guides' sections: - local: concepts/git_vs_http title: Git vs HTTP paradigm - title: 'Reference' sections: - local: package_reference/overview title: Overview - local: package_reference/authentication title: Authentication - local: package_reference/environment_variables title: Environment variables - local: package_reference/repository title: Managing local and online repositories - local: package_reference/hf_api title: Hugging Face Hub API - local: package_reference/file_download title: Downloading files - local: package_reference/mixins title: Mixins & serialization methods - local: package_reference/inference_types title: Inference Types - local: package_reference/inference_client title: Inference Client - local: package_reference/inference_endpoints title: Inference Endpoints - local: package_reference/hf_file_system title: HfFileSystem - local: package_reference/utilities title: Utilities - local: package_reference/community title: Discussions and Pull Requests - local: package_reference/cache title: Cache-system reference - local: package_reference/cards title: Repo Cards and Repo Card Data - local: package_reference/space_runtime title: Space runtime - local: package_reference/collections title: Collections - local: package_reference/tensorboard title: TensorBoard logger - local: package_reference/webhooks_server title: Webhooks server - local: package_reference/serialization title: Serialization huggingface_hub-0.31.1/docs/source/en/concepts/000077500000000000000000000000001500667546600213705ustar00rootroot00000000000000huggingface_hub-0.31.1/docs/source/en/concepts/git_vs_http.md000066400000000000000000000076131500667546600242530ustar00rootroot00000000000000 # Git vs HTTP paradigm The `huggingface_hub` library is a library for interacting with the Hugging Face Hub, which is a collection of git-based repositories (models, datasets or Spaces). There are two main ways to access the Hub using `huggingface_hub`. The first approach, the so-called "git-based" approach, is led by the [`Repository`] class. This method uses a wrapper around the `git` command with additional functions specifically designed to interact with the Hub. The second option, called the "HTTP-based" approach, involves making HTTP requests using the [`HfApi`] client. Let's examine the pros and cons of each approach. ## Repository: the historical git-based approach At first, `huggingface_hub` was mostly built around the [`Repository`] class. It provides Python wrappers for common `git` commands such as `"git add"`, `"git commit"`, `"git push"`, `"git tag"`, `"git checkout"`, etc. The library also helps with setting credentials and tracking large files, which are often used in machine learning repositories. Additionally, the library allows you to execute its methods in the background, making it useful for uploading data during training. The main advantage of using a [`Repository`] is that it allows you to maintain a local copy of the entire repository on your machine. This can also be a disadvantage as it requires you to constantly update and maintain this local copy. This is similar to traditional software development where each developer maintains their own local copy and pushes changes when working on a feature. However, in the context of machine learning, this may not always be necessary as users may only need to download weights for inference or convert weights from one format to another without the need to clone the entire repository. [`Repository`] is now deprecated in favor of the http-based alternatives. Given its large adoption in legacy code, the complete removal of [`Repository`] will only happen in release `v1.0`. ## HfApi: a flexible and convenient HTTP client The [`HfApi`] class was developed to provide an alternative to local git repositories, which can be cumbersome to maintain, especially when dealing with large models or datasets. The [`HfApi`] class offers the same functionality as git-based approaches, such as downloading and pushing files and creating branches and tags, but without the need for a local folder that needs to be kept in sync. In addition to the functionalities already provided by `git`, the [`HfApi`] class offers additional features, such as the ability to manage repos, download files using caching for efficient reuse, search the Hub for repos and metadata, access community features such as discussions, PRs, and comments, and configure Spaces hardware and secrets. ## What should I use ? And when ? Overall, the **HTTP-based approach is the recommended way to use** `huggingface_hub` in all cases. [`HfApi`] allows to pull and push changes, work with PRs, tags and branches, interact with discussions and much more. Since the `0.16` release, the http-based methods can also run in the background, which was the last major advantage of the [`Repository`] class. However, not all git commands are available through [`HfApi`]. Some may never be implemented, but we are always trying to improve and close the gap. If you don't see your use case covered, please open [an issue on Github](https://github.com/huggingface/huggingface_hub)! We welcome feedback to help build the 🤗 ecosystem with and for our users. This preference of the http-based [`HfApi`] over the git-based [`Repository`] does not mean that git versioning will disappear from the Hugging Face Hub anytime soon. It will always be possible to use `git` commands locally in workflows where it makes sense. huggingface_hub-0.31.1/docs/source/en/guides/000077500000000000000000000000001500667546600210325ustar00rootroot00000000000000huggingface_hub-0.31.1/docs/source/en/guides/cli.md000066400000000000000000000637451500667546600221420ustar00rootroot00000000000000 # Command Line Interface (CLI) The `huggingface_hub` Python package comes with a built-in CLI called `huggingface-cli`. This tool allows you to interact with the Hugging Face Hub directly from a terminal. For example, you can login to your account, create a repository, upload and download files, etc. It also comes with handy features to configure your machine or manage your cache. In this guide, we will have a look at the main features of the CLI and how to use them. ## Getting started First of all, let's install the CLI: ``` >>> pip install -U "huggingface_hub[cli]" ``` In the snippet above, we also installed the `[cli]` extra dependencies to make the user experience better, especially when using the `delete-cache` command. Once installed, you can check that the CLI is correctly setup: ``` >>> huggingface-cli --help usage: huggingface-cli [] positional arguments: {env,login,whoami,logout,repo,upload,download,lfs-enable-largefiles,lfs-multipart-upload,scan-cache,delete-cache,tag} huggingface-cli command helpers env Print information about the environment. login Log in using a token from huggingface.co/settings/tokens whoami Find out which huggingface.co account you are logged in as. logout Log out repo {create} Commands to interact with your huggingface.co repos. upload Upload a file or a folder to a repo on the Hub download Download files from the Hub lfs-enable-largefiles Configure your repository to enable upload of files > 5GB. scan-cache Scan cache directory. delete-cache Delete revisions from the cache directory. tag (create, list, delete) tags for a repo in the hub options: -h, --help show this help message and exit ``` If the CLI is correctly installed, you should see a list of all the options available in the CLI. If you get an error message such as `command not found: huggingface-cli`, please refer to the [Installation](../installation) guide. The `--help` option is very convenient for getting more details about a command. You can use it anytime to list all available options and their details. For example, `huggingface-cli upload --help` provides more information on how to upload files using the CLI. ### Alternative install #### Using pkgx [Pkgx](https://pkgx.sh) is a blazingly fast cross platform package manager that runs anything. You can install huggingface-cli using pkgx as follows: ```bash >>> pkgx install huggingface-cli ``` Or you can run huggingface-cli directly: ```bash >>> pkgx huggingface-cli --help ``` Check out the pkgx huggingface page [here](https://pkgx.dev/pkgs/huggingface.co/) for more details. #### Using Homebrew You can also install the CLI using [Homebrew](https://brew.sh/): ```bash >>> brew install huggingface-cli ``` Check out the Homebrew huggingface page [here](https://formulae.brew.sh/formula/huggingface-cli) for more details. ## huggingface-cli login In many cases, you must be logged in to a Hugging Face account to interact with the Hub (download private repos, upload files, create PRs, etc.). To do so, you need a [User Access Token](https://huggingface.co/docs/hub/security-tokens) from your [Settings page](https://huggingface.co/settings/tokens). The User Access Token is used to authenticate your identity to the Hub. Make sure to set a token with write access if you want to upload or modify content. Once you have your token, run the following command in your terminal: ```bash >>> huggingface-cli login ``` This command will prompt you for a token. Copy-paste yours and press *Enter*. Then, you'll be asked if the token should also be saved as a git credential. Press *Enter* again (default to yes) if you plan to use `git` locally. Finally, it will call the Hub to check that your token is valid and save it locally. ``` _| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| _|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_| _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| _| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_| To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens . Enter your token (input will not be visible): Add token as git credential? (Y/n) Token is valid (permission: write). Your token has been saved in your configured git credential helpers (store). Your token has been saved to /home/wauplin/.cache/huggingface/token Login successful ``` Alternatively, if you want to log-in without being prompted, you can pass the token directly from the command line. To be more secure, we recommend passing your token as an environment variable to avoid pasting it in your command history. ```bash # Or using an environment variable >>> huggingface-cli login --token $HF_TOKEN --add-to-git-credential Token is valid (permission: write). The token `token_name` has been saved to /home/wauplin/.cache/huggingface/stored_tokens Your token has been saved in your configured git credential helpers (store). Your token has been saved to /home/wauplin/.cache/huggingface/token Login successful The current active token is: `token_name` ``` For more details about authentication, check out [this section](../quick-start#authentication). ## huggingface-cli whoami If you want to know if you are logged in, you can use `huggingface-cli whoami`. This command doesn't have any options and simply prints your username and the organizations you are a part of on the Hub: ```bash huggingface-cli whoami Wauplin orgs: huggingface,eu-test,OAuthTesters,hf-accelerate,HFSmolCluster ``` If you are not logged in, an error message will be printed. ## huggingface-cli logout This command logs you out. In practice, it will delete all tokens stored on your machine. If you want to remove a specific token, you can specify the token name as an argument. This command will not log you out if you are logged in using the `HF_TOKEN` environment variable (see [reference](../package_reference/environment_variables#hftoken)). If that is the case, you must unset the environment variable in your machine configuration. ## huggingface-cli download Use the `huggingface-cli download` command to download files from the Hub directly. Internally, it uses the same [`hf_hub_download`] and [`snapshot_download`] helpers described in the [Download](./download) guide and prints the returned path to the terminal. In the examples below, we will walk through the most common use cases. For a full list of available options, you can run: ```bash huggingface-cli download --help ``` ### Download a single file To download a single file from a repo, simply provide the repo_id and filename as follow: ```bash >>> huggingface-cli download gpt2 config.json downloading https://huggingface.co/gpt2/resolve/main/config.json to /home/wauplin/.cache/huggingface/hub/tmpwrq8dm5o (…)ingface.co/gpt2/resolve/main/config.json: 100%|██████████████████████████████████| 665/665 [00:00<00:00, 2.49MB/s] /home/wauplin/.cache/huggingface/hub/models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10/config.json ``` The command will always print on the last line the path to the file on your local machine. To download a file located in a subdirectory of the repo, you should provide the path of the file in the repo in posix format like this: ```bash >>> huggingface-cli download HiDream-ai/HiDream-I1-Full text_encoder/model.safetensors ``` ### Download an entire repository In some cases, you just want to download all the files from a repository. This can be done by just specifying the repo id: ```bash >>> huggingface-cli download HuggingFaceH4/zephyr-7b-beta Fetching 23 files: 0%| | 0/23 [00:00>> huggingface-cli download gpt2 config.json model.safetensors Fetching 2 files: 0%| | 0/2 [00:00>> huggingface-cli download stabilityai/stable-diffusion-xl-base-1.0 --include "*.safetensors" --exclude "*.fp16.*"* Fetching 8 files: 0%| | 0/8 [00:00>> huggingface-cli download HuggingFaceH4/ultrachat_200k --repo-type dataset # https://huggingface.co/spaces/HuggingFaceH4/zephyr-chat >>> huggingface-cli download HuggingFaceH4/zephyr-chat --repo-type space ... ``` ### Download a specific revision The examples above show how to download from the latest commit on the main branch. To download from a specific revision (commit hash, branch name or tag), use the `--revision` option: ```bash >>> huggingface-cli download bigcode/the-stack --repo-type dataset --revision v1.1 ... ``` ### Download to a local folder The recommended (and default) way to download files from the Hub is to use the cache-system. However, in some cases you want to download files and move them to a specific folder. This is useful to get a workflow closer to what git commands offer. You can do that using the `--local-dir` option. A `.cache/huggingface/` folder is created at the root of your local directory containing metadata about the downloaded files. This prevents re-downloading files if they're already up-to-date. If the metadata has changed, then the new file version is downloaded. This makes the `local-dir` optimized for pulling only the latest changes. For more details on how downloading to a local file works, check out the [download](./download#download-files-to-a-local-folder) guide. ```bash >>> huggingface-cli download adept/fuyu-8b model-00001-of-00002.safetensors --local-dir fuyu ... fuyu/model-00001-of-00002.safetensors ``` ### Specify cache directory If not using `--local-dir`, all files will be downloaded by default to the cache directory defined by the `HF_HOME` [environment variable](../package_reference/environment_variables#hfhome). You can specify a custom cache using `--cache-dir`: ```bash >>> huggingface-cli download adept/fuyu-8b --cache-dir ./path/to/cache ... ./path/to/cache/models--adept--fuyu-8b/snapshots/ddcacbcf5fdf9cc59ff01f6be6d6662624d9c745 ``` ### Specify a token To access private or gated repositories, you must use a token. By default, the token saved locally (using `huggingface-cli login`) will be used. If you want to authenticate explicitly, use the `--token` option: ```bash >>> huggingface-cli download gpt2 config.json --token=hf_**** /home/wauplin/.cache/huggingface/hub/models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10/config.json ``` ### Quiet mode By default, the `huggingface-cli download` command will be verbose. It will print details such as warning messages, information about the downloaded files, and progress bars. If you want to silence all of this, use the `--quiet` option. Only the last line (i.e. the path to the downloaded files) is printed. This can prove useful if you want to pass the output to another command in a script. ```bash >>> huggingface-cli download gpt2 --quiet /home/wauplin/.cache/huggingface/hub/models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10 ``` ### Download timeout On machines with slow connections, you might encounter timeout issues like this one: ```bash `requests.exceptions.ReadTimeout: (ReadTimeoutError("HTTPSConnectionPool(host='cdn-lfs-us-1.huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: a33d910c-84c6-4514-8362-c705e2039d38)')` ``` To mitigate this issue, you can set the `HF_HUB_DOWNLOAD_TIMEOUT` environment variable to a higher value (default is 10): ```bash export HF_HUB_DOWNLOAD_TIMEOUT=30 ``` For more details, check out the [environment variables reference](../package_reference/environment_variables#hfhubdownloadtimeout). And rerun your download command. ## huggingface-cli upload Use the `huggingface-cli upload` command to upload files to the Hub directly. Internally, it uses the same [`upload_file`] and [`upload_folder`] helpers described in the [Upload](./upload) guide. In the examples below, we will walk through the most common use cases. For a full list of available options, you can run: ```bash >>> huggingface-cli upload --help ``` ### Upload an entire folder The default usage for this command is: ```bash # Usage: huggingface-cli upload [repo_id] [local_path] [path_in_repo] ``` To upload the current directory at the root of the repo, use: ```bash >>> huggingface-cli upload my-cool-model . . https://huggingface.co/Wauplin/my-cool-model/tree/main/ ``` If the repo doesn't exist yet, it will be created automatically. You can also upload a specific folder: ```bash >>> huggingface-cli upload my-cool-model ./models . https://huggingface.co/Wauplin/my-cool-model/tree/main/ ``` Finally, you can upload a folder to a specific destination on the repo: ```bash >>> huggingface-cli upload my-cool-model ./path/to/curated/data /data/train https://huggingface.co/Wauplin/my-cool-model/tree/main/data/train ``` ### Upload a single file You can also upload a single file by setting `local_path` to point to a file on your machine. If that's the case, `path_in_repo` is optional and will default to the name of your local file: ```bash >>> huggingface-cli upload Wauplin/my-cool-model ./models/model.safetensors https://huggingface.co/Wauplin/my-cool-model/blob/main/model.safetensors ``` If you want to upload a single file to a specific directory, set `path_in_repo` accordingly: ```bash >>> huggingface-cli upload Wauplin/my-cool-model ./models/model.safetensors /vae/model.safetensors https://huggingface.co/Wauplin/my-cool-model/blob/main/vae/model.safetensors ``` ### Upload multiple files To upload multiple files from a folder at once without uploading the entire folder, use the `--include` and `--exclude` patterns. It can also be combined with the `--delete` option to delete files on the repo while uploading new ones. In the example below, we sync the local Space by deleting remote files and uploading all files except the ones in `/logs`: ```bash # Sync local Space with Hub (upload new files except from logs/, delete removed files) >>> huggingface-cli upload Wauplin/space-example --repo-type=space --exclude="/logs/*" --delete="*" --commit-message="Sync local Space with Hub" ... ``` ### Upload to a dataset or Space To upload to a dataset or a Space, use the `--repo-type` option: ```bash >>> huggingface-cli upload Wauplin/my-cool-dataset ./data /train --repo-type=dataset ... ``` ### Upload to an organization To upload content to a repo owned by an organization instead of a personal repo, you must explicitly specify it in the `repo_id`: ```bash >>> huggingface-cli upload MyCoolOrganization/my-cool-model . . https://huggingface.co/MyCoolOrganization/my-cool-model/tree/main/ ``` ### Upload to a specific revision By default, files are uploaded to the `main` branch. If you want to upload files to another branch or reference, use the `--revision` option: ```bash # Upload files to a PR >>> huggingface-cli upload bigcode/the-stack . . --repo-type dataset --revision refs/pr/104 ... ``` **Note:** if `revision` does not exist and `--create-pr` is not set, a branch will be created automatically from the `main` branch. ### Upload and create a PR If you don't have the permission to push to a repo, you must open a PR and let the authors know about the changes you want to make. This can be done by setting the `--create-pr` option: ```bash # Create a PR and upload the files to it >>> huggingface-cli upload bigcode/the-stack . . --repo-type dataset --revision refs/pr/104 https://huggingface.co/datasets/bigcode/the-stack/blob/refs%2Fpr%2F104/ ``` ### Upload at regular intervals In some cases, you might want to push regular updates to a repo. For example, this is useful if you're training a model and you want to upload the logs folder every 10 minutes. You can do this using the `--every` option: ```bash # Upload new logs every 10 minutes huggingface-cli upload training-model logs/ --every=10 ``` ### Specify a commit message Use the `--commit-message` and `--commit-description` to set a custom message and description for your commit instead of the default one ```bash >>> huggingface-cli upload Wauplin/my-cool-model ./models . --commit-message="Epoch 34/50" --commit-description="Val accuracy: 68%. Check tensorboard for more details." ... https://huggingface.co/Wauplin/my-cool-model/tree/main ``` ### Specify a token To upload files, you must use a token. By default, the token saved locally (using `huggingface-cli login`) will be used. If you want to authenticate explicitly, use the `--token` option: ```bash >>> huggingface-cli upload Wauplin/my-cool-model ./models . --token=hf_**** ... https://huggingface.co/Wauplin/my-cool-model/tree/main ``` ### Quiet mode By default, the `huggingface-cli upload` command will be verbose. It will print details such as warning messages, information about the uploaded files, and progress bars. If you want to silence all of this, use the `--quiet` option. Only the last line (i.e. the URL to the uploaded files) is printed. This can prove useful if you want to pass the output to another command in a script. ```bash >>> huggingface-cli upload Wauplin/my-cool-model ./models . --quiet https://huggingface.co/Wauplin/my-cool-model/tree/main ``` ## huggingface-cli repo-files If you want to delete files from a Hugging Face repository, use the `huggingface-cli repo-files` command. ### Delete files The `huggingface-cli repo-files delete` sub-command allows you to delete files from a repository. Here are some usage examples. Delete a folder : ```bash >>> huggingface-cli repo-files Wauplin/my-cool-model delete folder/ Files correctly deleted from repo. Commit: https://huggingface.co/Wauplin/my-cool-mo... ``` Delete multiple files: ```bash >>> huggingface-cli repo-files Wauplin/my-cool-model delete file.txt folder/pytorch_model.bin Files correctly deleted from repo. Commit: https://huggingface.co/Wauplin/my-cool-mo... ``` Use Unix-style wildcards to delete sets of files: ```bash >>> huggingface-cli repo-files Wauplin/my-cool-model delete "*.txt" "folder/*.bin" Files correctly deleted from repo. Commit: https://huggingface.co/Wauplin/my-cool-mo... ``` ### Specify a token To delete files from a repo you must be authenticated and authorized. By default, the token saved locally (using `huggingface-cli login`) will be used. If you want to authenticate explicitly, use the `--token` option: ```bash >>> huggingface-cli repo-files --token=hf_**** Wauplin/my-cool-model delete file.txt ``` ## huggingface-cli scan-cache Scanning your cache directory is useful if you want to know which repos you have downloaded and how much space it takes on your disk. You can do that by running `huggingface-cli scan-cache`: ```bash >>> huggingface-cli scan-cache REPO ID REPO TYPE SIZE ON DISK NB FILES LAST_ACCESSED LAST_MODIFIED REFS LOCAL PATH --------------------------- --------- ------------ -------- ------------- ------------- ------------------- ------------------------------------------------------------------------- glue dataset 116.3K 15 4 days ago 4 days ago 2.4.0, main, 1.17.0 /home/wauplin/.cache/huggingface/hub/datasets--glue google/fleurs dataset 64.9M 6 1 week ago 1 week ago refs/pr/1, main /home/wauplin/.cache/huggingface/hub/datasets--google--fleurs Jean-Baptiste/camembert-ner model 441.0M 7 2 weeks ago 16 hours ago main /home/wauplin/.cache/huggingface/hub/models--Jean-Baptiste--camembert-ner bert-base-cased model 1.9G 13 1 week ago 2 years ago /home/wauplin/.cache/huggingface/hub/models--bert-base-cased t5-base model 10.1K 3 3 months ago 3 months ago main /home/wauplin/.cache/huggingface/hub/models--t5-base t5-small model 970.7M 11 3 days ago 3 days ago refs/pr/1, main /home/wauplin/.cache/huggingface/hub/models--t5-small Done in 0.0s. Scanned 6 repo(s) for a total of 3.4G. Got 1 warning(s) while scanning. Use -vvv to print details. ``` For more details about how to scan your cache directory, please refer to the [Manage your cache](./manage-cache#scan-cache-from-the-terminal) guide. ## huggingface-cli delete-cache `huggingface-cli delete-cache` is a tool that helps you delete parts of your cache that you don't use anymore. This is useful for saving and freeing disk space. To learn more about using this command, please refer to the [Manage your cache](./manage-cache#clean-cache-from-the-terminal) guide. ## huggingface-cli tag The `huggingface-cli tag` command allows you to tag, untag, and list tags for repositories. ### Tag a model To tag a repo, you need to provide the `repo_id` and the `tag` name: ```bash >>> huggingface-cli tag Wauplin/my-cool-model v1.0 You are about to create tag v1.0 on model Wauplin/my-cool-model Tag v1.0 created on Wauplin/my-cool-model ``` ### Tag a model at a specific revision If you want to tag a specific revision, you can use the `--revision` option. By default, the tag will be created on the `main` branch: ```bash >>> huggingface-cli tag Wauplin/my-cool-model v1.0 --revision refs/pr/104 You are about to create tag v1.0 on model Wauplin/my-cool-model Tag v1.0 created on Wauplin/my-cool-model ``` ### Tag a dataset or a Space If you want to tag a dataset or Space, you must specify the `--repo-type` option: ```bash >>> huggingface-cli tag bigcode/the-stack v1.0 --repo-type dataset You are about to create tag v1.0 on dataset bigcode/the-stack Tag v1.0 created on bigcode/the-stack ``` ### List tags To list all tags for a repository, use the `-l` or `--list` option: ```bash >>> huggingface-cli tag Wauplin/gradio-space-ci -l --repo-type space Tags for space Wauplin/gradio-space-ci: 0.2.2 0.2.1 0.2.0 0.1.2 0.0.2 0.0.1 ``` ### Delete a tag To delete a tag, use the `-d` or `--delete` option: ```bash >>> huggingface-cli tag -d Wauplin/my-cool-model v1.0 You are about to delete tag v1.0 on model Wauplin/my-cool-model Proceed? [Y/n] y Tag v1.0 deleted on Wauplin/my-cool-model ``` You can also pass `-y` to skip the confirmation step. ## huggingface-cli env The `huggingface-cli env` command prints details about your machine setup. This is useful when you open an issue on [GitHub](https://github.com/huggingface/huggingface_hub) to help the maintainers investigate your problem. ```bash >>> huggingface-cli env Copy-and-paste the text below in your GitHub issue. - huggingface_hub version: 0.19.0.dev0 - Platform: Linux-6.2.0-36-generic-x86_64-with-glibc2.35 - Python version: 3.10.12 - Running in iPython ?: No - Running in notebook ?: No - Running in Google Colab ?: No - Token path ?: /home/wauplin/.cache/huggingface/token - Has saved token ?: True - Who am I ?: Wauplin - Configured git credential helpers: store - FastAI: N/A - Tensorflow: 2.11.0 - Torch: 1.12.1 - Jinja2: 3.1.2 - Graphviz: 0.20.1 - Pydot: 1.4.2 - Pillow: 9.2.0 - hf_transfer: 0.1.3 - gradio: 4.0.2 - tensorboard: 2.6 - numpy: 1.23.2 - pydantic: 2.4.2 - aiohttp: 3.8.4 - ENDPOINT: https://huggingface.co - HF_HUB_CACHE: /home/wauplin/.cache/huggingface/hub - HF_ASSETS_CACHE: /home/wauplin/.cache/huggingface/assets - HF_TOKEN_PATH: /home/wauplin/.cache/huggingface/token - HF_HUB_OFFLINE: False - HF_HUB_DISABLE_TELEMETRY: False - HF_HUB_DISABLE_PROGRESS_BARS: None - HF_HUB_DISABLE_SYMLINKS_WARNING: False - HF_HUB_DISABLE_EXPERIMENTAL_WARNING: False - HF_HUB_DISABLE_IMPLICIT_TOKEN: False - HF_HUB_ENABLE_HF_TRANSFER: False - HF_HUB_ETAG_TIMEOUT: 10 - HF_HUB_DOWNLOAD_TIMEOUT: 10 ``` huggingface_hub-0.31.1/docs/source/en/guides/collections.md000066400000000000000000000230251500667546600236740ustar00rootroot00000000000000 # Collections A collection is a group of related items on the Hub (models, datasets, Spaces, papers) that are organized together on the same page. Collections are useful for creating your own portfolio, bookmarking content in categories, or presenting a curated list of items you want to share. Check out this [guide](https://huggingface.co/docs/hub/collections) to understand in more detail what collections are and how they look on the Hub. You can directly manage collections in the browser, but in this guide, we will focus on how to manage them programmatically. ## Fetch a collection Use [`get_collection`] to fetch your collections or any public ones. You must have the collection's *slug* to retrieve a collection. A slug is an identifier for a collection based on the title and a unique ID. You can find the slug in the URL of the collection page.
Let's fetch the collection with, `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`: ```py >>> from huggingface_hub import get_collection >>> collection = get_collection("TheBloke/recent-models-64f9a55bb3115b4f513ec026") >>> collection Collection( slug='TheBloke/recent-models-64f9a55bb3115b4f513ec026', title='Recent models', owner='TheBloke', items=[...], last_updated=datetime.datetime(2023, 10, 2, 22, 56, 48, 632000, tzinfo=datetime.timezone.utc), position=1, private=False, theme='green', upvotes=90, description="Models I've recently quantized. Please note that currently this list has to be updated manually, and therefore is not guaranteed to be up-to-date." ) >>> collection.items[0] CollectionItem( item_object_id='651446103cd773a050bf64c2', item_id='TheBloke/U-Amethyst-20B-AWQ', item_type='model', position=88, note=None ) ``` The [`Collection`] object returned by [`get_collection`] contains: - high-level metadata: `slug`, `owner`, `title`, `description`, etc. - a list of [`CollectionItem`] objects; each item represents a model, a dataset, a Space, or a paper. All collection items are guaranteed to have: - a unique `item_object_id`: this is the id of the collection item in the database - an `item_id`: this is the id on the Hub of the underlying item (model, dataset, Space, paper); it is not necessarily unique, and only the `item_id`/`item_type` pair is unique - an `item_type`: model, dataset, Space, paper - the `position` of the item in the collection, which can be updated to reorganize your collection (see [`update_collection_item`] below) A `note` can also be attached to the item. This is useful to add additional information about the item (a comment, a link to a blog post, etc.). The attribute still has a `None` value if an item doesn't have a note. In addition to these base attributes, returned items can have additional attributes depending on their type: `author`, `private`, `lastModified`, `gated`, `title`, `likes`, `upvotes`, etc. None of these attributes are guaranteed to be returned. ## List collections We can also retrieve collections using [`list_collections`]. Collections can be filtered using some parameters. Let's list all the collections from the user [`teknium`](https://huggingface.co/teknium). ```py >>> from huggingface_hub import list_collections >>> collections = list_collections(owner="teknium") ``` This returns an iterable of `Collection` objects. We can iterate over them to print, for example, the number of upvotes for each collection. ```py >>> for collection in collections: ... print("Number of upvotes:", collection.upvotes) Number of upvotes: 1 Number of upvotes: 5 ``` When listing collections, the item list per collection is truncated to 4 items maximum. To retrieve all items from a collection, you must use [`get_collection`]. It is possible to do more advanced filtering. Let's get all collections containing the model [TheBloke/OpenHermes-2.5-Mistral-7B-GGUF](https://huggingface.co/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF), sorted by trending, and limit the count to 5. ```py >>> collections = list_collections(item="models/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF", sort="trending", limit=5): >>> for collection in collections: ... print(collection.slug) teknium/quantized-models-6544690bb978e0b0f7328748 AmeerH/function-calling-65560a2565d7a6ef568527af PostArchitekt/7bz-65479bb8c194936469697d8c gnomealone/need-to-test-652007226c6ce4cdacf9c233 Crataco/favorite-7b-models-651944072b4fffcb41f8b568 ``` Parameter `sort` must be one of `"last_modified"`, `"trending"` or `"upvotes"`. Parameter `item` accepts any particular item. For example: * `"models/teknium/OpenHermes-2.5-Mistral-7B"` * `"spaces/julien-c/open-gpt-rhyming-robot"` * `"datasets/squad"` * `"papers/2311.12983"` For more details, please check out [`list_collections`] reference. ## Create a new collection Now that we know how to get a [`Collection`], let's create our own! Use [`create_collection`] with a title and description. To create a collection on an organization page, pass `namespace="my-cool-org"` when creating the collection. Finally, you can also create private collections by passing `private=True`. ```py >>> from huggingface_hub import create_collection >>> collection = create_collection( ... title="ICCV 2023", ... description="Portfolio of models, papers and demos I presented at ICCV 2023", ... ) ``` It will return a [`Collection`] object with the high-level metadata (title, description, owner, etc.) and an empty list of items. You will now be able to refer to this collection using its `slug`. ```py >>> collection.slug 'owner/iccv-2023-15e23b46cb98efca45' >>> collection.title "ICCV 2023" >>> collection.owner "username" >>> collection.url 'https://huggingface.co/collections/owner/iccv-2023-15e23b46cb98efca45' ``` ## Manage items in a collection Now that we have a [`Collection`], we want to add items to it and organize them. ### Add items Items have to be added one by one using [`add_collection_item`]. You only need to know the `collection_slug`, `item_id` and `item_type`. Optionally, you can also add a `note` to the item (500 characters maximum). ```py >>> from huggingface_hub import create_collection, add_collection_item >>> collection = create_collection(title="OS Week Highlights - Sept 18 - 24", namespace="osanseviero") >>> collection.slug "osanseviero/os-week-highlights-sept-18-24-650bfed7f795a59f491afb80" >>> add_collection_item(collection.slug, item_id="coqui/xtts", item_type="space") >>> add_collection_item( ... collection.slug, ... item_id="warp-ai/wuerstchen", ... item_type="model", ... note="Würstchen is a new fast and efficient high resolution text-to-image architecture and model" ... ) >>> add_collection_item(collection.slug, item_id="lmsys/lmsys-chat-1m", item_type="dataset") >>> add_collection_item(collection.slug, item_id="warp-ai/wuerstchen", item_type="space") # same item_id, different item_type ``` If an item already exists in a collection (same `item_id`/`item_type` pair), an HTTP 409 error will be raised. You can choose to ignore this error by setting `exists_ok=True`. ### Add a note to an existing item You can modify an existing item to add or modify the note attached to it using [`update_collection_item`]. Let's reuse the example above: ```py >>> from huggingface_hub import get_collection, update_collection_item # Fetch collection with newly added items >>> collection_slug = "osanseviero/os-week-highlights-sept-18-24-650bfed7f795a59f491afb80" >>> collection = get_collection(collection_slug) # Add note the `lmsys-chat-1m` dataset >>> update_collection_item( ... collection_slug=collection_slug, ... item_object_id=collection.items[2].item_object_id, ... note="This dataset contains one million real-world conversations with 25 state-of-the-art LLMs.", ... ) ``` ### Reorder items Items in a collection are ordered. The order is determined by the `position` attribute of each item. By default, items are ordered by appending new items at the end of the collection. You can update the order using [`update_collection_item`] the same way you would add a note. Let's reuse our example above: ```py >>> from huggingface_hub import get_collection, update_collection_item # Fetch collection >>> collection_slug = "osanseviero/os-week-highlights-sept-18-24-650bfed7f795a59f491afb80" >>> collection = get_collection(collection_slug) # Reorder to place the two `Wuerstchen` items together >>> update_collection_item( ... collection_slug=collection_slug, ... item_object_id=collection.items[3].item_object_id, ... position=2, ... ) ``` ### Remove items Finally, you can also remove an item using [`delete_collection_item`]. ```py >>> from huggingface_hub import get_collection, update_collection_item # Fetch collection >>> collection_slug = "osanseviero/os-week-highlights-sept-18-24-650bfed7f795a59f491afb80" >>> collection = get_collection(collection_slug) # Remove `coqui/xtts` Space from the list >>> delete_collection_item(collection_slug=collection_slug, item_object_id=collection.items[0].item_object_id) ``` ## Delete collection A collection can be deleted using [`delete_collection`]. This is a non-revertible action. A deleted collection cannot be restored. ```py >>> from huggingface_hub import delete_collection >>> collection = delete_collection("username/useless-collection-64f9a55bb3115b4f513ec026", missing_ok=True) ``` huggingface_hub-0.31.1/docs/source/en/guides/community.md000066400000000000000000000135111500667546600234010ustar00rootroot00000000000000 # Interact with Discussions and Pull Requests The `huggingface_hub` library provides a Python interface to interact with Pull Requests and Discussions on the Hub. Visit [the dedicated documentation page](https://huggingface.co/docs/hub/repositories-pull-requests-discussions) for a deeper view of what Discussions and Pull Requests on the Hub are, and how they work under the hood. ## Retrieve Discussions and Pull Requests from the Hub The `HfApi` class allows you to retrieve Discussions and Pull Requests on a given repo: ```python >>> from huggingface_hub import get_repo_discussions >>> for discussion in get_repo_discussions(repo_id="bigscience/bloom"): ... print(f"{discussion.num} - {discussion.title}, pr: {discussion.is_pull_request}") # 11 - Add Flax weights, pr: True # 10 - Update README.md, pr: True # 9 - Training languages in the model card, pr: True # 8 - Update tokenizer_config.json, pr: True # 7 - Slurm training script, pr: False [...] ``` `HfApi.get_repo_discussions` supports filtering by author, type (Pull Request or Discussion) and status (`open` or `closed`): ```python >>> from huggingface_hub import get_repo_discussions >>> for discussion in get_repo_discussions( ... repo_id="bigscience/bloom", ... author="ArthurZ", ... discussion_type="pull_request", ... discussion_status="open", ... ): ... print(f"{discussion.num} - {discussion.title} by {discussion.author}, pr: {discussion.is_pull_request}") # 19 - Add Flax weights by ArthurZ, pr: True ``` `HfApi.get_repo_discussions` returns a [generator](https://docs.python.org/3.7/howto/functional.html#generators) that yields [`Discussion`] objects. To get all the Discussions in a single list, run: ```python >>> from huggingface_hub import get_repo_discussions >>> discussions_list = list(get_repo_discussions(repo_id="bert-base-uncased")) ``` The [`Discussion`] object returned by [`HfApi.get_repo_discussions`] contains high-level overview of the Discussion or Pull Request. You can also get more detailed information using [`HfApi.get_discussion_details`]: ```python >>> from huggingface_hub import get_discussion_details >>> get_discussion_details( ... repo_id="bigscience/bloom-1b3", ... discussion_num=2 ... ) DiscussionWithDetails( num=2, author='cakiki', title='Update VRAM memory for the V100s', status='open', is_pull_request=True, events=[ DiscussionComment(type='comment', author='cakiki', ...), DiscussionCommit(type='commit', author='cakiki', summary='Update VRAM memory for the V100s', oid='1256f9d9a33fa8887e1c1bf0e09b4713da96773a', ...), ], conflicting_files=[], target_branch='refs/heads/main', merge_commit_oid=None, diff='diff --git a/README.md b/README.md\nindex a6ae3b9294edf8d0eda0d67c7780a10241242a7e..3a1814f212bc3f0d3cc8f74bdbd316de4ae7b9e3 100644\n--- a/README.md\n+++ b/README.md\n@@ -132,7 +132,7 [...]', ) ``` [`HfApi.get_discussion_details`] returns a [`DiscussionWithDetails`] object, which is a subclass of [`Discussion`] with more detailed information about the Discussion or Pull Request. Information includes all the comments, status changes, and renames of the Discussion via [`DiscussionWithDetails.events`]. In case of a Pull Request, you can retrieve the raw git diff with [`DiscussionWithDetails.diff`]. All the commits of the Pull Requests are listed in [`DiscussionWithDetails.events`]. ## Create and edit a Discussion or Pull Request programmatically The [`HfApi`] class also offers ways to create and edit Discussions and Pull Requests. You will need an [access token](https://huggingface.co/docs/hub/security-tokens) to create and edit Discussions or Pull Requests. The simplest way to propose changes on a repo on the Hub is via the [`create_commit`] API: just set the `create_pr` parameter to `True`. This parameter is also available on other methods that wrap [`create_commit`]: * [`upload_file`] * [`upload_folder`] * [`delete_file`] * [`delete_folder`] * [`metadata_update`] ```python >>> from huggingface_hub import metadata_update >>> metadata_update( ... repo_id="username/repo_name", ... metadata={"tags": ["computer-vision", "awesome-model"]}, ... create_pr=True, ... ) ``` You can also use [`HfApi.create_discussion`] (respectively [`HfApi.create_pull_request`]) to create a Discussion (respectively a Pull Request) on a repo. Opening a Pull Request this way can be useful if you need to work on changes locally. Pull Requests opened this way will be in `"draft"` mode. ```python >>> from huggingface_hub import create_discussion, create_pull_request >>> create_discussion( ... repo_id="username/repo-name", ... title="Hi from the huggingface_hub library!", ... token="", ... ) DiscussionWithDetails(...) >>> create_pull_request( ... repo_id="username/repo-name", ... title="Hi from the huggingface_hub library!", ... token="", ... ) DiscussionWithDetails(..., is_pull_request=True) ``` Managing Pull Requests and Discussions can be done entirely with the [`HfApi`] class. For example: * [`comment_discussion`] to add comments * [`edit_discussion_comment`] to edit comments * [`rename_discussion`] to rename a Discussion or Pull Request * [`change_discussion_status`] to open or close a Discussion / Pull Request * [`merge_pull_request`] to merge a Pull Request Visit the [`HfApi`] documentation page for an exhaustive reference of all available methods. ## Push changes to a Pull Request *Coming soon !* ## See also For a more detailed reference, visit the [Discussions and Pull Requests](../package_reference/community) and the [hf_api](../package_reference/hf_api) documentation page. huggingface_hub-0.31.1/docs/source/en/guides/download.md000066400000000000000000000257051500667546600231740ustar00rootroot00000000000000 # Download files from the Hub The `huggingface_hub` library provides functions to download files from the repositories stored on the Hub. You can use these functions independently or integrate them into your own library, making it more convenient for your users to interact with the Hub. This guide will show you how to: * Download and cache a single file. * Download and cache an entire repository. * Download files to a local folder. ## Download a single file The [`hf_hub_download`] function is the main function for downloading files from the Hub. It downloads the remote file, caches it on disk (in a version-aware way), and returns its local file path. The returned filepath is a pointer to the HF local cache. Therefore, it is important to not modify the file to avoid having a corrupted cache. If you are interested in getting to know more about how files are cached, please refer to our [caching guide](./manage-cache). ### From latest version Select the file to download using the `repo_id`, `repo_type` and `filename` parameters. By default, the file will be considered as being part of a `model` repo. ```python >>> from huggingface_hub import hf_hub_download >>> hf_hub_download(repo_id="lysandre/arxiv-nlp", filename="config.json") '/root/.cache/huggingface/hub/models--lysandre--arxiv-nlp/snapshots/894a9adde21d9a3e3843e6d5aeaaf01875c7fade/config.json' # Download from a dataset >>> hf_hub_download(repo_id="google/fleurs", filename="fleurs.py", repo_type="dataset") '/root/.cache/huggingface/hub/datasets--google--fleurs/snapshots/199e4ae37915137c555b1765c01477c216287d34/fleurs.py' ``` ### From specific version By default, the latest version from the `main` branch is downloaded. However, in some cases you want to download a file at a particular version (e.g. from a specific branch, a PR, a tag or a commit hash). To do so, use the `revision` parameter: ```python # Download from the `v1.0` tag >>> hf_hub_download(repo_id="lysandre/arxiv-nlp", filename="config.json", revision="v1.0") # Download from the `test-branch` branch >>> hf_hub_download(repo_id="lysandre/arxiv-nlp", filename="config.json", revision="test-branch") # Download from Pull Request #3 >>> hf_hub_download(repo_id="lysandre/arxiv-nlp", filename="config.json", revision="refs/pr/3") # Download from a specific commit hash >>> hf_hub_download(repo_id="lysandre/arxiv-nlp", filename="config.json", revision="877b84a8f93f2d619faa2a6e514a32beef88ab0a") ``` **Note:** When using the commit hash, it must be the full-length hash instead of a 7-character commit hash. ### Construct a download URL In case you want to construct the URL used to download a file from a repo, you can use [`hf_hub_url`] which returns a URL. Note that it is used internally by [`hf_hub_download`]. ## Download an entire repository [`snapshot_download`] downloads an entire repository at a given revision. It uses internally [`hf_hub_download`] which means all downloaded files are also cached on your local disk. Downloads are made concurrently to speed-up the process. To download a whole repository, just pass the `repo_id` and `repo_type`: ```python >>> from huggingface_hub import snapshot_download >>> snapshot_download(repo_id="lysandre/arxiv-nlp") '/home/lysandre/.cache/huggingface/hub/models--lysandre--arxiv-nlp/snapshots/894a9adde21d9a3e3843e6d5aeaaf01875c7fade' # Or from a dataset >>> snapshot_download(repo_id="google/fleurs", repo_type="dataset") '/home/lysandre/.cache/huggingface/hub/datasets--google--fleurs/snapshots/199e4ae37915137c555b1765c01477c216287d34' ``` [`snapshot_download`] downloads the latest revision by default. If you want a specific repository revision, use the `revision` parameter: ```python >>> from huggingface_hub import snapshot_download >>> snapshot_download(repo_id="lysandre/arxiv-nlp", revision="refs/pr/1") ``` ### Filter files to download [`snapshot_download`] provides an easy way to download a repository. However, you don't always want to download the entire content of a repository. For example, you might want to prevent downloading all `.bin` files if you know you'll only use the `.safetensors` weights. You can do that using `allow_patterns` and `ignore_patterns` parameters. These parameters accept either a single pattern or a list of patterns. Patterns are Standard Wildcards (globbing patterns) as documented [here](https://tldp.org/LDP/GNU-Linux-Tools-Summary/html/x11655.htm). The pattern matching is based on [`fnmatch`](https://docs.python.org/3/library/fnmatch.html). For example, you can use `allow_patterns` to only download JSON configuration files: ```python >>> from huggingface_hub import snapshot_download >>> snapshot_download(repo_id="lysandre/arxiv-nlp", allow_patterns="*.json") ``` On the other hand, `ignore_patterns` can exclude certain files from being downloaded. The following example ignores the `.msgpack` and `.h5` file extensions: ```python >>> from huggingface_hub import snapshot_download >>> snapshot_download(repo_id="lysandre/arxiv-nlp", ignore_patterns=["*.msgpack", "*.h5"]) ``` Finally, you can combine both to precisely filter your download. Here is an example to download all json and markdown files except `vocab.json`. ```python >>> from huggingface_hub import snapshot_download >>> snapshot_download(repo_id="gpt2", allow_patterns=["*.md", "*.json"], ignore_patterns="vocab.json") ``` ## Download file(s) to a local folder By default, we recommend using the [cache system](./manage-cache) to download files from the Hub. You can specify a custom cache location using the `cache_dir` parameter in [`hf_hub_download`] and [`snapshot_download`], or by setting the [`HF_HOME`](../package_reference/environment_variables#hf_home) environment variable. However, if you need to download files to a specific folder, you can pass a `local_dir` parameter to the download function. This is useful to get a workflow closer to what the `git` command offers. The downloaded files will maintain their original file structure within the specified folder. For example, if `filename="data/train.csv"` and `local_dir="path/to/folder"`, the resulting filepath will be `"path/to/folder/data/train.csv"`. A `.cache/huggingface/` folder is created at the root of your local directory containing metadata about the downloaded files. This prevents re-downloading files if they're already up-to-date. If the metadata has changed, then the new file version is downloaded. This makes the `local_dir` optimized for pulling only the latest changes. After completing the download, you can safely remove the `.cache/huggingface/` folder if you no longer need it. However, be aware that re-running your script without this folder may result in longer recovery times, as metadata will be lost. Rest assured that your local data will remain intact and unaffected. Don't worry about the `.cache/huggingface/` folder when committing changes to the Hub! This folder is automatically ignored by both `git` and [`upload_folder`]. ## Download from the CLI You can use the `huggingface-cli download` command from the terminal to directly download files from the Hub. Internally, it uses the same [`hf_hub_download`] and [`snapshot_download`] helpers described above and prints the returned path to the terminal. ```bash >>> huggingface-cli download gpt2 config.json /home/wauplin/.cache/huggingface/hub/models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10/config.json ``` You can download multiple files at once which displays a progress bar and returns the snapshot path in which the files are located: ```bash >>> huggingface-cli download gpt2 config.json model.safetensors Fetching 2 files: 100%|████████████████████████████████████████████| 2/2 [00:00<00:00, 23831.27it/s] /home/wauplin/.cache/huggingface/hub/models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10 ``` For more details about the CLI download command, please refer to the [CLI guide](./cli#huggingface-cli-download). ## Faster downloads There are two options to speed up downloads. Both involve installing a Python package written in Rust. * `hf_xet` is newer and uses the Xet storage backend for upload/download. It is available in production, but is in the process of being rolled out to all users, so join the [waitlist](https://huggingface.co/join/xet) to get onboarded soon! * `hf_transfer` is a power-tool to download and upload to our LFS storage backend (note: this is less future-proof than Xet). It is thoroughly tested and has been in production for a long time, but it has some limitations. ### hf_xet Take advantage of faster downloads through `hf_xet`, the Python binding to the [`xet-core`](https://github.com/huggingface/xet-core) library that enables chunk-based deduplication for faster downloads and uploads. `hf_xet` integrates seamlessly with `huggingface_hub`, but uses the Rust `xet-core` library and Xet storage instead of LFS. `hf_xet` uses the Xet storage system, which breaks files down into immutable chunks, storing collections of these chunks (called blocks or xorbs) remotely and retrieving them to reassemble the file when requested. When downloading, after confirming the user is authorized to access the files, `hf_xet` will query the Xet content-addressable service (CAS) with the LFS SHA256 hash for this file to receive the reconstruction metadata (ranges within xorbs) to assemble these files, along with presigned URLs to download the xorbs directly. Then `hf_xet` will efficiently download the xorb ranges necessary and will write out the files on disk. `hf_xet` uses a local disk cache to only download chunks once, learn more in the [Chunk-based caching(Xet)](./manage-cache#chunk-based-caching-xet) section. To enable it, specify the `hf_xet` package when installing `huggingface_hub`: ```bash pip install -U "huggingface_hub[hf_xet]" ``` Note: `hf_xet` will only be utilized when the files being downloaded are being stored with Xet Storage. All other `huggingface_hub` APIs will continue to work without any modification. To learn more about the benefits of Xet storage and `hf_xet`, refer to this [section](https://huggingface.co/docs/hub/storage-backends). ### hf_transfer If you are running on a machine with high bandwidth, you can increase your download speed with [`hf_transfer`](https://github.com/huggingface/hf_transfer), a Rust-based library developed to speed up file transfers with the Hub. To enable it: 1. Specify the `hf_transfer` extra when installing `huggingface_hub` (e.g. `pip install huggingface_hub[hf_transfer]`). 2. Set `HF_HUB_ENABLE_HF_TRANSFER=1` as an environment variable. `hf_transfer` is a power user tool! It is tested and production-ready, but it lacks user-friendly features like advanced error handling or proxies. For more details, please take a look at this [section](https://huggingface.co/docs/huggingface_hub/hf_transfer). huggingface_hub-0.31.1/docs/source/en/guides/hf_file_system.md000066400000000000000000000123751500667546600243640ustar00rootroot00000000000000 # Interact with the Hub through the Filesystem API In addition to the [`HfApi`], the `huggingface_hub` library provides [`HfFileSystem`], a pythonic [fsspec-compatible](https://filesystem-spec.readthedocs.io/en/latest/) file interface to the Hugging Face Hub. The [`HfFileSystem`] builds on top of the [`HfApi`] and offers typical filesystem style operations like `cp`, `mv`, `ls`, `du`, `glob`, `get_file`, and `put_file`. [`HfFileSystem`] provides fsspec compatibility, which is useful for libraries that require it (e.g., reading Hugging Face datasets directly with `pandas`). However, it introduces additional overhead due to this compatibility layer. For better performance and reliability, it's recommended to use [`HfApi`] methods when possible. ## Usage ```python >>> from huggingface_hub import HfFileSystem >>> fs = HfFileSystem() >>> # List all files in a directory >>> fs.ls("datasets/my-username/my-dataset-repo/data", detail=False) ['datasets/my-username/my-dataset-repo/data/train.csv', 'datasets/my-username/my-dataset-repo/data/test.csv'] >>> # List all ".csv" files in a repo >>> fs.glob("datasets/my-username/my-dataset-repo/**/*.csv") ['datasets/my-username/my-dataset-repo/data/train.csv', 'datasets/my-username/my-dataset-repo/data/test.csv'] >>> # Read a remote file >>> with fs.open("datasets/my-username/my-dataset-repo/data/train.csv", "r") as f: ... train_data = f.readlines() >>> # Read the content of a remote file as a string >>> train_data = fs.read_text("datasets/my-username/my-dataset-repo/data/train.csv", revision="dev") >>> # Write a remote file >>> with fs.open("datasets/my-username/my-dataset-repo/data/validation.csv", "w") as f: ... f.write("text,label") ... f.write("Fantastic movie!,good") ``` The optional `revision` argument can be passed to run an operation from a specific commit such as a branch, tag name, or a commit hash. Unlike Python's built-in `open`, `fsspec`'s `open` defaults to binary mode, `"rb"`. This means you must explicitly set mode as `"r"` for reading and `"w"` for writing in text mode. Appending to a file (modes `"a"` and `"ab"`) is not supported yet. ## Integrations The [`HfFileSystem`] can be used with any library that integrates `fsspec`, provided the URL follows the scheme: ``` hf://[][@]/ ```
The `repo_type_prefix` is `datasets/` for datasets, `spaces/` for spaces, and models don't need a prefix in the URL. Some interesting integrations where [`HfFileSystem`] simplifies interacting with the Hub are listed below: * Reading/writing a [Pandas](https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html#reading-writing-remote-files) DataFrame from/to a Hub repository: ```python >>> import pandas as pd >>> # Read a remote CSV file into a dataframe >>> df = pd.read_csv("hf://datasets/my-username/my-dataset-repo/train.csv") >>> # Write a dataframe to a remote CSV file >>> df.to_csv("hf://datasets/my-username/my-dataset-repo/test.csv") ``` The same workflow can also be used for [Dask](https://docs.dask.org/en/stable/how-to/connect-to-remote-data.html) and [Polars](https://pola-rs.github.io/polars/py-polars/html/reference/io.html) DataFrames. * Querying (remote) Hub files with [DuckDB](https://duckdb.org/docs/guides/python/filesystems): ```python >>> from huggingface_hub import HfFileSystem >>> import duckdb >>> fs = HfFileSystem() >>> duckdb.register_filesystem(fs) >>> # Query a remote file and get the result back as a dataframe >>> fs_query_file = "hf://datasets/my-username/my-dataset-repo/data_dir/data.parquet" >>> df = duckdb.query(f"SELECT * FROM '{fs_query_file}' LIMIT 10").df() ``` * Using the Hub as an array store with [Zarr](https://zarr.readthedocs.io/en/stable/tutorial.html#io-with-fsspec): ```python >>> import numpy as np >>> import zarr >>> embeddings = np.random.randn(50000, 1000).astype("float32") >>> # Write an array to a repo >>> with zarr.open_group("hf://my-username/my-model-repo/array-store", mode="w") as root: ... foo = root.create_group("embeddings") ... foobar = foo.zeros('experiment_0', shape=(50000, 1000), chunks=(10000, 1000), dtype='f4') ... foobar[:] = embeddings >>> # Read an array from a repo >>> with zarr.open_group("hf://my-username/my-model-repo/array-store", mode="r") as root: ... first_row = root["embeddings/experiment_0"][0] ``` ## Authentication In many cases, you must be logged in with a Hugging Face account to interact with the Hub. Refer to the [Authentication](../quick-start#authentication) section of the documentation to learn more about authentication methods on the Hub. It is also possible to log in programmatically by passing your `token` as an argument to [`HfFileSystem`]: ```python >>> from huggingface_hub import HfFileSystem >>> fs = HfFileSystem(token=token) ``` If you log in this way, be careful not to accidentally leak the token when sharing your source code! huggingface_hub-0.31.1/docs/source/en/guides/inference.md000066400000000000000000000574371500667546600233320ustar00rootroot00000000000000 # Run Inference on servers Inference is the process of using a trained model to make predictions on new data. Because this process can be compute-intensive, running on a dedicated or external service can be an interesting option. The `huggingface_hub` library provides a unified interface to run inference across multiple services for models hosted on the Hugging Face Hub: 1. [HF Inference API](https://huggingface.co/docs/api-inference/index): a serverless solution that allows you to run model inference on Hugging Face's infrastructure for free. This service is a fast way to get started, test different models, and prototype AI products. 2. [Third-party providers](#supported-providers-and-tasks): various serverless solution provided by external providers (Together, Sambanova, etc.). These providers offer production-ready APIs on a pay-as-you-go model. This is the fastest way to integrate AI in your products with a maintenance-free and scalable solution. Refer to the [Supported providers and tasks](#supported-providers-and-tasks) section for a list of supported providers. 3. [Inference Endpoints](https://huggingface.co/docs/inference-endpoints/index): a product to easily deploy models to production. Inference is run by Hugging Face in a dedicated, fully managed infrastructure on a cloud provider of your choice. These services can all be called from the [`InferenceClient`] object. It acts as a replacement for the legacy [`InferenceApi`] client, adding specific support for tasks and third-party providers. Learn how to migrate to the new client in the [Legacy InferenceAPI client](#legacy-inferenceapi-client) section. [`InferenceClient`] is a Python client making HTTP calls to our APIs. If you want to make the HTTP calls directly using your preferred tool (curl, postman,...), please refer to the [Inference API](https://huggingface.co/docs/api-inference/index) or to the [Inference Endpoints](https://huggingface.co/docs/inference-endpoints/index) documentation pages. For web development, a [JS client](https://huggingface.co/docs/huggingface.js/inference/README) has been released. If you are interested in game development, you might have a look at our [C# project](https://github.com/huggingface/unity-api). ## Getting started Let's get started with a text-to-image task: ```python >>> from huggingface_hub import InferenceClient # Example with an external provider (e.g. replicate) >>> replicate_client = InferenceClient( provider="replicate", api_key="my_replicate_api_key", ) >>> replicate_image = replicate_client.text_to_image( "A flying car crossing a futuristic cityscape.", model="black-forest-labs/FLUX.1-schnell", ) >>> replicate_image.save("flying_car.png") ``` In the example above, we initialized an [`InferenceClient`] with a third-party provider, [Replicate](https://replicate.com/). When using a provider, you must specify the model you want to use. The model id must be the id of the model on the Hugging Face Hub, not the id of the model from the third-party provider. In our example, we generated an image from a text prompt. The returned value is a `PIL.Image` object that can be saved to a file. For more details, check out the [`~InferenceClient.text_to_image`] documentation. Let's now see an example using the [`~InferenceClient.chat_completion`] API. This task uses an LLM to generate a response from a list of messages: ```python >>> from huggingface_hub import InferenceClient >>> messages = [ { "role": "user", "content": "What is the capital of France?", } ] >>> client = InferenceClient( provider="together", model="meta-llama/Meta-Llama-3-8B-Instruct", api_key="my_together_api_key", ) >>> client.chat_completion(messages, max_tokens=100) ChatCompletionOutput( choices=[ ChatCompletionOutputComplete( finish_reason="eos_token", index=0, message=ChatCompletionOutputMessage( role="assistant", content="The capital of France is Paris.", name=None, tool_calls=None ), logprobs=None, ) ], created=1719907176, id="", model="meta-llama/Meta-Llama-3-8B-Instruct", object="text_completion", system_fingerprint="2.0.4-sha-f426a33", usage=ChatCompletionOutputUsage(completion_tokens=8, prompt_tokens=17, total_tokens=25), ) ``` In the example above, we used a third-party provider ([Together AI](https://www.together.ai/)) and specified which model we want to use (`"meta-llama/Meta-Llama-3-8B-Instruct"`). We then gave a list of messages to complete (here, a single question) and passed an additional parameter to the API (`max_token=100`). The output is a `ChatCompletionOutput` object that follows the OpenAI specification. The generated content can be accessed with `output.choices[0].message.content`. For more details, check out the [`~InferenceClient.chat_completion`] documentation. The API is designed to be simple. Not all parameters and options are available or described for the end user. Check out [this page](https://huggingface.co/docs/api-inference/detailed_parameters) if you are interested in learning more about all the parameters available for each task. ### Using a specific provider If you want to use a specific provider, you can specify it when initializing the client. The default value is "auto" which will select the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. Refer to the [Supported providers and tasks](#supported-providers-and-tasks) section for a list of supported providers. ```python >>> from huggingface_hub import InferenceClient >>> client = InferenceClient(provider="replicate", api_key="my_replicate_api_key") ``` ### Using a specific model What if you want to use a specific model? You can specify it either as a parameter or directly at an instance level: ```python >>> from huggingface_hub import InferenceClient # Initialize client for a specific model >>> client = InferenceClient(provider="together", model="meta-llama/Llama-3.1-8B-Instruct") >>> client.text_to_image(...) # Or use a generic client but pass your model as an argument >>> client = InferenceClient(provider="together") >>> client.text_to_image(..., model="meta-llama/Llama-3.1-8B-Instruct") ``` When using the "hf-inference" provider, each task comes with a recommended model from the 1M+ models available on the Hub. However, this recommendation can change over time, so it's best to explicitly set a model once you've decided which one to use. For third-party providers, you must always specify a model that is compatible with that provider. Visit the [Models](https://huggingface.co/models?inference=warm) page on the Hub to explore models available through Inference Providers. ### Using a specific URL The examples we saw above use either the Hugging Face Inference API or third-party providers. While these prove to be very useful for prototyping and testing things quickly. Once you're ready to deploy your model to production, you'll need to use a dedicated infrastructure. That's where [Inference Endpoints](https://huggingface.co/docs/inference-endpoints/index) comes into play. It allows you to deploy any model and expose it as a private API. Once deployed, you'll get a URL that you can connect to using exactly the same code as before, changing only the `model` parameter: ```python >>> from huggingface_hub import InferenceClient >>> client = InferenceClient(model="https://uu149rez6gw9ehej.eu-west-1.aws.endpoints.huggingface.cloud/deepfloyd-if") # or >>> client = InferenceClient() >>> client.text_to_image(..., model="https://uu149rez6gw9ehej.eu-west-1.aws.endpoints.huggingface.cloud/deepfloyd-if") ``` Note that you cannot specify both a URL and a provider - they are mutually exclusive. URLs are used to connect directly to deployed endpoints. ### Authentication Authentication can be done in two ways: **Routed through Hugging Face** : Use Hugging Face as a proxy to access third-party providers. The calls will be routed through Hugging Face's infrastructure using our provider keys, and the usage will be billed directly to your Hugging Face account. You can authenticate using a [User Access Token](https://huggingface.co/docs/hub/security-tokens). You can provide your Hugging Face token directly using the `api_key` parameter: ```python >>> client = InferenceClient( provider="replicate", api_key="hf_****" # Your HF token ) ``` If you *don't* pass an `api_key`, the client will attempt to find and use a token stored locally on your machine. This typically happens if you've previously logged in. See the [Authentication Guide](https://huggingface.co/docs/huggingface_hub/quick-start#authentication) for details on login. ```python >>> client = InferenceClient( provider="replicate", token="hf_****" # Your HF token ) ``` **Direct access to provider**: Use your own API key to interact directly with the provider's service: ```python >>> client = InferenceClient( provider="replicate", api_key="r8_****" # Your Replicate API key ) ``` For more details, refer to the [Inference Providers pricing documentation](https://huggingface.co/docs/inference-providers/pricing#routed-requests-vs-direct-calls). ## Supported providers and tasks [`InferenceClient`]'s goal is to provide the easiest interface to run inference on Hugging Face models, on any provider. It has a simple API that supports the most common tasks. Here is a table showing which providers support which tasks: | Task | Black Forest Labs | Cerebras | Cohere | fal-ai | Fireworks AI | HF Inference | Hyperbolic | Nebius AI Studio | Novita AI | Replicate | Sambanova | Together | | --------------------------------------------------- | ----------------- | -------- | ------ | ------ | ------------ | ------------ | ---------- | ---------------- | --------- | --------- | --------- | -------- | | [`~InferenceClient.audio_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.audio_to_audio`] | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.automatic_speech_recognition`] | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.chat_completion`] | ❌ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | [`~InferenceClient.document_question_answering`] | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.feature_extraction`] | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | | [`~InferenceClient.fill_mask`] | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.image_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.image_segmentation`] | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.image_to_image`] | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.image_to_text`] | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.object_detection`] | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.question_answering`] | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.sentence_similarity`] | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.summarization`] | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.table_question_answering`] | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.text_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.text_generation`] | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | | [`~InferenceClient.text_to_image`] | ✅ | ❌ | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ✅ | | [`~InferenceClient.text_to_speech`] | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | | [`~InferenceClient.text_to_video`] | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | | [`~InferenceClient.tabular_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.tabular_regression`] | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.token_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.translation`] | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.visual_question_answering`] | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.zero_shot_image_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.zero_shot_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | Check out the [Tasks](https://huggingface.co/tasks) page to learn more about each task. ## OpenAI compatibility The `chat_completion` task follows [OpenAI's Python client](https://github.com/openai/openai-python) syntax. What does it mean for you? It means that if you are used to play with `OpenAI`'s APIs you will be able to switch to `huggingface_hub.InferenceClient` to work with open-source models by updating just 2 line of code! ```diff - from openai import OpenAI + from huggingface_hub import InferenceClient - client = OpenAI( + client = InferenceClient( base_url=..., api_key=..., ) output = client.chat.completions.create( model="meta-llama/Meta-Llama-3-8B-Instruct", messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Count to 10"}, ], stream=True, max_tokens=1024, ) for chunk in output: print(chunk.choices[0].delta.content) ``` And that's it! The only required changes are to replace `from openai import OpenAI` by `from huggingface_hub import InferenceClient` and `client = OpenAI(...)` by `client = InferenceClient(...)`. You can choose any LLM model from the Hugging Face Hub by passing its model id as `model` parameter. [Here is a list](https://huggingface.co/models?pipeline_tag=text-generation&other=conversational,text-generation-inference&sort=trending) of supported models. For authentication, you should pass a valid [User Access Token](https://huggingface.co/settings/tokens) as `api_key` or authenticate using `huggingface_hub` (see the [authentication guide](https://huggingface.co/docs/huggingface_hub/quick-start#authentication)). All input parameters and output format are strictly the same. In particular, you can pass `stream=True` to receive tokens as they are generated. You can also use the [`AsyncInferenceClient`] to run inference using `asyncio`: ```diff import asyncio - from openai import AsyncOpenAI + from huggingface_hub import AsyncInferenceClient - client = AsyncOpenAI() + client = AsyncInferenceClient() async def main(): stream = await client.chat.completions.create( model="meta-llama/Meta-Llama-3-8B-Instruct", messages=[{"role": "user", "content": "Say this is a test"}], stream=True, ) async for chunk in stream: print(chunk.choices[0].delta.content or "", end="") asyncio.run(main()) ``` You might wonder why using [`InferenceClient`] instead of OpenAI's client? There are a few reasons for that: 1. [`InferenceClient`] is configured for Hugging Face services. You don't need to provide a `base_url` to run models on the serverless Inference API. You also don't need to provide a `token` or `api_key` if your machine is already correctly logged in. 2. [`InferenceClient`] is tailored for both Text-Generation-Inference (TGI) and `transformers` frameworks, meaning you are assured it will always be on-par with the latest updates. 3. [`InferenceClient`] is integrated with our Inference Endpoints service, making it easier to launch an Inference Endpoint, check its status and run inference on it. Check out the [Inference Endpoints](./inference_endpoints.md) guide for more details. `InferenceClient.chat.completions.create` is simply an alias for `InferenceClient.chat_completion`. Check out the package reference of [`~InferenceClient.chat_completion`] for more details. `base_url` and `api_key` parameters when instantiating the client are also aliases for `model` and `token`. These aliases have been defined to reduce friction when switching from `OpenAI` to `InferenceClient`. ## Async client An async version of the client is also provided, based on `asyncio` and `aiohttp`. You can either install `aiohttp` directly or use the `[inference]` extra: ```sh pip install --upgrade huggingface_hub[inference] # or # pip install aiohttp ``` After installation all async API endpoints are available via [`AsyncInferenceClient`]. Its initialization and APIs are strictly the same as the sync-only version. ```py # Code must be run in an asyncio concurrent context. # $ python -m asyncio >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> image = await client.text_to_image("An astronaut riding a horse on the moon.") >>> image.save("astronaut.png") >>> async for token in await client.text_generation("The Huggingface Hub is", stream=True): ... print(token, end="") a platform for sharing and discussing ML-related content. ``` For more information about the `asyncio` module, please refer to the [official documentation](https://docs.python.org/3/library/asyncio.html). ## Advanced tips In the above section, we saw the main aspects of [`InferenceClient`]. Let's dive into some more advanced tips. ### Billing As an HF user, you get monthly credits to run inference through various providers on the Hub. The amount of credits you get depends on your type of account (Free or PRO or Enterprise Hub). You get charged for every inference request, depending on the provider's pricing table. By default, the requests are billed to your personal account. However, it is possible to set the billing so that requests are charged to an organization you are part of by simply passing `bill_to=""` to `InferenceClient`. For this to work, your organization must be subscribed to Enterprise Hub. For more details about billing, check out [this guide](https://huggingface.co/docs/api-inference/pricing#features-using-inference-providers). ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient(provider="fal-ai", bill_to="openai") >>> image = client.text_to_image( ... "A majestic lion in a fantasy forest", ... model="black-forest-labs/FLUX.1-schnell", ... ) >>> image.save("lion.png") ``` Note that it is NOT possible to charge another user or an organization you are not part of. If you want to grant someone else some credits, you must create a joint organization with them. ### Timeout Inference calls can take a significant amount of time. By default, [`InferenceClient`] will wait "indefinitely" until the inference complete. If you want more control in your workflow, you can set the `timeout` parameter to a specific value in seconds. If the timeout delay expires, an [`InferenceTimeoutError`] is raised, which you can catch in your code: ```python >>> from huggingface_hub import InferenceClient, InferenceTimeoutError >>> client = InferenceClient(timeout=30) >>> try: ... client.text_to_image(...) ... except InferenceTimeoutError: ... print("Inference timed out after 30s.") ``` ### Binary inputs Some tasks require binary inputs, for example, when dealing with images or audio files. In this case, [`InferenceClient`] tries to be as permissive as possible and accept different types: - raw `bytes` - a file-like object, opened as binary (`with open("audio.flac", "rb") as f: ...`) - a path (`str` or `Path`) pointing to a local file - a URL (`str`) pointing to a remote file (e.g. `https://...`). In this case, the file will be downloaded locally before sending it to the Inference API. ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> client.image_classification("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg") [{'score': 0.9779096841812134, 'label': 'Blenheim spaniel'}, ...] ``` huggingface_hub-0.31.1/docs/source/en/guides/inference_endpoints.md000066400000000000000000000334231500667546600254020ustar00rootroot00000000000000# Inference Endpoints Inference Endpoints provides a secure production solution to easily deploy any `transformers`, `sentence-transformers`, and `diffusers` models on a dedicated and autoscaling infrastructure managed by Hugging Face. An Inference Endpoint is built from a model from the [Hub](https://huggingface.co/models). In this guide, we will learn how to programmatically manage Inference Endpoints with `huggingface_hub`. For more information about the Inference Endpoints product itself, check out its [official documentation](https://huggingface.co/docs/inference-endpoints/index). This guide assumes `huggingface_hub` is correctly installed and that your machine is logged in. Check out the [Quick Start guide](https://huggingface.co/docs/huggingface_hub/quick-start#quickstart) if that's not the case yet. The minimal version supporting Inference Endpoints API is `v0.19.0`. **New:** it is now possible to deploy an Inference Endpoint from the [HF model catalog](https://endpoints.huggingface.co/catalog) with a simple API call. The catalog is a carefully curated list of models that can be deployed with optimized settings. You don't need to configure anything, we take all the heavy stuff on us! All models and settings are guaranteed to have been tested to provide best cost/performance balance. [`create_inference_endpoint_from_catalog`] works the same as [`create_inference_endpoint`], with much less parameters to pass. You can use [`list_inference_catalog`] to programmatically retrieve the catalog. Note that this is still an experimental feature. Let us know what you think if you use it! ## Create an Inference Endpoint The first step is to create an Inference Endpoint using [`create_inference_endpoint`]: ```py >>> from huggingface_hub import create_inference_endpoint >>> endpoint = create_inference_endpoint( ... "my-endpoint-name", ... repository="gpt2", ... framework="pytorch", ... task="text-generation", ... accelerator="cpu", ... vendor="aws", ... region="us-east-1", ... type="protected", ... instance_size="x2", ... instance_type="intel-icl" ... ) ``` In this example, we created a `protected` Inference Endpoint named `"my-endpoint-name"`, to serve [gpt2](https://huggingface.co/gpt2) for `text-generation`. A `protected` Inference Endpoint means your token is required to access the API. We also need to provide additional information to configure the hardware requirements, such as vendor, region, accelerator, instance type, and size. You can check out the list of available resources [here](https://api.endpoints.huggingface.cloud/#/v2%3A%3Aprovider/list_vendors). Alternatively, you can create an Inference Endpoint manually using the [Web interface](https://ui.endpoints.huggingface.co/new) for convenience. Refer to this [guide](https://huggingface.co/docs/inference-endpoints/guides/advanced) for details on advanced settings and their usage. The value returned by [`create_inference_endpoint`] is an [`InferenceEndpoint`] object: ```py >>> endpoint InferenceEndpoint(name='my-endpoint-name', namespace='Wauplin', repository='gpt2', status='pending', url=None) ``` It's a dataclass that holds information about the endpoint. You can access important attributes such as `name`, `repository`, `status`, `task`, `created_at`, `updated_at`, etc. If you need it, you can also access the raw response from the server with `endpoint.raw`. Once your Inference Endpoint is created, you can find it on your [personal dashboard](https://ui.endpoints.huggingface.co/). ![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/huggingface_hub/inference_endpoints_created.png) #### Using a custom image By default the Inference Endpoint is built from a docker image provided by Hugging Face. However, it is possible to specify any docker image using the `custom_image` parameter. A common use case is to run LLMs using the [text-generation-inference](https://github.com/huggingface/text-generation-inference) framework. This can be done like this: ```python # Start an Inference Endpoint running Zephyr-7b-beta on TGI >>> from huggingface_hub import create_inference_endpoint >>> endpoint = create_inference_endpoint( ... "aws-zephyr-7b-beta-0486", ... repository="HuggingFaceH4/zephyr-7b-beta", ... framework="pytorch", ... task="text-generation", ... accelerator="gpu", ... vendor="aws", ... region="us-east-1", ... type="protected", ... instance_size="x1", ... instance_type="nvidia-a10g", ... custom_image={ ... "health_route": "/health", ... "env": { ... "MAX_BATCH_PREFILL_TOKENS": "2048", ... "MAX_INPUT_LENGTH": "1024", ... "MAX_TOTAL_TOKENS": "1512", ... "MODEL_ID": "/repository" ... }, ... "url": "ghcr.io/huggingface/text-generation-inference:1.1.0", ... }, ... ) ``` The value to pass as `custom_image` is a dictionary containing a url to the docker container and configuration to run it. For more details about it, checkout the [Swagger documentation](https://api.endpoints.huggingface.cloud/#/v2%3A%3Aendpoint/create_endpoint). ### Get or list existing Inference Endpoints In some cases, you might need to manage Inference Endpoints you created previously. If you know the name, you can fetch it using [`get_inference_endpoint`], which returns an [`InferenceEndpoint`] object. Alternatively, you can use [`list_inference_endpoints`] to retrieve a list of all Inference Endpoints. Both methods accept an optional `namespace` parameter. You can set the `namespace` to any organization you are a part of. Otherwise, it defaults to your username. ```py >>> from huggingface_hub import get_inference_endpoint, list_inference_endpoints # Get one >>> get_inference_endpoint("my-endpoint-name") InferenceEndpoint(name='my-endpoint-name', namespace='Wauplin', repository='gpt2', status='pending', url=None) # List all endpoints from an organization >>> list_inference_endpoints(namespace="huggingface") [InferenceEndpoint(name='aws-starchat-beta', namespace='huggingface', repository='HuggingFaceH4/starchat-beta', status='paused', url=None), ...] # List all endpoints from all organizations the user belongs to >>> list_inference_endpoints(namespace="*") [InferenceEndpoint(name='aws-starchat-beta', namespace='huggingface', repository='HuggingFaceH4/starchat-beta', status='paused', url=None), ...] ``` ## Check deployment status In the rest of this guide, we will assume that we have a [`InferenceEndpoint`] object called `endpoint`. You might have noticed that the endpoint has a `status` attribute of type [`InferenceEndpointStatus`]. When the Inference Endpoint is deployed and accessible, the status should be `"running"` and the `url` attribute is set: ```py >>> endpoint InferenceEndpoint(name='my-endpoint-name', namespace='Wauplin', repository='gpt2', status='running', url='https://jpj7k2q4j805b727.us-east-1.aws.endpoints.huggingface.cloud') ``` Before reaching a `"running"` state, the Inference Endpoint typically goes through an `"initializing"` or `"pending"` phase. You can fetch the new state of the endpoint by running [`~InferenceEndpoint.fetch`]. Like every other method from [`InferenceEndpoint`] that makes a request to the server, the internal attributes of `endpoint` are mutated in place: ```py >>> endpoint.fetch() InferenceEndpoint(name='my-endpoint-name', namespace='Wauplin', repository='gpt2', status='pending', url=None) ``` Instead of fetching the Inference Endpoint status while waiting for it to run, you can directly call [`~InferenceEndpoint.wait`]. This helper takes as input a `timeout` and a `fetch_every` parameter (in seconds) and will block the thread until the Inference Endpoint is deployed. Default values are respectively `None` (no timeout) and `5` seconds. ```py # Pending endpoint >>> endpoint InferenceEndpoint(name='my-endpoint-name', namespace='Wauplin', repository='gpt2', status='pending', url=None) # Wait 10s => raises a InferenceEndpointTimeoutError >>> endpoint.wait(timeout=10) raise InferenceEndpointTimeoutError("Timeout while waiting for Inference Endpoint to be deployed.") huggingface_hub._inference_endpoints.InferenceEndpointTimeoutError: Timeout while waiting for Inference Endpoint to be deployed. # Wait more >>> endpoint.wait() InferenceEndpoint(name='my-endpoint-name', namespace='Wauplin', repository='gpt2', status='running', url='https://jpj7k2q4j805b727.us-east-1.aws.endpoints.huggingface.cloud') ``` If `timeout` is set and the Inference Endpoint takes too much time to load, a [`InferenceEndpointTimeoutError`] timeout error is raised. ## Run inference Once your Inference Endpoint is up and running, you can finally run inference on it! [`InferenceEndpoint`] has two properties `client` and `async_client` returning respectively an [`InferenceClient`] and an [`AsyncInferenceClient`] objects. ```py # Run text_generation task: >>> endpoint.client.text_generation("I am") ' not a fan of the idea of a "big-budget" movie. I think it\'s a' # Or in an asyncio context: >>> await endpoint.async_client.text_generation("I am") ``` If the Inference Endpoint is not running, an [`InferenceEndpointError`] exception is raised: ```py >>> endpoint.client huggingface_hub._inference_endpoints.InferenceEndpointError: Cannot create a client for this Inference Endpoint as it is not yet deployed. Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again. ``` For more details about how to use the [`InferenceClient`], check out the [Inference guide](../guides/inference). ## Manage lifecycle Now that we saw how to create an Inference Endpoint and run inference on it, let's see how to manage its lifecycle. In this section, we will see methods like [`~InferenceEndpoint.pause`], [`~InferenceEndpoint.resume`], [`~InferenceEndpoint.scale_to_zero`], [`~InferenceEndpoint.update`] and [`~InferenceEndpoint.delete`]. All of those methods are aliases added to [`InferenceEndpoint`] for convenience. If you prefer, you can also use the generic methods defined in `HfApi`: [`pause_inference_endpoint`], [`resume_inference_endpoint`], [`scale_to_zero_inference_endpoint`], [`update_inference_endpoint`], and [`delete_inference_endpoint`]. ### Pause or scale to zero To reduce costs when your Inference Endpoint is not in use, you can choose to either pause it using [`~InferenceEndpoint.pause`] or scale it to zero using [`~InferenceEndpoint.scale_to_zero`]. An Inference Endpoint that is *paused* or *scaled to zero* doesn't cost anything. The difference between those two is that a *paused* endpoint needs to be explicitly *resumed* using [`~InferenceEndpoint.resume`]. On the contrary, a *scaled to zero* endpoint will automatically start if an inference call is made to it, with an additional cold start delay. An Inference Endpoint can also be configured to scale to zero automatically after a certain period of inactivity. ```py # Pause and resume endpoint >>> endpoint.pause() InferenceEndpoint(name='my-endpoint-name', namespace='Wauplin', repository='gpt2', status='paused', url=None) >>> endpoint.resume() InferenceEndpoint(name='my-endpoint-name', namespace='Wauplin', repository='gpt2', status='pending', url=None) >>> endpoint.wait().client.text_generation(...) ... # Scale to zero >>> endpoint.scale_to_zero() InferenceEndpoint(name='my-endpoint-name', namespace='Wauplin', repository='gpt2', status='scaledToZero', url='https://jpj7k2q4j805b727.us-east-1.aws.endpoints.huggingface.cloud') # Endpoint is not 'running' but still has a URL and will restart on first call. ``` ### Update model or hardware requirements In some cases, you might also want to update your Inference Endpoint without creating a new one. You can either update the hosted model or the hardware requirements to run the model. You can do this using [`~InferenceEndpoint.update`]: ```py # Change target model >>> endpoint.update(repository="gpt2-large") InferenceEndpoint(name='my-endpoint-name', namespace='Wauplin', repository='gpt2-large', status='pending', url=None) # Update number of replicas >>> endpoint.update(min_replica=2, max_replica=6) InferenceEndpoint(name='my-endpoint-name', namespace='Wauplin', repository='gpt2-large', status='pending', url=None) # Update to larger instance >>> endpoint.update(accelerator="cpu", instance_size="x4", instance_type="intel-icl") InferenceEndpoint(name='my-endpoint-name', namespace='Wauplin', repository='gpt2-large', status='pending', url=None) ``` ### Delete the endpoint Finally if you won't use the Inference Endpoint anymore, you can simply call [`~InferenceEndpoint.delete()`]. This is a non-revertible action that will completely remove the endpoint, including its configuration, logs and usage metrics. You cannot restore a deleted Inference Endpoint. ## An end-to-end example A typical use case of Inference Endpoints is to process a batch of jobs at once to limit the infrastructure costs. You can automate this process using what we saw in this guide: ```py >>> import asyncio >>> from huggingface_hub import create_inference_endpoint # Start endpoint + wait until initialized >>> endpoint = create_inference_endpoint(name="batch-endpoint",...).wait() # Run inference >>> client = endpoint.client >>> results = [client.text_generation(...) for job in jobs] # Or with asyncio >>> async_client = endpoint.async_client >>> results = asyncio.gather(*[async_client.text_generation(...) for job in jobs]) # Pause endpoint >>> endpoint.pause() ``` Or if your Inference Endpoint already exists and is paused: ```py >>> import asyncio >>> from huggingface_hub import get_inference_endpoint # Get endpoint + wait until initialized >>> endpoint = get_inference_endpoint("batch-endpoint").resume().wait() # Run inference >>> async_client = endpoint.async_client >>> results = asyncio.gather(*[async_client.text_generation(...) for job in jobs]) # Pause endpoint >>> endpoint.pause() ``` huggingface_hub-0.31.1/docs/source/en/guides/integrations.md000066400000000000000000000511701500667546600240660ustar00rootroot00000000000000 # Integrate any ML framework with the Hub The Hugging Face Hub makes hosting and sharing models with the community easy. It supports [dozens of libraries](https://huggingface.co/docs/hub/models-libraries) in the Open Source ecosystem. We are always working on expanding this support to push collaborative Machine Learning forward. The `huggingface_hub` library plays a key role in this process, allowing any Python script to easily push and load files. There are four main ways to integrate a library with the Hub: 1. **Push to Hub:** implement a method to upload a model to the Hub. This includes the model weights, as well as [the model card](https://huggingface.co/docs/huggingface_hub/how-to-model-cards) and any other relevant information or data necessary to run the model (for example, training logs). This method is often called `push_to_hub()`. 2. **Download from Hub:** implement a method to load a model from the Hub. The method should download the model configuration/weights and load the model. This method is often called `from_pretrained` or `load_from_hub()`. 3. **Inference API:** use our servers to run inference on models supported by your library for free. 4. **Widgets:** display a widget on the landing page of your models on the Hub. It allows users to quickly try a model from the browser. In this guide, we will focus on the first two topics. We will present the two main approaches you can use to integrate a library, with their advantages and drawbacks. Everything is summarized at the end of the guide to help you choose between the two. Please keep in mind that these are only guidelines that you are free to adapt to you requirements. If you are interested in Inference and Widgets, you can follow [this guide](https://huggingface.co/docs/hub/models-adding-libraries#set-up-the-inference-api). In both cases, you can reach out to us if you are integrating a library with the Hub and want to be listed [in our docs](https://huggingface.co/docs/hub/models-libraries). ## A flexible approach: helpers The first approach to integrate a library to the Hub is to actually implement the `push_to_hub` and `from_pretrained` methods by yourself. This gives you full flexibility on which files you need to upload/download and how to handle inputs specific to your framework. You can refer to the two [upload files](./upload) and [download files](./download) guides to learn more about how to do that. This is, for example how the FastAI integration is implemented (see [`push_to_hub_fastai`] and [`from_pretrained_fastai`]). Implementation can differ between libraries, but the workflow is often similar. ### from_pretrained This is how a `from_pretrained` method usually looks like: ```python def from_pretrained(model_id: str) -> MyModelClass: # Download model from Hub cached_model = hf_hub_download( repo_id=repo_id, filename="model.pkl", library_name="fastai", library_version=get_fastai_version(), ) # Load model return load_model(cached_model) ``` ### push_to_hub The `push_to_hub` method often requires a bit more complexity to handle repo creation, generate the model card and save weights. A common approach is to save all of these files in a temporary folder, upload it and then delete it. ```python def push_to_hub(model: MyModelClass, repo_name: str) -> None: api = HfApi() # Create repo if not existing yet and get the associated repo_id repo_id = api.create_repo(repo_name, exist_ok=True) # Save all files in a temporary directory and push them in a single commit with TemporaryDirectory() as tmpdir: tmpdir = Path(tmpdir) # Save weights save_model(model, tmpdir / "model.safetensors") # Generate model card card = generate_model_card(model) (tmpdir / "README.md").write_text(card) # Save logs # Save figures # Save evaluation metrics # ... # Push to hub return api.upload_folder(repo_id=repo_id, folder_path=tmpdir) ``` This is of course only an example. If you are interested in more complex manipulations (delete remote files, upload weights on the fly, persist weights locally, etc.) please refer to the [upload files](./upload) guide. ### Limitations While being flexible, this approach has some drawbacks, especially in terms of maintenance. Hugging Face users are often used to additional features when working with `huggingface_hub`. For example, when loading files from the Hub, it is common to offer parameters like: - `token`: to download from a private repo - `revision`: to download from a specific branch - `cache_dir`: to cache files in a specific directory - `force_download`/`local_files_only`: to reuse the cache or not - `proxies`: configure HTTP session When pushing models, similar parameters are supported: - `commit_message`: custom commit message - `private`: create a private repo if missing - `create_pr`: create a PR instead of pushing to `main` - `branch`: push to a branch instead of the `main` branch - `allow_patterns`/`ignore_patterns`: filter which files to upload - `token` - ... All of these parameters can be added to the implementations we saw above and passed to the `huggingface_hub` methods. However, if a parameter changes or a new feature is added, you will need to update your package. Supporting those parameters also means more documentation to maintain on your side. To see how to mitigate these limitations, let's jump to our next section **class inheritance**. ## A more complex approach: class inheritance As we saw above, there are two main methods to include in your library to integrate it with the Hub: upload files (`push_to_hub`) and download files (`from_pretrained`). You can implement those methods by yourself but it comes with caveats. To tackle this, `huggingface_hub` provides a tool that uses class inheritance. Let's see how it works! In a lot of cases, a library already implements its model using a Python class. The class contains the properties of the model and methods to load, run, train, and evaluate it. Our approach is to extend this class to include upload and download features using mixins. A [Mixin](https://stackoverflow.com/a/547714) is a class that is meant to extend an existing class with a set of specific features using multiple inheritance. `huggingface_hub` provides its own mixin, the [`ModelHubMixin`]. The key here is to understand its behavior and how to customize it. The [`ModelHubMixin`] class implements 3 *public* methods (`push_to_hub`, `save_pretrained` and `from_pretrained`). Those are the methods that your users will call to load/save models with your library. [`ModelHubMixin`] also defines 2 *private* methods (`_save_pretrained` and `_from_pretrained`). Those are the ones you must implement. So to integrate your library, you should: 1. Make your Model class inherit from [`ModelHubMixin`]. 2. Implement the private methods: - [`~ModelHubMixin._save_pretrained`]: method taking as input a path to a directory and saving the model to it. You must write all the logic to dump your model in this method: model card, model weights, configuration files, training logs, and figures. Any relevant information for this model must be handled by this method. [Model Cards](https://huggingface.co/docs/hub/model-cards) are particularly important to describe your model. Check out [our implementation guide](./model-cards) for more details. - [`~ModelHubMixin._from_pretrained`]: **class method** taking as input a `model_id` and returning an instantiated model. The method must download the relevant files and load them. 3. You are done! The advantage of using [`ModelHubMixin`] is that once you take care of the serialization/loading of the files, you are ready to go. You don't need to worry about stuff like repo creation, commits, PRs, or revisions. The [`ModelHubMixin`] also ensures public methods are documented and type annotated, and you'll be able to view your model's download count on the Hub. All of this is handled by the [`ModelHubMixin`] and available to your users. ### A concrete example: PyTorch A good example of what we saw above is [`PyTorchModelHubMixin`], our integration for the PyTorch framework. This is a ready-to-use integration. #### How to use it? Here is how any user can load/save a PyTorch model from/to the Hub: ```python >>> import torch >>> import torch.nn as nn >>> from huggingface_hub import PyTorchModelHubMixin # Define your Pytorch model exactly the same way you are used to >>> class MyModel( ... nn.Module, ... PyTorchModelHubMixin, # multiple inheritance ... library_name="keras-nlp", ... tags=["keras"], ... repo_url="https://github.com/keras-team/keras-nlp", ... docs_url="https://keras.io/keras_nlp/", ... # ^ optional metadata to generate model card ... ): ... def __init__(self, hidden_size: int = 512, vocab_size: int = 30000, output_size: int = 4): ... super().__init__() ... self.param = nn.Parameter(torch.rand(hidden_size, vocab_size)) ... self.linear = nn.Linear(output_size, vocab_size) ... def forward(self, x): ... return self.linear(x + self.param) # 1. Create model >>> model = MyModel(hidden_size=128) # Config is automatically created based on input + default values >>> model.param.shape[0] 128 # 2. (optional) Save model to local directory >>> model.save_pretrained("path/to/my-awesome-model") # 3. Push model weights to the Hub >>> model.push_to_hub("my-awesome-model") # 4. Initialize model from the Hub => config has been preserved >>> model = MyModel.from_pretrained("username/my-awesome-model") >>> model.param.shape[0] 128 # Model card has been correctly populated >>> from huggingface_hub import ModelCard >>> card = ModelCard.load("username/my-awesome-model") >>> card.data.tags ["keras", "pytorch_model_hub_mixin", "model_hub_mixin"] >>> card.data.library_name "keras-nlp" ``` #### Implementation The implementation is actually very straightforward, and the full implementation can be found [here](https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/hub_mixin.py). 1. First, inherit your class from `ModelHubMixin`: ```python from huggingface_hub import ModelHubMixin class PyTorchModelHubMixin(ModelHubMixin): (...) ``` 2. Implement the `_save_pretrained` method: ```py from huggingface_hub import ModelHubMixin class PyTorchModelHubMixin(ModelHubMixin): (...) def _save_pretrained(self, save_directory: Path) -> None: """Save weights from a Pytorch model to a local directory.""" save_model_as_safetensor(self.module, str(save_directory / SAFETENSORS_SINGLE_FILE)) ``` 3. Implement the `_from_pretrained` method: ```python class PyTorchModelHubMixin(ModelHubMixin): (...) @classmethod # Must be a classmethod! def _from_pretrained( cls, *, model_id: str, revision: str, cache_dir: str, force_download: bool, proxies: Optional[Dict], resume_download: bool, local_files_only: bool, token: Union[str, bool, None], map_location: str = "cpu", # additional argument strict: bool = False, # additional argument **model_kwargs, ): """Load Pytorch pretrained weights and return the loaded model.""" model = cls(**model_kwargs) if os.path.isdir(model_id): print("Loading weights from local directory") model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE) return cls._load_as_safetensor(model, model_file, map_location, strict) model_file = hf_hub_download( repo_id=model_id, filename=SAFETENSORS_SINGLE_FILE, revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, token=token, local_files_only=local_files_only, ) return cls._load_as_safetensor(model, model_file, map_location, strict) ``` And that's it! Your library now enables users to upload and download files to and from the Hub. ### Advanced usage In the section above, we quickly discussed how the [`ModelHubMixin`] works. In this section, we will see some of its more advanced features to improve your library integration with the Hugging Face Hub. #### Model card [`ModelHubMixin`] generates the model card for you. Model cards are files that accompany the models and provide important information about them. Under the hood, model cards are simple Markdown files with additional metadata. Model cards are essential for discoverability, reproducibility, and sharing! Check out the [Model Cards guide](https://huggingface.co/docs/hub/model-cards) for more details. Generating model cards semi-automatically is a good way to ensure that all models pushed with your library will share common metadata: `library_name`, `tags`, `license`, `pipeline_tag`, etc. This makes all models backed by your library easily searchable on the Hub and provides some resource links for users landing on the Hub. You can define the metadata directly when inheriting from [`ModelHubMixin`]: ```py class UniDepthV1( nn.Module, PyTorchModelHubMixin, library_name="unidepth", repo_url="https://github.com/lpiccinelli-eth/UniDepth", docs_url=..., pipeline_tag="depth-estimation", license="cc-by-nc-4.0", tags=["monocular-metric-depth-estimation", "arxiv:1234.56789"] ): ... ``` By default, a generic model card will be generated with the info you've provided (example: [pyp1/VoiceCraft_giga830M](https://huggingface.co/pyp1/VoiceCraft_giga830M)). But you can define your own model card template as well! In this example, all models pushed with the `VoiceCraft` class will automatically include a citation section and license details. For more details on how to define a model card template, please check the [Model Cards guide](./model-cards). ```py MODEL_CARD_TEMPLATE = """ --- # For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 # Doc / guide: https://huggingface.co/docs/hub/model-cards {{ card_data }} --- This is a VoiceCraft model. For more details, please check out the official Github repo: https://github.com/jasonppy/VoiceCraft. This model is shared under a Attribution-NonCommercial-ShareAlike 4.0 International license. ## Citation @article{peng2024voicecraft, author = {Peng, Puyuan and Huang, Po-Yao and Li, Daniel and Mohamed, Abdelrahman and Harwath, David}, title = {VoiceCraft: Zero-Shot Speech Editing and Text-to-Speech in the Wild}, journal = {arXiv}, year = {2024}, } """ class VoiceCraft( nn.Module, PyTorchModelHubMixin, library_name="voicecraft", model_card_template=MODEL_CARD_TEMPLATE, ... ): ... ``` Finally, if you want to extend the model card generation process with dynamic values, you can override the [`~ModelHubMixin.generate_model_card`] method: ```py from huggingface_hub import ModelCard, PyTorchModelHubMixin class UniDepthV1(nn.Module, PyTorchModelHubMixin, ...): (...) def generate_model_card(self, *args, **kwargs) -> ModelCard: card = super().generate_model_card(*args, **kwargs) card.data.metrics = ... # add metrics to the metadata card.text += ... # append section to the modelcard return card ``` #### Config [`ModelHubMixin`] handles the model configuration for you. It automatically checks the input values when you instantiate the model and serializes them in a `config.json` file. This provides 2 benefits: 1. Users will be able to reload the model with the exact same parameters as you. 2. Having a `config.json` file automatically enables analytics on the Hub (i.e. the "downloads" count). But how does it work in practice? Several rules make the process as smooth as possible from a user perspective: - if your `__init__` method expects a `config` input, it will be automatically saved in the repo as `config.json`. - if the `config` input parameter is annotated with a dataclass type (e.g. `config: Optional[MyConfigClass] = None`), then the `config` value will be correctly deserialized for you. - all values passed at initialization will also be stored in the config file. This means you don't necessarily have to expect a `config` input to benefit from it. Example: ```py class MyModel(ModelHubMixin): def __init__(value: str, size: int = 3): self.value = value self.size = size (...) # implement _save_pretrained / _from_pretrained model = MyModel(value="my_value") model.save_pretrained(...) # config.json contains passed and default values {"value": "my_value", "size": 3} ``` But what if a value cannot be serialized as JSON? By default, the value will be ignored when saving the config file. However, in some cases your library already expects a custom object as input that cannot be serialized, and you don't want to update your internal logic to update its type. No worries! You can pass custom encoders/decoders for any type when inheriting from [`ModelHubMixin`]. This is a bit more work but ensures your internal logic is untouched when integrating your library with the Hub. Here is a concrete example where a class expects a `argparse.Namespace` config as input: ```py class VoiceCraft(nn.Module): def __init__(self, args): self.pattern = self.args.pattern self.hidden_size = self.args.hidden_size ... ``` One solution can be to update the `__init__` signature to `def __init__(self, pattern: str, hidden_size: int)` and update all snippets that instantiate your class. This is a perfectly valid way to fix it but it might break downstream applications using your library. Another solution is to provide a simple encoder/decoder to convert `argparse.Namespace` to a dictionary. ```py from argparse import Namespace class VoiceCraft( nn.Module, PyTorchModelHubMixin, # inherit from mixin coders={ Namespace : ( lambda x: vars(x), # Encoder: how to convert a `Namespace` to a valid jsonable value? lambda data: Namespace(**data), # Decoder: how to reconstruct a `Namespace` from a dictionary? ) } ): def __init__(self, args: Namespace): # annotate `args` self.pattern = self.args.pattern self.hidden_size = self.args.hidden_size ... ``` In the snippet above, both the internal logic and the `__init__` signature of the class did not change. This means all existing code snippets for your library will continue to work. To achieve this, we had to: 1. Inherit from the mixin (`PytorchModelHubMixin` in this case). 2. Pass a `coders` parameter in the inheritance. This is a dictionary where keys are custom types you want to process. Values are a tuple `(encoder, decoder)`. - The encoder expects an object of the specified type as input and returns a jsonable value. This will be used when saving a model with `save_pretrained`. - The decoder expects raw data (typically a dictionary) as input and reconstructs the initial object. This will be used when loading the model with `from_pretrained`. 3. Add a type annotation to the `__init__` signature. This is important to let the mixin know which type is expected by the class and, therefore, which decoder to use. For the sake of simplicity, the encoder/decoder functions in the example above are not robust. For a concrete implementation, you would most likely have to handle corner cases properly. ## Quick comparison Let's quickly sum up the two approaches we saw with their advantages and drawbacks. The table below is only indicative. Your framework might have some specificities that you need to address. This guide is only here to give guidelines and ideas on how to handle integration. In any case, feel free to contact us if you have any questions! | Integration | Using helpers | Using [`ModelHubMixin`] | |:---:|:---:|:---:| | User experience | `model = load_from_hub(...)`
`push_to_hub(model, ...)` | `model = MyModel.from_pretrained(...)`
`model.push_to_hub(...)` | | Flexibility | Very flexible.
You fully control the implementation. | Less flexible.
Your framework must have a model class. | | Maintenance | More maintenance to add support for configuration, and new features. Might also require fixing issues reported by users. | Less maintenance as most of the interactions with the Hub are implemented in `huggingface_hub`. | | Documentation / Type annotation | To be written manually. | Partially handled by `huggingface_hub`. | | Download counter | To be handled manually. | Enabled by default if class has a `config` attribute. | | Model card | To be handled manually | Generated by default with library_name, tags, etc. | huggingface_hub-0.31.1/docs/source/en/guides/manage-cache.md000066400000000000000000000770051500667546600236560ustar00rootroot00000000000000 # Understand caching `huggingface_hub` utilizes the local disk as two caches, which avoid re-downloading items again. The first cache is a file-based cache, which caches individual files downloaded from the Hub and ensures that the same file is not downloaded again when a repo gets updated. The second cache is a chunk cache, where each chunk represents a byte range from a file and ensures that chunks that are shared across files are only downloaded once. ## File-based caching The Hugging Face Hub cache-system is designed to be the central cache shared across libraries that depend on the Hub. It has been updated in v0.8.0 to prevent re-downloading same files between revisions. The caching system is designed as follows: ``` ├─ ├─ ├─ ``` The default `` is `~/.cache/huggingface/hub`. However, it is customizable with the `cache_dir` argument on all methods, or by specifying either `HF_HOME` or `HF_HUB_CACHE` environment variable. Models, datasets and spaces share a common root. Each of these repositories contains the repository type, the namespace (organization or username) if it exists and the repository name: ``` ├─ models--julien-c--EsperBERTo-small ├─ models--lysandrejik--arxiv-nlp ├─ models--bert-base-cased ├─ datasets--glue ├─ datasets--huggingface--DataMeasurementsFiles ├─ spaces--dalle-mini--dalle-mini ``` It is within these folders that all files will now be downloaded from the Hub. Caching ensures that a file isn't downloaded twice if it already exists and wasn't updated; but if it was updated, and you're asking for the latest file, then it will download the latest file (while keeping the previous file intact in case you need it again). In order to achieve this, all folders contain the same skeleton: ``` ├─ datasets--glue │ ├─ refs │ ├─ blobs │ ├─ snapshots ... ``` Each folder is designed to contain the following: ### Refs The `refs` folder contains files which indicates the latest revision of the given reference. For example, if we have previously fetched a file from the `main` branch of a repository, the `refs` folder will contain a file named `main`, which will itself contain the commit identifier of the current head. If the latest commit of `main` has `aaaaaa` as identifier, then it will contain `aaaaaa`. If that same branch gets updated with a new commit, that has `bbbbbb` as an identifier, then re-downloading a file from that reference will update the `refs/main` file to contain `bbbbbb`. ### Blobs The `blobs` folder contains the actual files that we have downloaded. The name of each file is their hash. ### Snapshots The `snapshots` folder contains symlinks to the blobs mentioned above. It is itself made up of several folders: one per known revision! In the explanation above, we had initially fetched a file from the `aaaaaa` revision, before fetching a file from the `bbbbbb` revision. In this situation, we would now have two folders in the `snapshots` folder: `aaaaaa` and `bbbbbb`. In each of these folders, live symlinks that have the names of the files that we have downloaded. For example, if we had downloaded the `README.md` file at revision `aaaaaa`, we would have the following path: ``` //snapshots/aaaaaa/README.md ``` That `README.md` file is actually a symlink linking to the blob that has the hash of the file. By creating the skeleton this way we open the mechanism to file sharing: if the same file was fetched in revision `bbbbbb`, it would have the same hash and the file would not need to be re-downloaded. ### .no_exist (advanced) In addition to the `blobs`, `refs` and `snapshots` folders, you might also find a `.no_exist` folder in your cache. This folder keeps track of files that you've tried to download once but don't exist on the Hub. Its structure is the same as the `snapshots` folder with 1 subfolder per known revision: ``` //.no_exist/aaaaaa/config_that_does_not_exist.json ``` Unlike the `snapshots` folder, files are simple empty files (no symlinks). In this example, the file `"config_that_does_not_exist.json"` does not exist on the Hub for the revision `"aaaaaa"`. As it only stores empty files, this folder is neglectable in term of disk usage. So now you might wonder, why is this information even relevant? In some cases, a framework tries to load optional files for a model. Saving the non-existence of optional files makes it faster to load a model as it saves 1 HTTP call per possible optional file. This is for example the case in `transformers` where each tokenizer can support additional files. The first time you load the tokenizer on your machine, it will cache which optional files exist (and which doesn't) to make the loading time faster for the next initializations. To test if a file is cached locally (without making any HTTP request), you can use the [`try_to_load_from_cache`] helper. It will either return the filepath (if exists and cached), the object `_CACHED_NO_EXIST` (if non-existence is cached) or `None` (if we don't know). ```python from huggingface_hub import try_to_load_from_cache, _CACHED_NO_EXIST filepath = try_to_load_from_cache() if isinstance(filepath, str): # file exists and is cached ... elif filepath is _CACHED_NO_EXIST: # non-existence of file is cached ... else: # file is not cached ... ``` ### In practice In practice, your cache should look like the following tree: ```text [ 96] . └── [ 160] models--julien-c--EsperBERTo-small ├── [ 160] blobs │ ├── [321M] 403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd │ ├── [ 398] 7cb18dc9bafbfcf74629a4b760af1b160957a83e │ └── [1.4K] d7edf6bd2a681fb0175f7735299831ee1b22b812 ├── [ 96] refs │ └── [ 40] main └── [ 128] snapshots ├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f │ ├── [ 52] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812 │ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd └── [ 128] bbc77c8132af1cc5cf678da3f1ddf2de43606d48 ├── [ 52] README.md -> ../../blobs/7cb18dc9bafbfcf74629a4b760af1b160957a83e └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd ``` ### Limitations In order to have an efficient cache-system, `huggingface-hub` uses symlinks. However, symlinks are not supported on all machines. This is a known limitation especially on Windows. When this is the case, `huggingface_hub` do not use the `blobs/` directory but directly stores the files in the `snapshots/` directory instead. This workaround allows users to download and cache files from the Hub exactly the same way. Tools to inspect and delete the cache (see below) are also supported. However, the cache-system is less efficient as a single file might be downloaded several times if multiple revisions of the same repo is downloaded. If you want to benefit from the symlink-based cache-system on a Windows machine, you either need to [activate Developer Mode](https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development) or to run Python as an administrator. When symlinks are not supported, a warning message is displayed to the user to alert them they are using a degraded version of the cache-system. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable to true. ## Chunk-based caching (Xet) To provide more efficient file transfers, `hf_xet` adds a `xet` directory to the existing `huggingface_hub` cache, creating additional caching layer to enable chunk-based deduplication. This cache holds chunks, which are immutable byte ranges from files (up to 64KB) that are created using content-defined chunking. For more information on the Xet Storage system, see this [section](https://huggingface.co/docs/hub/storage-backends). The `xet` directory, located at `~/.cache/huggingface/xet` by default, contains two caches, utilized for uploads and downloads with the following structure ```bash ├─ chunk_cache ├─ shard_cache ``` The `xet` cache, like the rest of `hf_xet` is fully integrated with `huggingface_hub`. If you use the existing APIs for interacting with cached assets, there is no need to update your workflow. The `xet` cache is built as an optimization layer on top of the existing `hf_xet` chunk-based deduplication and `huggingface_hub` cache system. The `chunk-cache` directory contains cached data chunks that are used to speed up downloads while the `shard-cache` directory contains cached shards that are utilized on the upload path. ### `chunk_cache` This cache is used on the download path. The cache directory structure is based on a base-64 encoded hash from the content-addressed store (CAS) that backs each Xet-enabled repository. A CAS hash serves as the key to lookup the offsets of where the data is stored. At the topmost level, the first two letters of the base 64 encoded CAS hash are used to create a subdirectory in the `chunk_cache` (keys that share these first two letters are grouped here). The inner levels are comprised of subdirectories with the full key as the directory name. At the base are the cache items which are ranges of blocks that contain the cached chunks. ```bash ├─ xet │ ├─ chunk_cache │ │ ├─ A1 │ │ │ ├─ A1GerURLUcISVivdseeoY1PnYifYkOaCCJ7V5Q9fjgxkZWZhdWx0 │ │ │ │ ├─ AAAAAAEAAAA5DQAAAAAAAIhRLjDI3SS5jYs4ysNKZiJy9XFI8CN7Ww0UyEA9KPD9 │ │ │ │ ├─ AQAAAAIAAABzngAAAAAAAPNqPjd5Zby5aBvabF7Z1itCx0ryMwoCnuQcDwq79jlB ``` When requesting a file, the first thing `hf_xet` does is communicate with Xet storage’s content addressed store (CAS) for reconstruction information. The reconstruction information contains information about the CAS keys required to download the file in its entirety. Before executing the requests for the CAS keys, the `chunk_cache` is consulted. If a key in the cache matches a CAS key, then there is no reason to issue a request for that content. `hf_xet` uses the chunks stored in the directory instead. As the `chunk_cache` is purely an optimization, not a guarantee, `hf_xet` utilizes a computationally efficient eviction policy. When the `chunk_cache` is full (see `Limits and Limitations` below), `hf_xet` implements a random eviction policy when selecting an eviction candidate. This significantly reduces the overhead of managing a robust caching system (e.g., LRU) while still providing most of the benefits of caching chunks. ### `shard_cache` This cache is used when uploading content to the Hub. The directory is flat, comprising only of shard files, each using an ID for the shard name. ```sh ├─ xet │ ├─ shard_cache │ │ ├─ 1fe4ffd5cf0c3375f1ef9aec5016cf773ccc5ca294293d3f92d92771dacfc15d.mdb │ │ ├─ 906ee184dc1cd0615164a89ed64e8147b3fdccd1163d80d794c66814b3b09992.mdb │ │ ├─ ceeeb7ea4cf6c0a8d395a2cf9c08871211fbbd17b9b5dc1005811845307e6b8f.mdb │ │ ├─ e8535155b1b11ebd894c908e91a1e14e3461dddd1392695ddc90ae54a548d8b2.mdb ``` The `shard_cache` contains shards that are: - Locally generated and successfully uploaded to the CAS - Downloaded from CAS as part of the global deduplication algorithm Shards provide a mapping between files and chunks. During uploads, each file is chunked and the hash of the chunk is saved. Every shard in the cache is then consulted. If a shard contains a chunk hash that is present in the local file being uploaded, then that chunk can be discarded as it is already stored in CAS. All shards have an expiration date of 3-4 weeks from when they are downloaded. Shards that are expired are not loaded during upload and are deleted one week after expiration. ### Limits and Limitations The `chunk_cache` is limited to 10GB in size while the `shard_cache` is technically without limits (in practice, the size and use of shards are such that limiting the cache is unnecessary). By design, both caches are without high-level APIs. These caches are used primarily to facilitate the reconstruction (download) or upload of a file. To interact with the assets themselves, it’s recommended that you use the [`huggingface_hub` cache system APIs](https://huggingface.co/docs/huggingface_hub/guides/manage-cache). If you need to reclaim the space utilized by either cache or need to debug any potential cache-related issues, simply remove the `xet` cache entirely by running `rm -rf ~//xet` where `` is the location of your Hugging Face cache, typically `~/.cache/huggingface` Example full `xet`cache directory tree: ```sh ├─ xet │ ├─ chunk_cache │ │ ├─ L1 │ │ │ ├─ L1GerURLUcISVivdseeoY1PnYifYkOaCCJ7V5Q9fjgxkZWZhdWx0 │ │ │ │ ├─ AAAAAAEAAAA5DQAAAAAAAIhRLjDI3SS5jYs4ysNKZiJy9XFI8CN7Ww0UyEA9KPD9 │ │ │ │ ├─ AQAAAAIAAABzngAAAAAAAPNqPjd5Zby5aBvabF7Z1itCx0ryMwoCnuQcDwq79jlB │ ├─ shard_cache │ │ ├─ 1fe4ffd5cf0c3375f1ef9aec5016cf773ccc5ca294293d3f92d92771dacfc15d.mdb │ │ ├─ 906ee184dc1cd0615164a89ed64e8147b3fdccd1163d80d794c66814b3b09992.mdb │ │ ├─ ceeeb7ea4cf6c0a8d395a2cf9c08871211fbbd17b9b5dc1005811845307e6b8f.mdb │ │ ├─ e8535155b1b11ebd894c908e91a1e14e3461dddd1392695ddc90ae54a548d8b2.mdb ``` To learn more about Xet Storage, see this [section](https://huggingface.co/docs/hub/storage-backends). ## Caching assets In addition to caching files from the Hub, downstream libraries often requires to cache other files related to HF but not handled directly by `huggingface_hub` (example: file downloaded from GitHub, preprocessed data, logs,...). In order to cache those files, called `assets`, one can use [`cached_assets_path`]. This small helper generates paths in the HF cache in a unified way based on the name of the library requesting it and optionally on a namespace and a subfolder name. The goal is to let every downstream libraries manage its assets its own way (e.g. no rule on the structure) as long as it stays in the right assets folder. Those libraries can then leverage tools from `huggingface_hub` to manage the cache, in particular scanning and deleting parts of the assets from a CLI command. ```py from huggingface_hub import cached_assets_path assets_path = cached_assets_path(library_name="datasets", namespace="SQuAD", subfolder="download") something_path = assets_path / "something.json" # Do anything you like in your assets folder ! ``` [`cached_assets_path`] is the recommended way to store assets but is not mandatory. If your library already uses its own cache, feel free to use it! ### Assets in practice In practice, your assets cache should look like the following tree: ```text assets/ └── datasets/ │ ├── SQuAD/ │ │ ├── downloaded/ │ │ ├── extracted/ │ │ └── processed/ │ ├── Helsinki-NLP--tatoeba_mt/ │ ├── downloaded/ │ ├── extracted/ │ └── processed/ └── transformers/ ├── default/ │ ├── something/ ├── bert-base-cased/ │ ├── default/ │ └── training/ hub/ └── models--julien-c--EsperBERTo-small/ ├── blobs/ │ ├── (...) │ ├── (...) ├── refs/ │ └── (...) └── [ 128] snapshots/ ├── 2439f60ef33a0d46d85da5001d52aeda5b00ce9f/ │ ├── (...) └── bbc77c8132af1cc5cf678da3f1ddf2de43606d48/ └── (...) ``` ## Manage your file-based cache ### Scan your cache At the moment, cached files are never deleted from your local directory: when you download a new revision of a branch, previous files are kept in case you need them again. Therefore it can be useful to scan your cache directory in order to know which repos and revisions are taking the most disk space. `huggingface_hub` provides an helper to do so that can be used via `huggingface-cli` or in a python script. **Scan cache from the terminal** The easiest way to scan your HF cache-system is to use the `scan-cache` command from `huggingface-cli` tool. This command scans the cache and prints a report with information like repo id, repo type, disk usage, refs and full local path. The snippet below shows a scan report in a folder in which 4 models and 2 datasets are cached. ```text ➜ huggingface-cli scan-cache REPO ID REPO TYPE SIZE ON DISK NB FILES LAST_ACCESSED LAST_MODIFIED REFS LOCAL PATH --------------------------- --------- ------------ -------- ------------- ------------- ------------------- ------------------------------------------------------------------------- glue dataset 116.3K 15 4 days ago 4 days ago 2.4.0, main, 1.17.0 /home/wauplin/.cache/huggingface/hub/datasets--glue google/fleurs dataset 64.9M 6 1 week ago 1 week ago refs/pr/1, main /home/wauplin/.cache/huggingface/hub/datasets--google--fleurs Jean-Baptiste/camembert-ner model 441.0M 7 2 weeks ago 16 hours ago main /home/wauplin/.cache/huggingface/hub/models--Jean-Baptiste--camembert-ner bert-base-cased model 1.9G 13 1 week ago 2 years ago /home/wauplin/.cache/huggingface/hub/models--bert-base-cased t5-base model 10.1K 3 3 months ago 3 months ago main /home/wauplin/.cache/huggingface/hub/models--t5-base t5-small model 970.7M 11 3 days ago 3 days ago refs/pr/1, main /home/wauplin/.cache/huggingface/hub/models--t5-small Done in 0.0s. Scanned 6 repo(s) for a total of 3.4G. Got 1 warning(s) while scanning. Use -vvv to print details. ``` To get a more detailed report, use the `--verbose` option. For each repo, you get a list of all revisions that have been downloaded. As explained above, the files that don't change between 2 revisions are shared thanks to the symlinks. This means that the size of the repo on disk is expected to be less than the sum of the size of each of its revisions. For example, here `bert-base-cased` has 2 revisions of 1.4G and 1.5G but the total disk usage is only 1.9G. ```text ➜ huggingface-cli scan-cache -v REPO ID REPO TYPE REVISION SIZE ON DISK NB FILES LAST_MODIFIED REFS LOCAL PATH --------------------------- --------- ---------------------------------------- ------------ -------- ------------- ----------- ---------------------------------------------------------------------------------------------------------------------------- glue dataset 9338f7b671827df886678df2bdd7cc7b4f36dffd 97.7K 14 4 days ago main, 2.4.0 /home/wauplin/.cache/huggingface/hub/datasets--glue/snapshots/9338f7b671827df886678df2bdd7cc7b4f36dffd glue dataset f021ae41c879fcabcf823648ec685e3fead91fe7 97.8K 14 1 week ago 1.17.0 /home/wauplin/.cache/huggingface/hub/datasets--glue/snapshots/f021ae41c879fcabcf823648ec685e3fead91fe7 google/fleurs dataset 129b6e96cf1967cd5d2b9b6aec75ce6cce7c89e8 25.4K 3 2 weeks ago refs/pr/1 /home/wauplin/.cache/huggingface/hub/datasets--google--fleurs/snapshots/129b6e96cf1967cd5d2b9b6aec75ce6cce7c89e8 google/fleurs dataset 24f85a01eb955224ca3946e70050869c56446805 64.9M 4 1 week ago main /home/wauplin/.cache/huggingface/hub/datasets--google--fleurs/snapshots/24f85a01eb955224ca3946e70050869c56446805 Jean-Baptiste/camembert-ner model dbec8489a1c44ecad9da8a9185115bccabd799fe 441.0M 7 16 hours ago main /home/wauplin/.cache/huggingface/hub/models--Jean-Baptiste--camembert-ner/snapshots/dbec8489a1c44ecad9da8a9185115bccabd799fe bert-base-cased model 378aa1bda6387fd00e824948ebe3488630ad8565 1.5G 9 2 years ago /home/wauplin/.cache/huggingface/hub/models--bert-base-cased/snapshots/378aa1bda6387fd00e824948ebe3488630ad8565 bert-base-cased model a8d257ba9925ef39f3036bfc338acf5283c512d9 1.4G 9 3 days ago main /home/wauplin/.cache/huggingface/hub/models--bert-base-cased/snapshots/a8d257ba9925ef39f3036bfc338acf5283c512d9 t5-base model 23aa4f41cb7c08d4b05c8f327b22bfa0eb8c7ad9 10.1K 3 1 week ago main /home/wauplin/.cache/huggingface/hub/models--t5-base/snapshots/23aa4f41cb7c08d4b05c8f327b22bfa0eb8c7ad9 t5-small model 98ffebbb27340ec1b1abd7c45da12c253ee1882a 726.2M 6 1 week ago refs/pr/1 /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/98ffebbb27340ec1b1abd7c45da12c253ee1882a t5-small model d0a119eedb3718e34c648e594394474cf95e0617 485.8M 6 4 weeks ago /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/d0a119eedb3718e34c648e594394474cf95e0617 t5-small model d78aea13fa7ecd06c29e3e46195d6341255065d5 970.7M 9 1 week ago main /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/d78aea13fa7ecd06c29e3e46195d6341255065d5 Done in 0.0s. Scanned 6 repo(s) for a total of 3.4G. Got 1 warning(s) while scanning. Use -vvv to print details. ``` **Grep example** Since the output is in tabular format, you can combine it with any `grep`-like tools to filter the entries. Here is an example to filter only revisions from the "t5-small" model on a Unix-based machine. ```text ➜ eval "huggingface-cli scan-cache -v" | grep "t5-small" t5-small model 98ffebbb27340ec1b1abd7c45da12c253ee1882a 726.2M 6 1 week ago refs/pr/1 /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/98ffebbb27340ec1b1abd7c45da12c253ee1882a t5-small model d0a119eedb3718e34c648e594394474cf95e0617 485.8M 6 4 weeks ago /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/d0a119eedb3718e34c648e594394474cf95e0617 t5-small model d78aea13fa7ecd06c29e3e46195d6341255065d5 970.7M 9 1 week ago main /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/d78aea13fa7ecd06c29e3e46195d6341255065d5 ``` **Scan cache from Python** For a more advanced usage, use [`scan_cache_dir`] which is the python utility called by the CLI tool. You can use it to get a detailed report structured around 4 dataclasses: - [`HFCacheInfo`]: complete report returned by [`scan_cache_dir`] - [`CachedRepoInfo`]: information about a cached repo - [`CachedRevisionInfo`]: information about a cached revision (e.g. "snapshot") inside a repo - [`CachedFileInfo`]: information about a cached file in a snapshot Here is a simple usage example. See reference for details. ```py >>> from huggingface_hub import scan_cache_dir >>> hf_cache_info = scan_cache_dir() HFCacheInfo( size_on_disk=3398085269, repos=frozenset({ CachedRepoInfo( repo_id='t5-small', repo_type='model', repo_path=PosixPath(...), size_on_disk=970726914, nb_files=11, last_accessed=1662971707.3567169, last_modified=1662971107.3567169, revisions=frozenset({ CachedRevisionInfo( commit_hash='d78aea13fa7ecd06c29e3e46195d6341255065d5', size_on_disk=970726339, snapshot_path=PosixPath(...), # No `last_accessed` as blobs are shared among revisions last_modified=1662971107.3567169, files=frozenset({ CachedFileInfo( file_name='config.json', size_on_disk=1197 file_path=PosixPath(...), blob_path=PosixPath(...), blob_last_accessed=1662971707.3567169, blob_last_modified=1662971107.3567169, ), CachedFileInfo(...), ... }), ), CachedRevisionInfo(...), ... }), ), CachedRepoInfo(...), ... }), warnings=[ CorruptedCacheException("Snapshots dir doesn't exist in cached repo: ..."), CorruptedCacheException(...), ... ], ) ``` ### Clean your cache Scanning your cache is interesting but what you really want to do next is usually to delete some portions to free up some space on your drive. This is possible using the `delete-cache` CLI command. One can also programmatically use the [`~HFCacheInfo.delete_revisions`] helper from [`HFCacheInfo`] object returned when scanning the cache. **Delete strategy** To delete some cache, you need to pass a list of revisions to delete. The tool will define a strategy to free up the space based on this list. It returns a [`DeleteCacheStrategy`] object that describes which files and folders will be deleted. The [`DeleteCacheStrategy`] allows give you how much space is expected to be freed. Once you agree with the deletion, you must execute it to make the deletion effective. In order to avoid discrepancies, you cannot edit a strategy object manually. The strategy to delete revisions is the following: - the `snapshot` folder containing the revision symlinks is deleted. - blobs files that are targeted only by revisions to be deleted are deleted as well. - if a revision is linked to 1 or more `refs`, references are deleted. - if all revisions from a repo are deleted, the entire cached repository is deleted. Revision hashes are unique across all repositories. This means you don't need to provide any `repo_id` or `repo_type` when removing revisions. If a revision is not found in the cache, it will be silently ignored. Besides, if a file or folder cannot be found while trying to delete it, a warning will be logged but no error is thrown. The deletion continues for other paths contained in the [`DeleteCacheStrategy`] object. **Clean cache from the terminal** The easiest way to delete some revisions from your HF cache-system is to use the `delete-cache` command from `huggingface-cli` tool. The command has two modes. By default, a TUI (Terminal User Interface) is displayed to the user to select which revisions to delete. This TUI is currently in beta as it has not been tested on all platforms. If the TUI doesn't work on your machine, you can disable it using the `--disable-tui` flag. **Using the TUI** This is the default mode. To use it, you first need to install extra dependencies by running the following command: ``` pip install huggingface_hub["cli"] ``` Then run the command: ``` huggingface-cli delete-cache ``` You should now see a list of revisions that you can select/deselect:
Instructions: - Press keyboard arrow keys `` and `` to move the cursor. - Press `` to toggle (select/unselect) an item. - When a revision is selected, the first line is updated to show you how much space will be freed. - Press `` to confirm your selection. - If you want to cancel the operation and quit, you can select the first item ("None of the following"). If this item is selected, the delete process will be cancelled, no matter what other items are selected. Otherwise you can also press `` to quit the TUI. Once you've selected the revisions you want to delete and pressed ``, a last confirmation message will be prompted. Press `` again and the deletion will be effective. If you want to cancel, enter `n`. ```txt ✗ huggingface-cli delete-cache --dir ~/.cache/huggingface/hub ? Select revisions to delete: 2 revision(s) selected. ? 2 revisions selected counting for 3.1G. Confirm deletion ? Yes Start deletion. Done. Deleted 1 repo(s) and 0 revision(s) for a total of 3.1G. ``` **Without TUI** As mentioned above, the TUI mode is currently in beta and is optional. It may be the case that it doesn't work on your machine or that you don't find it convenient. Another approach is to use the `--disable-tui` flag. The process is very similar as you will be asked to manually review the list of revisions to delete. However, this manual step will not take place in the terminal directly but in a temporary file generated on the fly and that you can manually edit. This file has all the instructions you need in the header. Open it in your favorite text editor. To select/deselect a revision, simply comment/uncomment it with a `#`. Once the manual review is done and the file is edited, you can save it. Go back to your terminal and press ``. By default it will compute how much space would be freed with the updated list of revisions. You can continue to edit the file or confirm with `"y"`. ```sh huggingface-cli delete-cache --disable-tui ``` Example of command file: ```txt # INSTRUCTIONS # ------------ # This is a temporary file created by running `huggingface-cli delete-cache` with the # `--disable-tui` option. It contains a set of revisions that can be deleted from your # local cache directory. # # Please manually review the revisions you want to delete: # - Revision hashes can be commented out with '#'. # - Only non-commented revisions in this file will be deleted. # - Revision hashes that are removed from this file are ignored as well. # - If `CANCEL_DELETION` line is uncommented, the all cache deletion is cancelled and # no changes will be applied. # # Once you've manually reviewed this file, please confirm deletion in the terminal. This # file will be automatically removed once done. # ------------ # KILL SWITCH # ------------ # Un-comment following line to completely cancel the deletion process # CANCEL_DELETION # ------------ # REVISIONS # ------------ # Dataset chrisjay/crowd-speech-africa (761.7M, used 5 days ago) ebedcd8c55c90d39fd27126d29d8484566cd27ca # Refs: main # modified 5 days ago # Dataset oscar (3.3M, used 4 days ago) # 916f956518279c5e60c63902ebdf3ddf9fa9d629 # Refs: main # modified 4 days ago # Dataset wikiann (804.1K, used 2 weeks ago) 89d089624b6323d69dcd9e5eb2def0551887a73a # Refs: main # modified 2 weeks ago # Dataset z-uo/male-LJSpeech-italian (5.5G, used 5 days ago) # 9cfa5647b32c0a30d0adfca06bf198d82192a0d1 # Refs: main # modified 5 days ago ``` **Clean cache from Python** For more flexibility, you can also use the [`~HFCacheInfo.delete_revisions`] method programmatically. Here is a simple example. See reference for details. ```py >>> from huggingface_hub import scan_cache_dir >>> delete_strategy = scan_cache_dir().delete_revisions( ... "81fd1d6e7847c99f5862c9fb81387956d99ec7aa" ... "e2983b237dccf3ab4937c97fa717319a9ca1a96d", ... "6c0e6080953db56375760c0471a8c5f2929baf11", ... ) >>> print("Will free " + delete_strategy.expected_freed_size_str) Will free 8.6G >>> delete_strategy.execute() Cache deletion done. Saved 8.6G. ``` huggingface_hub-0.31.1/docs/source/en/guides/manage-spaces.md000066400000000000000000000304021500667546600240570ustar00rootroot00000000000000 # Manage your Space In this guide, we will see how to manage your Space runtime ([secrets](https://huggingface.co/docs/hub/spaces-overview#managing-secrets), [hardware](https://huggingface.co/docs/hub/spaces-gpus), and [storage](https://huggingface.co/docs/hub/spaces-storage#persistent-storage)) using `huggingface_hub`. ## A simple example: configure secrets and hardware. Here is an end-to-end example to create and setup a Space on the Hub. **1. Create a Space on the Hub.** ```py >>> from huggingface_hub import HfApi >>> repo_id = "Wauplin/my-cool-training-space" >>> api = HfApi() # For example with a Gradio SDK >>> api.create_repo(repo_id=repo_id, repo_type="space", space_sdk="gradio") ``` **1. (bis) Duplicate a Space.** This can prove useful if you want to build up from an existing Space instead of starting from scratch. It is also useful is you want control over the configuration/settings of a public Space. See [`duplicate_space`] for more details. ```py >>> api.duplicate_space("multimodalart/dreambooth-training") ``` **2. Upload your code using your preferred solution.** Here is an example to upload the local folder `src/` from your machine to your Space: ```py >>> api.upload_folder(repo_id=repo_id, repo_type="space", folder_path="src/") ``` At this step, your app should already be running on the Hub for free ! However, you might want to configure it further with secrets and upgraded hardware. **3. Configure secrets and variables** Your Space might require some secret keys, token or variables to work. See [docs](https://huggingface.co/docs/hub/spaces-overview#managing-secrets) for more details. For example, an HF token to upload an image dataset to the Hub once generated from your Space. ```py >>> api.add_space_secret(repo_id=repo_id, key="HF_TOKEN", value="hf_api_***") >>> api.add_space_variable(repo_id=repo_id, key="MODEL_REPO_ID", value="user/repo") ``` Secrets and variables can be deleted as well: ```py >>> api.delete_space_secret(repo_id=repo_id, key="HF_TOKEN") >>> api.delete_space_variable(repo_id=repo_id, key="MODEL_REPO_ID") ``` From within your Space, secrets are available as environment variables (or Streamlit Secrets Management if using Streamlit). No need to fetch them via the API! Any change in your Space configuration (secrets or hardware) will trigger a restart of your app. **Bonus: set secrets and variables when creating or duplicating the Space!** Secrets and variables can be set when creating or duplicating a space: ```py >>> api.create_repo( ... repo_id=repo_id, ... repo_type="space", ... space_sdk="gradio", ... space_secrets=[{"key"="HF_TOKEN", "value"="hf_api_***"}, ...], ... space_variables=[{"key"="MODEL_REPO_ID", "value"="user/repo"}, ...], ... ) ``` ```py >>> api.duplicate_space( ... from_id=repo_id, ... secrets=[{"key"="HF_TOKEN", "value"="hf_api_***"}, ...], ... variables=[{"key"="MODEL_REPO_ID", "value"="user/repo"}, ...], ... ) ``` **4. Configure the hardware** By default, your Space will run on a CPU environment for free. You can upgrade the hardware to run it on GPUs. A payment card or a community grant is required to access upgrade your Space. See [docs](https://huggingface.co/docs/hub/spaces-gpus) for more details. ```py # Use `SpaceHardware` enum >>> from huggingface_hub import SpaceHardware >>> api.request_space_hardware(repo_id=repo_id, hardware=SpaceHardware.T4_MEDIUM) # Or simply pass a string value >>> api.request_space_hardware(repo_id=repo_id, hardware="t4-medium") ``` Hardware updates are not done immediately as your Space has to be reloaded on our servers. At any time, you can check on which hardware your Space is running to see if your request has been met. ```py >>> runtime = api.get_space_runtime(repo_id=repo_id) >>> runtime.stage "RUNNING_BUILDING" >>> runtime.hardware "cpu-basic" >>> runtime.requested_hardware "t4-medium" ``` You now have a Space fully configured. Make sure to downgrade your Space back to "cpu-classic" when you are done using it. **Bonus: request hardware when creating or duplicating the Space!** Upgraded hardware will be automatically assigned to your Space once it's built. ```py >>> api.create_repo( ... repo_id=repo_id, ... repo_type="space", ... space_sdk="gradio" ... space_hardware="cpu-upgrade", ... space_storage="small", ... space_sleep_time="7200", # 2 hours in secs ... ) ``` ```py >>> api.duplicate_space( ... from_id=repo_id, ... hardware="cpu-upgrade", ... storage="small", ... sleep_time="7200", # 2 hours in secs ... ) ``` **5. Pause and restart your Space** By default if your Space is running on an upgraded hardware, it will never be stopped. However to avoid getting billed, you might want to pause it when you are not using it. This is possible using [`pause_space`]. A paused Space will be inactive until the owner of the Space restarts it, either with the UI or via API using [`restart_space`]. For more details about paused mode, please refer to [this section](https://huggingface.co/docs/hub/spaces-gpus#pause) ```py # Pause your Space to avoid getting billed >>> api.pause_space(repo_id=repo_id) # (...) # Restart it when you need it >>> api.restart_space(repo_id=repo_id) ``` Another possibility is to set a timeout for your Space. If your Space is inactive for more than the timeout duration, it will go to sleep. Any visitor landing on your Space will start it back up. You can set a timeout using [`set_space_sleep_time`]. For more details about sleeping mode, please refer to [this section](https://huggingface.co/docs/hub/spaces-gpus#sleep-time). ```py # Put your Space to sleep after 1h of inactivity >>> api.set_space_sleep_time(repo_id=repo_id, sleep_time=3600) ``` Note: if you are using a 'cpu-basic' hardware, you cannot configure a custom sleep time. Your Space will automatically be paused after 48h of inactivity. **Bonus: set a sleep time while requesting hardware** Upgraded hardware will be automatically assigned to your Space once it's built. ```py >>> api.request_space_hardware(repo_id=repo_id, hardware=SpaceHardware.T4_MEDIUM, sleep_time=3600) ``` **Bonus: set a sleep time when creating or duplicating the Space!** ```py >>> api.create_repo( ... repo_id=repo_id, ... repo_type="space", ... space_sdk="gradio" ... space_hardware="t4-medium", ... space_sleep_time="3600", ... ) ``` ```py >>> api.duplicate_space( ... from_id=repo_id, ... hardware="t4-medium", ... sleep_time="3600", ... ) ``` **6. Add persistent storage to your Space** You can choose the storage tier of your choice to access disk space that persists across restarts of your Space. This means you can read and write from disk like you would with a traditional hard drive. See [docs](https://huggingface.co/docs/hub/spaces-storage#persistent-storage) for more details. ```py >>> from huggingface_hub import SpaceStorage >>> api.request_space_storage(repo_id=repo_id, storage=SpaceStorage.LARGE) ``` You can also delete your storage, losing all the data permanently. ```py >>> api.delete_space_storage(repo_id=repo_id) ``` Note: You cannot decrease the storage tier of your space once it's been granted. To do so, you must delete the storage first then request the new desired tier. **Bonus: request storage when creating or duplicating the Space!** ```py >>> api.create_repo( ... repo_id=repo_id, ... repo_type="space", ... space_sdk="gradio" ... space_storage="large", ... ) ``` ```py >>> api.duplicate_space( ... from_id=repo_id, ... storage="large", ... ) ``` ## More advanced: temporarily upgrade your Space ! Spaces allow for a lot of different use cases. Sometimes, you might want to temporarily run a Space on a specific hardware, do something and then shut it down. In this section, we will explore how to benefit from Spaces to finetune a model on demand. This is only one way of solving this particular problem. It has to be taken as a suggestion and adapted to your use case. Let's assume we have a Space to finetune a model. It is a Gradio app that takes as input a model id and a dataset id. The workflow is as follows: 0. (Prompt the user for a model and a dataset) 1. Load the model from the Hub. 2. Load the dataset from the Hub. 3. Finetune the model on the dataset. 4. Upload the new model to the Hub. Step 3. requires a custom hardware but you don't want your Space to be running all the time on a paid GPU. A solution is to dynamically request hardware for the training and shut it down afterwards. Since requesting hardware restarts your Space, your app must somehow "remember" the current task it is performing. There are multiple ways of doing this. In this guide we will see one solution using a Dataset as "task scheduler". ### App skeleton Here is what your app would look like. On startup, check if a task is scheduled and if yes, run it on the correct hardware. Once done, set back hardware to the free-plan CPU and prompt the user for a new task. Such a workflow does not support concurrent access as normal demos. In particular, the interface will be disabled when training occurs. It is preferable to set your repo as private to ensure you are the only user. ```py # Space will need your token to request hardware: set it as a Secret ! HF_TOKEN = os.environ.get("HF_TOKEN") # Space own repo_id TRAINING_SPACE_ID = "Wauplin/dreambooth-training" from huggingface_hub import HfApi, SpaceHardware api = HfApi(token=HF_TOKEN) # On Space startup, check if a task is scheduled. If yes, finetune the model. If not, # display an interface to request a new task. task = get_task() if task is None: # Start Gradio app def gradio_fn(task): # On user request, add task and request hardware add_task(task) api.request_space_hardware(repo_id=TRAINING_SPACE_ID, hardware=SpaceHardware.T4_MEDIUM) gr.Interface(fn=gradio_fn, ...).launch() else: runtime = api.get_space_runtime(repo_id=TRAINING_SPACE_ID) # Check if Space is loaded with a GPU. if runtime.hardware == SpaceHardware.T4_MEDIUM: # If yes, finetune base model on dataset ! train_and_upload(task) # Then, mark the task as "DONE" mark_as_done(task) # DO NOT FORGET: set back CPU hardware api.request_space_hardware(repo_id=TRAINING_SPACE_ID, hardware=SpaceHardware.CPU_BASIC) else: api.request_space_hardware(repo_id=TRAINING_SPACE_ID, hardware=SpaceHardware.T4_MEDIUM) ``` ### Task scheduler Scheduling tasks can be done in many ways. Here is an example how it could be done using a simple CSV stored as a Dataset. ```py # Dataset ID in which a `tasks.csv` file contains the tasks to perform. # Here is a basic example for `tasks.csv` containing inputs (base model and dataset) # and status (PENDING or DONE). # multimodalart/sd-fine-tunable,Wauplin/concept-1,DONE # multimodalart/sd-fine-tunable,Wauplin/concept-2,PENDING TASK_DATASET_ID = "Wauplin/dreambooth-task-scheduler" def _get_csv_file(): return hf_hub_download(repo_id=TASK_DATASET_ID, filename="tasks.csv", repo_type="dataset", token=HF_TOKEN) def get_task(): with open(_get_csv_file()) as csv_file: csv_reader = csv.reader(csv_file, delimiter=',') for row in csv_reader: if row[2] == "PENDING": return row[0], row[1] # model_id, dataset_id def add_task(task): model_id, dataset_id = task with open(_get_csv_file()) as csv_file: with open(csv_file, "r") as f: tasks = f.read() api.upload_file( repo_id=repo_id, repo_type=repo_type, path_in_repo="tasks.csv", # Quick and dirty way to add a task path_or_fileobj=(tasks + f"\n{model_id},{dataset_id},PENDING").encode() ) def mark_as_done(task): model_id, dataset_id = task with open(_get_csv_file()) as csv_file: with open(csv_file, "r") as f: tasks = f.read() api.upload_file( repo_id=repo_id, repo_type=repo_type, path_in_repo="tasks.csv", # Quick and dirty way to set the task as DONE path_or_fileobj=tasks.replace( f"{model_id},{dataset_id},PENDING", f"{model_id},{dataset_id},DONE" ).encode() ) ``` huggingface_hub-0.31.1/docs/source/en/guides/model-cards.md000066400000000000000000000241721500667546600235540ustar00rootroot00000000000000 # Create and share Model Cards The `huggingface_hub` library provides a Python interface to create, share, and update Model Cards. Visit [the dedicated documentation page](https://huggingface.co/docs/hub/models-cards) for a deeper view of what Model Cards on the Hub are, and how they work under the hood. ## Load a Model Card from the Hub To load an existing card from the Hub, you can use the [`ModelCard.load`] function. Here, we'll load the card from [`nateraw/vit-base-beans`](https://huggingface.co/nateraw/vit-base-beans). ```python from huggingface_hub import ModelCard card = ModelCard.load('nateraw/vit-base-beans') ``` This card has some helpful attributes that you may want to access/leverage: - `card.data`: Returns a [`ModelCardData`] instance with the model card's metadata. Call `.to_dict()` on this instance to get the representation as a dictionary. - `card.text`: Returns the text of the card, *excluding the metadata header*. - `card.content`: Returns the text content of the card, *including the metadata header*. ## Create Model Cards ### From Text To initialize a Model Card from text, just pass the text content of the card to the `ModelCard` on init. ```python content = """ --- language: en license: mit --- # My Model Card """ card = ModelCard(content) card.data.to_dict() == {'language': 'en', 'license': 'mit'} # True ``` Another way you might want to do this is with f-strings. In the following example, we: - Use [`ModelCardData.to_yaml`] to convert metadata we defined to YAML so we can use it to insert the YAML block in the model card. - Show how you might use a template variable via Python f-strings. ```python card_data = ModelCardData(language='en', license='mit', library='timm') example_template_var = 'nateraw' content = f""" --- { card_data.to_yaml() } --- # My Model Card This model created by [@{example_template_var}](https://github.com/{example_template_var}) """ card = ModelCard(content) print(card) ``` The above example would leave us with a card that looks like this: ``` --- language: en license: mit library: timm --- # My Model Card This model created by [@nateraw](https://github.com/nateraw) ``` ### From a Jinja Template If you have `Jinja2` installed, you can create Model Cards from a jinja template file. Let's see a basic example: ```python from pathlib import Path from huggingface_hub import ModelCard, ModelCardData # Define your jinja template template_text = """ --- {{ card_data }} --- # Model Card for MyCoolModel This model does this and that. This model was created by [@{{ author }}](https://hf.co/{{author}}). """.strip() # Write the template to a file Path('custom_template.md').write_text(template_text) # Define card metadata card_data = ModelCardData(language='en', license='mit', library_name='keras') # Create card from template, passing it any jinja template variables you want. # In our case, we'll pass author card = ModelCard.from_template(card_data, template_path='custom_template.md', author='nateraw') card.save('my_model_card_1.md') print(card) ``` The resulting card's markdown looks like this: ``` --- language: en license: mit library_name: keras --- # Model Card for MyCoolModel This model does this and that. This model was created by [@nateraw](https://hf.co/nateraw). ``` If you update any card.data, it'll reflect in the card itself. ``` card.data.library_name = 'timm' card.data.language = 'fr' card.data.license = 'apache-2.0' print(card) ``` Now, as you can see, the metadata header has been updated: ``` --- language: fr license: apache-2.0 library_name: timm --- # Model Card for MyCoolModel This model does this and that. This model was created by [@nateraw](https://hf.co/nateraw). ``` As you update the card data, you can validate the card is still valid against the Hub by calling [`ModelCard.validate`]. This ensures that the card passes any validation rules set up on the Hugging Face Hub. ### From the Default Template Instead of using your own template, you can also use the [default template](https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/modelcard_template.md), which is a fully featured model card with tons of sections you may want to fill out. Under the hood, it uses [Jinja2](https://jinja.palletsprojects.com/en/3.1.x/) to fill out a template file. Note that you will have to have Jinja2 installed to use `from_template`. You can do so with `pip install Jinja2`. ```python card_data = ModelCardData(language='en', license='mit', library_name='keras') card = ModelCard.from_template( card_data, model_id='my-cool-model', model_description="this model does this and that", developers="Nate Raw", repo="https://github.com/huggingface/huggingface_hub", ) card.save('my_model_card_2.md') print(card) ``` ## Share Model Cards If you're authenticated with the Hugging Face Hub (either by using `huggingface-cli login` or [`login`]), you can push cards to the Hub by simply calling [`ModelCard.push_to_hub`]. Let's take a look at how to do that... First, we'll create a new repo called 'hf-hub-modelcards-pr-test' under the authenticated user's namespace: ```python from huggingface_hub import whoami, create_repo user = whoami()['name'] repo_id = f'{user}/hf-hub-modelcards-pr-test' url = create_repo(repo_id, exist_ok=True) ``` Then, we'll create a card from the default template (same as the one defined in the section above): ```python card_data = ModelCardData(language='en', license='mit', library_name='keras') card = ModelCard.from_template( card_data, model_id='my-cool-model', model_description="this model does this and that", developers="Nate Raw", repo="https://github.com/huggingface/huggingface_hub", ) ``` Finally, we'll push that up to the hub ```python card.push_to_hub(repo_id) ``` You can check out the resulting card [here](https://huggingface.co/nateraw/hf-hub-modelcards-pr-test/blob/main/README.md). If you instead wanted to push a card as a pull request, you can just say `create_pr=True` when calling `push_to_hub`: ```python card.push_to_hub(repo_id, create_pr=True) ``` A resulting PR created from this command can be seen [here](https://huggingface.co/nateraw/hf-hub-modelcards-pr-test/discussions/3). ## Update metadata In this section we will see what metadata are in repo cards and how to update them. `metadata` refers to a hash map (or key value) context that provides some high-level information about a model, dataset or Space. That information can include details such as the model's `pipeline type`, `model_id` or `model_description`. For more detail you can take a look to these guides: [Model Card](https://huggingface.co/docs/hub/model-cards#model-card-metadata), [Dataset Card](https://huggingface.co/docs/hub/datasets-cards#dataset-card-metadata) and [Spaces Settings](https://huggingface.co/docs/hub/spaces-settings#spaces-settings). Now lets see some examples on how to update those metadata. Let's start with a first example: ```python >>> from huggingface_hub import metadata_update >>> metadata_update("username/my-cool-model", {"pipeline_tag": "image-classification"}) ``` With these two lines of code you will update the metadata to set a new `pipeline_tag`. By default, you cannot update a key that is already existing on the card. If you want to do so, you must pass `overwrite=True` explicitly: ```python >>> from huggingface_hub import metadata_update >>> metadata_update("username/my-cool-model", {"pipeline_tag": "text-generation"}, overwrite=True) ``` It often happen that you want to suggest some changes to a repository on which you don't have write permission. You can do that by creating a PR on that repo which will allow the owners to review and merge your suggestions. ```python >>> from huggingface_hub import metadata_update >>> metadata_update("someone/model", {"pipeline_tag": "text-classification"}, create_pr=True) ``` ## Include Evaluation Results To include evaluation results in the metadata `model-index`, you can pass an [`EvalResult`] or a list of `EvalResult` with your associated evaluation results. Under the hood it'll create the `model-index` when you call `card.data.to_dict()`. For more information on how this works, you can check out [this section of the Hub docs](https://huggingface.co/docs/hub/models-cards#evaluation-results). Note that using this function requires you to include the `model_name` attribute in [`ModelCardData`]. ```python card_data = ModelCardData( language='en', license='mit', model_name='my-cool-model', eval_results = EvalResult( task_type='image-classification', dataset_type='beans', dataset_name='Beans', metric_type='accuracy', metric_value=0.7 ) ) card = ModelCard.from_template(card_data) print(card.data) ``` The resulting `card.data` should look like this: ``` language: en license: mit model-index: - name: my-cool-model results: - task: type: image-classification dataset: name: Beans type: beans metrics: - type: accuracy value: 0.7 ``` If you have more than one evaluation result you'd like to share, just pass a list of `EvalResult`: ```python card_data = ModelCardData( language='en', license='mit', model_name='my-cool-model', eval_results = [ EvalResult( task_type='image-classification', dataset_type='beans', dataset_name='Beans', metric_type='accuracy', metric_value=0.7 ), EvalResult( task_type='image-classification', dataset_type='beans', dataset_name='Beans', metric_type='f1', metric_value=0.65 ) ] ) card = ModelCard.from_template(card_data) card.data ``` Which should leave you with the following `card.data`: ``` language: en license: mit model-index: - name: my-cool-model results: - task: type: image-classification dataset: name: Beans type: beans metrics: - type: accuracy value: 0.7 - type: f1 value: 0.65 ``` huggingface_hub-0.31.1/docs/source/en/guides/overview.md000066400000000000000000000141371500667546600232300ustar00rootroot00000000000000 # How-to guides In this section, you will find practical guides to help you achieve a specific goal. Take a look at these guides to learn how to use huggingface_hub to solve real-world problems: huggingface_hub-0.31.1/docs/source/en/guides/repository.md000066400000000000000000000264501500667546600236020ustar00rootroot00000000000000 # Create and manage a repository The Hugging Face Hub is a collection of git repositories. [Git](https://git-scm.com/) is a widely used tool in software development to easily version projects when working collaboratively. This guide will show you how to interact with the repositories on the Hub, especially: - Create and delete a repository. - Manage branches and tags. - Rename your repository. - Update your repository visibility. - Manage a local copy of your repository. If you are used to working with platforms such as GitLab/GitHub/Bitbucket, your first instinct might be to use `git` CLI to clone your repo (`git clone`), commit changes (`git add, git commit`) and push them (`git push`). This is valid when using the Hugging Face Hub. However, software engineering and machine learning do not share the same requirements and workflows. Model repositories might maintain large model weight files for different frameworks and tools, so cloning the repository can lead to you maintaining large local folders with massive sizes. As a result, it may be more efficient to use our custom HTTP methods. You can read our [Git vs HTTP paradigm](../concepts/git_vs_http) explanation page for more details. If you want to create and manage a repository on the Hub, your machine must be logged in. If you are not, please refer to [this section](../quick-start#authentication). In the rest of this guide, we will assume that your machine is logged in. ## Repo creation and deletion The first step is to know how to create and delete repositories. You can only manage repositories that you own (under your username namespace) or from organizations in which you have write permissions. ### Create a repository Create an empty repository with [`create_repo`] and give it a name with the `repo_id` parameter. The `repo_id` is your namespace followed by the repository name: `username_or_org/repo_name`. ```py >>> from huggingface_hub import create_repo >>> create_repo("lysandre/test-model") 'https://huggingface.co/lysandre/test-model' ``` By default, [`create_repo`] creates a model repository. But you can use the `repo_type` parameter to specify another repository type. For example, if you want to create a dataset repository: ```py >>> from huggingface_hub import create_repo >>> create_repo("lysandre/test-dataset", repo_type="dataset") 'https://huggingface.co/datasets/lysandre/test-dataset' ``` When you create a repository, you can set your repository visibility with the `private` parameter. ```py >>> from huggingface_hub import create_repo >>> create_repo("lysandre/test-private", private=True) ``` If you want to change the repository visibility at a later time, you can use the [`update_repo_settings`] function. If you are part of an organization with an Enterprise plan, you can create a repo in a specific resource group by passing `resource_group_id` as parameter to [`create_repo`]. Resource groups are a security feature to control which members from your org can access a given resource. You can get the resource group ID by copying it from your org settings page url on the Hub (e.g. `"https://huggingface.co/organizations/huggingface/settings/resource-groups/66670e5163145ca562cb1988"` => `"66670e5163145ca562cb1988"`). For more details about resource group, check out this [guide](https://huggingface.co/docs/hub/en/security-resource-groups). ### Delete a repository Delete a repository with [`delete_repo`]. Make sure you want to delete a repository because this is an irreversible process! Specify the `repo_id` of the repository you want to delete: ```py >>> delete_repo(repo_id="lysandre/my-corrupted-dataset", repo_type="dataset") ``` ### Duplicate a repository (only for Spaces) In some cases, you want to copy someone else's repo to adapt it to your use case. This is possible for Spaces using the [`duplicate_space`] method. It will duplicate the whole repository. You will still need to configure your own settings (hardware, sleep-time, storage, variables and secrets). Check out our [Manage your Space](./manage-spaces) guide for more details. ```py >>> from huggingface_hub import duplicate_space >>> duplicate_space("multimodalart/dreambooth-training", private=False) RepoUrl('https://huggingface.co/spaces/nateraw/dreambooth-training',...) ``` ## Upload and download files Now that you have created your repository, you are interested in pushing changes to it and downloading files from it. These 2 topics deserve their own guides. Please refer to the [upload](./upload) and the [download](./download) guides to learn how to use your repository. ## Branches and tags Git repositories often make use of branches to store different versions of a same repository. Tags can also be used to flag a specific state of your repository, for example, when releasing a version. More generally, branches and tags are referred as [git references](https://git-scm.com/book/en/v2/Git-Internals-Git-References). ### Create branches and tags You can create new branch and tags using [`create_branch`] and [`create_tag`]: ```py >>> from huggingface_hub import create_branch, create_tag # Create a branch on a Space repo from `main` branch >>> create_branch("Matthijs/speecht5-tts-demo", repo_type="space", branch="handle-dog-speaker") # Create a tag on a Dataset repo from `v0.1-release` branch >>> create_tag("bigcode/the-stack", repo_type="dataset", revision="v0.1-release", tag="v0.1.1", tag_message="Bump release version.") ``` You can use the [`delete_branch`] and [`delete_tag`] functions in the same way to delete a branch or a tag. ### List all branches and tags You can also list the existing git refs from a repository using [`list_repo_refs`]: ```py >>> from huggingface_hub import list_repo_refs >>> list_repo_refs("bigcode/the-stack", repo_type="dataset") GitRefs( branches=[ GitRefInfo(name='main', ref='refs/heads/main', target_commit='18edc1591d9ce72aa82f56c4431b3c969b210ae3'), GitRefInfo(name='v1.1.a1', ref='refs/heads/v1.1.a1', target_commit='f9826b862d1567f3822d3d25649b0d6d22ace714') ], converts=[], tags=[ GitRefInfo(name='v1.0', ref='refs/tags/v1.0', target_commit='c37a8cd1e382064d8aced5e05543c5f7753834da') ] ) ``` ## Change repository settings Repositories come with some settings that you can configure. Most of the time, you will want to do that manually in the repo settings page in your browser. You must have write access to a repo to configure it (either own it or being part of an organization). In this section, we will see the settings that you can also configure programmatically using `huggingface_hub`. Some settings are specific to Spaces (hardware, environment variables,...). To configure those, please refer to our [Manage your Spaces](../guides/manage-spaces) guide. ### Update visibility A repository can be public or private. A private repository is only visible to you or members of the organization in which the repository is located. Change a repository to private as shown in the following: ```py >>> from huggingface_hub import update_repo_settings >>> update_repo_settings(repo_id=repo_id, private=True) ``` ### Setup gated access To give more control over how repos are used, the Hub allows repo authors to enable **access requests** for their repos. User must agree to share their contact information (username and email address) with the repo authors to access the files when enabled. A repo with access requests enabled is called a **gated repo**. You can set a repo as gated using [`update_repo_settings`]: ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.update_repo_settings(repo_id=repo_id, gated="auto") # Set automatic gating for a model ``` ### Rename your repository You can rename your repository on the Hub using [`move_repo`]. Using this method, you can also move the repo from a user to an organization. When doing so, there are a [few limitations](https://hf.co/docs/hub/repositories-settings#renaming-or-transferring-a-repo) that you should be aware of. For example, you can't transfer your repo to another user. ```py >>> from huggingface_hub import move_repo >>> move_repo(from_id="Wauplin/cool-model", to_id="huggingface/cool-model") ``` ## Manage a local copy of your repository All the actions described above can be done using HTTP requests. However, in some cases you might be interested in having a local copy of your repository and interact with it using the Git commands you are familiar with. The [`Repository`] class allows you to interact with files and repositories on the Hub with functions similar to Git commands. It is a wrapper over Git and Git-LFS methods to use the Git commands you already know and love. Before starting, please make sure you have Git-LFS installed (see [here](https://git-lfs.github.com/) for installation instructions). [`Repository`] is deprecated in favor of the http-based alternatives implemented in [`HfApi`]. Given its large adoption in legacy code, the complete removal of [`Repository`] will only happen in release `v1.0`. For more details, please read [this explanation page](./concepts/git_vs_http). ### Use a local repository Instantiate a [`Repository`] object with a path to a local repository: ```py >>> from huggingface_hub import Repository >>> repo = Repository(local_dir="//") ``` ### Clone The `clone_from` parameter clones a repository from a Hugging Face repository ID to a local directory specified by the `local_dir` argument: ```py >>> from huggingface_hub import Repository >>> repo = Repository(local_dir="w2v2", clone_from="facebook/wav2vec2-large-960h-lv60") ``` `clone_from` can also clone a repository using a URL: ```py >>> repo = Repository(local_dir="huggingface-hub", clone_from="https://huggingface.co/facebook/wav2vec2-large-960h-lv60") ``` You can combine the `clone_from` parameter with [`create_repo`] to create and clone a repository: ```py >>> repo_url = create_repo(repo_id="repo_name") >>> repo = Repository(local_dir="repo_local_path", clone_from=repo_url) ``` You can also configure a Git username and email to a cloned repository by specifying the `git_user` and `git_email` parameters when you clone a repository. When users commit to that repository, Git will be aware of the commit author. ```py >>> repo = Repository( ... "my-dataset", ... clone_from="/", ... token=True, ... repo_type="dataset", ... git_user="MyName", ... git_email="me@cool.mail" ... ) ``` ### Branch Branches are important for collaboration and experimentation without impacting your current files and code. Switch between branches with [`~Repository.git_checkout`]. For example, if you want to switch from `branch1` to `branch2`: ```py >>> from huggingface_hub import Repository >>> repo = Repository(local_dir="huggingface-hub", clone_from="/", revision='branch1') >>> repo.git_checkout("branch2") ``` ### Pull [`~Repository.git_pull`] allows you to update a current local branch with changes from a remote repository: ```py >>> from huggingface_hub import Repository >>> repo.git_pull() ``` Set `rebase=True` if you want your local commits to occur after your branch is updated with the new commits from the remote: ```py >>> repo.git_pull(rebase=True) ``` huggingface_hub-0.31.1/docs/source/en/guides/search.md000066400000000000000000000036661500667546600226340ustar00rootroot00000000000000 # Search the Hub In this tutorial, you will learn how to search models, datasets and spaces on the Hub using `huggingface_hub`. ## How to list repositories ? `huggingface_hub` library includes an HTTP client [`HfApi`] to interact with the Hub. Among other things, it can list models, datasets and spaces stored on the Hub: ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> models = api.list_models() ``` The output of [`list_models`] is an iterator over the models stored on the Hub. Similarly, you can use [`list_datasets`] to list datasets and [`list_spaces`] to list Spaces. ## How to filter repositories ? Listing repositories is great but now you might want to filter your search. The list helpers have several attributes like: - `filter` - `author` - `search` - ... Let's see an example to get all models on the Hub that does image classification, have been trained on the imagenet dataset and that runs with PyTorch. ```py models = hf_api.list_models( task="image-classification", library="pytorch", trained_dataset="imagenet", ) ``` While filtering, you can also sort the models and take only the top results. For example, the following example fetches the top 5 most downloaded datasets on the Hub: ```py >>> list(list_datasets(sort="downloads", direction=-1, limit=5)) [DatasetInfo( id='argilla/databricks-dolly-15k-curated-en', author='argilla', sha='4dcd1dedbe148307a833c931b21ca456a1fc4281', last_modified=datetime.datetime(2023, 10, 2, 12, 32, 53, tzinfo=datetime.timezone.utc), private=False, downloads=8889377, (...) ``` To explore available filters on the Hub, visit [models](https://huggingface.co/models) and [datasets](https://huggingface.co/datasets) pages in your browser, search for some parameters and look at the values in the URL. huggingface_hub-0.31.1/docs/source/en/guides/upload.md000066400000000000000000001013641500667546600226450ustar00rootroot00000000000000 # Upload files to the Hub Sharing your files and work is an important aspect of the Hub. The `huggingface_hub` offers several options for uploading your files to the Hub. You can use these functions independently or integrate them into your library, making it more convenient for your users to interact with the Hub. This guide will show you how to push files: - without using Git. - that are very large with [Git LFS](https://git-lfs.github.com/). - with the `commit` context manager. - with the [`~Repository.push_to_hub`] function. Whenever you want to upload files to the Hub, you need to log in to your Hugging Face account. For more details about authentication, check out [this section](../quick-start#authentication). ## Upload a file Once you've created a repository with [`create_repo`], you can upload a file to your repository using [`upload_file`]. Specify the path of the file to upload, where you want to upload the file to in the repository, and the name of the repository you want to add the file to. Depending on your repository type, you can optionally set the repository type as a `dataset`, `model`, or `space`. ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.upload_file( ... path_or_fileobj="/path/to/local/folder/README.md", ... path_in_repo="README.md", ... repo_id="username/test-dataset", ... repo_type="dataset", ... ) ``` ## Upload a folder Use the [`upload_folder`] function to upload a local folder to an existing repository. Specify the path of the local folder to upload, where you want to upload the folder to in the repository, and the name of the repository you want to add the folder to. Depending on your repository type, you can optionally set the repository type as a `dataset`, `model`, or `space`. ```py >>> from huggingface_hub import HfApi >>> api = HfApi() # Upload all the content from the local folder to your remote Space. # By default, files are uploaded at the root of the repo >>> api.upload_folder( ... folder_path="/path/to/local/space", ... repo_id="username/my-cool-space", ... repo_type="space", ... ) ``` By default, the `.gitignore` file will be taken into account to know which files should be committed or not. By default we check if a `.gitignore` file is present in a commit, and if not, we check if it exists on the Hub. Please be aware that only a `.gitignore` file present at the root of the directory will be used. We do not check for `.gitignore` files in subdirectories. If you don't want to use an hardcoded `.gitignore` file, you can use the `allow_patterns` and `ignore_patterns` arguments to filter which files to upload. These parameters accept either a single pattern or a list of patterns. Patterns are Standard Wildcards (globbing patterns) as documented [here](https://tldp.org/LDP/GNU-Linux-Tools-Summary/html/x11655.htm). If both `allow_patterns` and `ignore_patterns` are provided, both constraints apply. Beside the `.gitignore` file and allow/ignore patterns, any `.git/` folder present in any subdirectory will be ignored. ```py >>> api.upload_folder( ... folder_path="/path/to/local/folder", ... path_in_repo="my-dataset/train", # Upload to a specific folder ... repo_id="username/test-dataset", ... repo_type="dataset", ... ignore_patterns="**/logs/*.txt", # Ignore all text logs ... ) ``` You can also use the `delete_patterns` argument to specify files you want to delete from the repo in the same commit. This can prove useful if you want to clean a remote folder before pushing files in it and you don't know which files already exists. The example below uploads the local `./logs` folder to the remote `/experiment/logs/` folder. Only txt files are uploaded but before that, all previous logs on the repo on deleted. All of this in a single commit. ```py >>> api.upload_folder( ... folder_path="/path/to/local/folder/logs", ... repo_id="username/trained-model", ... path_in_repo="experiment/logs/", ... allow_patterns="*.txt", # Upload all local text files ... delete_patterns="*.txt", # Delete all remote text files before ... ) ``` ## Upload from the CLI You can use the `huggingface-cli upload` command from the terminal to directly upload files to the Hub. Internally it uses the same [`upload_file`] and [`upload_folder`] helpers described above. You can either upload a single file or an entire folder: ```bash # Usage: huggingface-cli upload [repo_id] [local_path] [path_in_repo] >>> huggingface-cli upload Wauplin/my-cool-model ./models/model.safetensors model.safetensors https://huggingface.co/Wauplin/my-cool-model/blob/main/model.safetensors >>> huggingface-cli upload Wauplin/my-cool-model ./models . https://huggingface.co/Wauplin/my-cool-model/tree/main ``` `local_path` and `path_in_repo` are optional and can be implicitly inferred. If `local_path` is not set, the tool will check if a local folder or file has the same name as the `repo_id`. If that's the case, its content will be uploaded. Otherwise, an exception is raised asking the user to explicitly set `local_path`. In any case, if `path_in_repo` is not set, files are uploaded at the root of the repo. For more details about the CLI upload command, please refer to the [CLI guide](./cli#huggingface-cli-upload). ## Upload a large folder In most cases, the [`upload_folder`] method and `huggingface-cli upload` command should be the go-to solutions to upload files to the Hub. They ensure a single commit will be made, handle a lot of use cases, and fail explicitly when something wrong happens. However, when dealing with a large amount of data, you will usually prefer a resilient process even if it leads to more commits or requires more CPU usage. The [`upload_large_folder`] method has been implemented in that spirit: - it is resumable: the upload process is split into many small tasks (hashing files, pre-uploading them, and committing them). Each time a task is completed, the result is cached locally in a `./cache/huggingface` folder inside the folder you are trying to upload. By doing so, restarting the process after an interruption will resume all completed tasks. - it is multi-threaded: hashing large files and pre-uploading them benefits a lot from multithreading if your machine allows it. - it is resilient to errors: a high-level retry-mechanism has been added to retry each independent task indefinitely until it passes (no matter if it's a OSError, ConnectionError, PermissionError, etc.). This mechanism is double-edged. If transient errors happen, the process will continue and retry. If permanent errors happen (e.g. permission denied), it will retry indefinitely without solving the root cause. If you want more technical details about how `upload_large_folder` is implemented under the hood, please have a look to the [`upload_large_folder`] package reference. Here is how to use [`upload_large_folder`] in a script. The method signature is very similar to [`upload_folder`]: ```py >>> api.upload_large_folder( ... repo_id="HuggingFaceM4/Docmatix", ... repo_type="dataset", ... folder_path="/path/to/local/docmatix", ... ) ``` You will see the following output in your terminal: ``` Repo created: https://huggingface.co/datasets/HuggingFaceM4/Docmatix Found 5 candidate files to upload Recovering from metadata files: 100%|█████████████████████████████████████| 5/5 [00:00<00:00, 542.66it/s] ---------- 2024-07-22 17:23:17 (0:00:00) ---------- Files: hashed 5/5 (5.0G/5.0G) | pre-uploaded: 0/5 (0.0/5.0G) | committed: 0/5 (0.0/5.0G) | ignored: 0 Workers: hashing: 0 | get upload mode: 0 | pre-uploading: 5 | committing: 0 | waiting: 11 --------------------------------------------------- ``` First, the repo is created if it didn't exist before. Then, the local folder is scanned for files to upload. For each file, we try to recover metadata information (from a previously interrupted upload). From there, it is able to launch workers and print an update status every 1 minute. Here, we can see that 5 files have already been hashed but not pre-uploaded. 5 workers are pre-uploading files while the 11 others are waiting for a task. A command line is also provided. You can define the number of workers and the level of verbosity in the terminal: ```sh huggingface-cli upload-large-folder HuggingFaceM4/Docmatix --repo-type=dataset /path/to/local/docmatix --num-workers=16 ``` For large uploads, you have to set `repo_type="model"` or `--repo-type=model` explicitly. Usually, this information is implicit in all other `HfApi` methods. This is to avoid having data uploaded to a repository with a wrong type. If that's the case, you'll have to re-upload everything. While being much more robust to upload large folders, `upload_large_folder` is more limited than [`upload_folder`] feature-wise. In practice: - you cannot set a custom `path_in_repo`. If you want to upload to a subfolder, you need to set the proper structure locally. - you cannot set a custom `commit_message` and `commit_description` since multiple commits are created. - you cannot delete from the repo while uploading. Please make a separate commit first. - you cannot create a PR directly. Please create a PR first (from the UI or using [`create_pull_request`]) and then commit to it by passing `revision`. ### Tips and tricks for large uploads There are some limitations to be aware of when dealing with a large amount of data in your repo. Given the time it takes to stream the data, getting an upload/push to fail at the end of the process or encountering a degraded experience, be it on hf.co or when working locally, can be very annoying. Check out our [Repository limitations and recommendations](https://huggingface.co/docs/hub/repositories-recommendations) guide for best practices on how to structure your repositories on the Hub. Let's move on with some practical tips to make your upload process as smooth as possible. - **Start small**: We recommend starting with a small amount of data to test your upload script. It's easier to iterate on a script when failing takes only a little time. - **Expect failures**: Streaming large amounts of data is challenging. You don't know what can happen, but it's always best to consider that something will fail at least once -no matter if it's due to your machine, your connection, or our servers. For example, if you plan to upload a large number of files, it's best to keep track locally of which files you already uploaded before uploading the next batch. You are ensured that an LFS file that is already committed will never be re-uploaded twice but checking it client-side can still save some time. This is what [`upload_large_folder`] does for you. - **Use `hf_xet`**: this leverages the new storage backend for Hub, is written in Rust, and is being rolled out to users right now. In order to upload using `hf_xet` your repo must be enabled to use the Xet storage backend. It is being rolled out now, so join the [waitlist](https://huggingface.co/join/xet) to get onboarded soon! - **Use `hf_transfer`**: this is a Rust-based [library](https://github.com/huggingface/hf_transfer) meant to speed up uploads on machines with very high bandwidth (uploads LFS files). To use `hf_transfer`: 1. Specify the `hf_transfer` extra when installing `huggingface_hub` (i.e., `pip install huggingface_hub[hf_transfer]`). 2. Set `HF_HUB_ENABLE_HF_TRANSFER=1` as an environment variable. `hf_transfer` is a power user tool for uploading LFS files! It is tested and production-ready, but it is less future-proof and lacks user-friendly features like advanced error handling or proxies. For more details, please take a look at this [section](https://huggingface.co/docs/huggingface_hub/hf_transfer). Note that `hf_xet` and `hf_transfer` tools are mutually exclusive. The former is used to upload files to Xet-enabled repos while the later uploads LFS files to regular repos. ## Advanced features In most cases, you won't need more than [`upload_file`] and [`upload_folder`] to upload your files to the Hub. However, `huggingface_hub` has more advanced features to make things easier. Let's have a look at them! ### Faster Uploads Take advantage of faster uploads through `hf_xet`, the Python binding to the [`xet-core`](https://github.com/huggingface/xet-core) library that enables chunk-based deduplication for faster uploads and downloads. `hf_xet` integrates seamlessly with `huggingface_hub`, but uses the Rust `xet-core` library and Xet storage instead of LFS. Xet storage is being rolled out to Hugging Face Hub users at this time, so xet uploads may need to be enabled for your repo for `hf_xet` to actually upload to the Xet backend. Join the [waitlist](https://huggingface.co/join/xet) to get onboarded soon! Also, `hf_xet` today only works with files on the file system, so cannot be used with file-like objects (byte-arrays, buffers). `hf_xet` uses the Xet storage system, which breaks files down into immutable chunks, storing collections of these chunks (called blocks or xorbs) remotely and retrieving them to reassemble the file when requested. When uploading, after confirming the user is authorized to write to this repo, `hf_xet` will scan the files, breaking them down into their chunks and collecting those chunks into xorbs (and deduplicating across known chunks), and then will be upload these xorbs to the Xet content-addressable service (CAS), which will verify the integrity of the xorbs, register the xorb metadata along with the LFS SHA256 hash (to support lookup/download), and write the xorbs to remote storage. To enable it, specify the `hf_xet` extra when installing `huggingface_hub`: ```bash pip install -U "huggingface_hub[hf_xet]" ``` All other `huggingface_hub` APIs will continue to work without any modification. To learn more about the benefits of Xet storage and `hf_xet`, refer to this [section](https://huggingface.co/docs/hub/storage-backends). **Cluster / Distributed Filesystem Upload Considerations** When uploading from a cluster, the files being uploaded often reside on a distributed or networked filesystem (NFS, EBS, Lustre, Fsx, etc). Xet storage will chunk those files and write them into blocks (also called xorbs) locally, and once the block is completed will upload them. For better performance when uploading from a distributed filesystem, make sure to set [`HF_XET_CACHE`](../package_reference/environment_variables#hfxetcache) to a directory that is on a local disk (ex. a local NVMe or SSD disk). The default location for the Xet cache is under `HF_HOME` at (`~/.cache/huggingface/xet`) and this being in the user's home directory is often also located on the distributed filesystem. ### Non-blocking uploads In some cases, you want to push data without blocking your main thread. This is particularly useful to upload logs and artifacts while continuing a training. To do so, you can use the `run_as_future` argument in both [`upload_file`] and [`upload_folder`]. This will return a [`concurrent.futures.Future`](https://docs.python.org/3/library/concurrent.futures.html#future-objects) object that you can use to check the status of the upload. ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> future = api.upload_folder( # Upload in the background (non-blocking action) ... repo_id="username/my-model", ... folder_path="checkpoints-001", ... run_as_future=True, ... ) >>> future Future(...) >>> future.done() False >>> future.result() # Wait for the upload to complete (blocking action) ... ``` Background jobs are queued when using `run_as_future=True`. This means that you are guaranteed that the jobs will be executed in the correct order. Even though background jobs are mostly useful to upload data/create commits, you can queue any method you like using [`run_as_future`]. For instance, you can use it to create a repo and then upload data to it in the background. The built-in `run_as_future` argument in upload methods is just an alias around it. ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.run_as_future(api.create_repo, "username/my-model", exists_ok=True) Future(...) >>> api.upload_file( ... repo_id="username/my-model", ... path_in_repo="file.txt", ... path_or_fileobj=b"file content", ... run_as_future=True, ... ) Future(...) ``` ### Upload a folder by chunks [`upload_folder`] makes it easy to upload an entire folder to the Hub. However, for large folders (thousands of files or hundreds of GB), we recommend using [`upload_large_folder`], which splits the upload into multiple commits. See the [Upload a large folder](#upload-a-large-folder) section for more details. ### Scheduled uploads The Hugging Face Hub makes it easy to save and version data. However, there are some limitations when updating the same file thousands of times. For instance, you might want to save logs of a training process or user feedback on a deployed Space. In these cases, uploading the data as a dataset on the Hub makes sense, but it can be hard to do properly. The main reason is that you don't want to version every update of your data because it'll make the git repository unusable. The [`CommitScheduler`] class offers a solution to this problem. The idea is to run a background job that regularly pushes a local folder to the Hub. Let's assume you have a Gradio Space that takes as input some text and generates two translations of it. Then, the user can select their preferred translation. For each run, you want to save the input, output, and user preference to analyze the results. This is a perfect use case for [`CommitScheduler`]; you want to save data to the Hub (potentially millions of user feedback), but you don't _need_ to save in real-time each user's input. Instead, you can save the data locally in a JSON file and upload it every 10 minutes. For example: ```py >>> import json >>> import uuid >>> from pathlib import Path >>> import gradio as gr >>> from huggingface_hub import CommitScheduler # Define the file where to save the data. Use UUID to make sure not to overwrite existing data from a previous run. >>> feedback_file = Path("user_feedback/") / f"data_{uuid.uuid4()}.json" >>> feedback_folder = feedback_file.parent # Schedule regular uploads. Remote repo and local folder are created if they don't already exist. >>> scheduler = CommitScheduler( ... repo_id="report-translation-feedback", ... repo_type="dataset", ... folder_path=feedback_folder, ... path_in_repo="data", ... every=10, ... ) # Define the function that will be called when the user submits its feedback (to be called in Gradio) >>> def save_feedback(input_text:str, output_1: str, output_2:str, user_choice: int) -> None: ... """ ... Append input/outputs and user feedback to a JSON Lines file using a thread lock to avoid concurrent writes from different users. ... """ ... with scheduler.lock: ... with feedback_file.open("a") as f: ... f.write(json.dumps({"input": input_text, "output_1": output_1, "output_2": output_2, "user_choice": user_choice})) ... f.write("\n") # Start Gradio >>> with gr.Blocks() as demo: >>> ... # define Gradio demo + use `save_feedback` >>> demo.launch() ``` And that's it! User input/outputs and feedback will be available as a dataset on the Hub. By using a unique JSON file name, you are guaranteed you won't overwrite data from a previous run or data from another Spaces/replicas pushing concurrently to the same repository. For more details about the [`CommitScheduler`], here is what you need to know: - **append-only:** It is assumed that you will only add content to the folder. You must only append data to existing files or create new files. Deleting or overwriting a file might corrupt your repository. - **git history**: The scheduler will commit the folder every `every` minutes. To avoid polluting the git repository too much, it is recommended to set a minimal value of 5 minutes. Besides, the scheduler is designed to avoid empty commits. If no new content is detected in the folder, the scheduled commit is dropped. - **errors:** The scheduler run as background thread. It is started when you instantiate the class and never stops. In particular, if an error occurs during the upload (example: connection issue), the scheduler will silently ignore it and retry at the next scheduled commit. - **thread-safety:** In most cases it is safe to assume that you can write to a file without having to worry about a lock file. The scheduler will not crash or be corrupted if you write content to the folder while it's uploading. In practice, _it is possible_ that concurrency issues happen for heavy-loaded apps. In this case, we advice to use the `scheduler.lock` lock to ensure thread-safety. The lock is blocked only when the scheduler scans the folder for changes, not when it uploads data. You can safely assume that it will not affect the user experience on your Space. #### Space persistence demo Persisting data from a Space to a Dataset on the Hub is the main use case for [`CommitScheduler`]. Depending on the use case, you might want to structure your data differently. The structure has to be robust to concurrent users and restarts which often implies generating UUIDs. Besides robustness, you should upload data in a format readable by the 🤗 Datasets library for later reuse. We created a [Space](https://huggingface.co/spaces/Wauplin/space_to_dataset_saver) that demonstrates how to save several different data formats (you may need to adapt it for your own specific needs). #### Custom uploads [`CommitScheduler`] assumes your data is append-only and should be uploading "as is". However, you might want to customize the way data is uploaded. You can do that by creating a class inheriting from [`CommitScheduler`] and overwrite the `push_to_hub` method (feel free to overwrite it any way you want). You are guaranteed it will be called every `every` minutes in a background thread. You don't have to worry about concurrency and errors but you must be careful about other aspects, such as pushing empty commits or duplicated data. In the (simplified) example below, we overwrite `push_to_hub` to zip all PNG files in a single archive to avoid overloading the repo on the Hub: ```py class ZipScheduler(CommitScheduler): def push_to_hub(self): # 1. List PNG files png_files = list(self.folder_path.glob("*.png")) if len(png_files) == 0: return None # return early if nothing to commit # 2. Zip png files in a single archive with tempfile.TemporaryDirectory() as tmpdir: archive_path = Path(tmpdir) / "train.zip" with zipfile.ZipFile(archive_path, "w", zipfile.ZIP_DEFLATED) as zip: for png_file in png_files: zip.write(filename=png_file, arcname=png_file.name) # 3. Upload archive self.api.upload_file(..., path_or_fileobj=archive_path) # 4. Delete local png files to avoid re-uploading them later for png_file in png_files: png_file.unlink() ``` When you overwrite `push_to_hub`, you have access to the attributes of [`CommitScheduler`] and especially: - [`HfApi`] client: `api` - Folder parameters: `folder_path` and `path_in_repo` - Repo parameters: `repo_id`, `repo_type`, `revision` - The thread lock: `lock` For more examples of custom schedulers, check out our [demo Space](https://huggingface.co/spaces/Wauplin/space_to_dataset_saver) containing different implementations depending on your use cases. ### create_commit The [`upload_file`] and [`upload_folder`] functions are high-level APIs that are generally convenient to use. We recommend trying these functions first if you don't need to work at a lower level. However, if you want to work at a commit-level, you can use the [`create_commit`] function directly. There are three types of operations supported by [`create_commit`]: - [`CommitOperationAdd`] uploads a file to the Hub. If the file already exists, the file contents are overwritten. This operation accepts two arguments: - `path_in_repo`: the repository path to upload a file to. - `path_or_fileobj`: either a path to a file on your filesystem or a file-like object. This is the content of the file to upload to the Hub. - [`CommitOperationDelete`] removes a file or a folder from a repository. This operation accepts `path_in_repo` as an argument. - [`CommitOperationCopy`] copies a file within a repository. This operation accepts three arguments: - `src_path_in_repo`: the repository path of the file to copy. - `path_in_repo`: the repository path where the file should be copied. - `src_revision`: optional - the revision of the file to copy if your want to copy a file from a different branch/revision. For example, if you want to upload two files and delete a file in a Hub repository: 1. Use the appropriate `CommitOperation` to add or delete a file and to delete a folder: ```py >>> from huggingface_hub import HfApi, CommitOperationAdd, CommitOperationDelete >>> api = HfApi() >>> operations = [ ... CommitOperationAdd(path_in_repo="LICENSE.md", path_or_fileobj="~/repo/LICENSE.md"), ... CommitOperationAdd(path_in_repo="weights.h5", path_or_fileobj="~/repo/weights-final.h5"), ... CommitOperationDelete(path_in_repo="old-weights.h5"), ... CommitOperationDelete(path_in_repo="logs/"), ... CommitOperationCopy(src_path_in_repo="image.png", path_in_repo="duplicate_image.png"), ... ] ``` 2. Pass your operations to [`create_commit`]: ```py >>> api.create_commit( ... repo_id="lysandre/test-model", ... operations=operations, ... commit_message="Upload my model weights and license", ... ) ``` In addition to [`upload_file`] and [`upload_folder`], the following functions also use [`create_commit`] under the hood: - [`delete_file`] deletes a single file from a repository on the Hub. - [`delete_folder`] deletes an entire folder from a repository on the Hub. - [`metadata_update`] updates a repository's metadata. For more detailed information, take a look at the [`HfApi`] reference. ### Preupload LFS files before commit In some cases, you might want to upload huge files to S3 **before** making the commit call. For example, if you are committing a dataset in several shards that are generated in-memory, you would need to upload the shards one by one to avoid an out-of-memory issue. A solution is to upload each shard as a separate commit on the repo. While being perfectly valid, this solution has the drawback of potentially messing the git history by generating tens of commits. To overcome this issue, you can upload your files one by one to S3 and then create a single commit at the end. This is possible using [`preupload_lfs_files`] in combination with [`create_commit`]. This is a power-user method. Directly using [`upload_file`], [`upload_folder`] or [`create_commit`] instead of handling the low-level logic of pre-uploading files is the way to go in the vast majority of cases. The main caveat of [`preupload_lfs_files`] is that until the commit is actually made, the upload files are not accessible on the repo on the Hub. If you have a question, feel free to ping us on our Discord or in a GitHub issue. Here is a simple example illustrating how to pre-upload files: ```py >>> from huggingface_hub import CommitOperationAdd, preupload_lfs_files, create_commit, create_repo >>> repo_id = create_repo("test_preupload").repo_id >>> operations = [] # List of all `CommitOperationAdd` objects that will be generated >>> for i in range(5): ... content = ... # generate binary content ... addition = CommitOperationAdd(path_in_repo=f"shard_{i}_of_5.bin", path_or_fileobj=content) ... preupload_lfs_files(repo_id, additions=[addition]) ... operations.append(addition) >>> # Create commit >>> create_commit(repo_id, operations=operations, commit_message="Commit all shards") ``` First, we create the [`CommitOperationAdd`] objects one by one. In a real-world example, those would contain the generated shards. Each file is uploaded before generating the next one. During the [`preupload_lfs_files`] step, **the `CommitOperationAdd` object is mutated**. You should only use it to pass it directly to [`create_commit`]. The main update of the object is that **the binary content is removed** from it, meaning that it will be garbage-collected if you don't store another reference to it. This is expected as we don't want to keep in memory the content that is already uploaded. Finally we create the commit by passing all the operations to [`create_commit`]. You can pass additional operations (add, delete or copy) that have not been processed yet and they will be handled correctly. ## (legacy) Upload files with Git LFS All the methods described above use the Hub's API to upload files. This is the recommended way to upload files to the Hub. However, we also provide [`Repository`], a wrapper around the git tool to manage a local repository. Although [`Repository`] is not formally deprecated, we recommend using the HTTP-based methods described above instead. For more details about this recommendation, please have a look at [this guide](../concepts/git_vs_http) explaining the core differences between HTTP-based and Git-based approaches. Git LFS automatically handles files larger than 10MB. But for very large files (>5GB), you need to install a custom transfer agent for Git LFS: ```bash huggingface-cli lfs-enable-largefiles ``` You should install this for each repository that has a very large file. Once installed, you'll be able to push files larger than 5GB. ### commit context manager The `commit` context manager handles four of the most common Git commands: pull, add, commit, and push. `git-lfs` automatically tracks any file larger than 10MB. In the following example, the `commit` context manager: 1. Pulls from the `text-files` repository. 2. Adds a change made to `file.txt`. 3. Commits the change. 4. Pushes the change to the `text-files` repository. ```python >>> from huggingface_hub import Repository >>> with Repository(local_dir="text-files", clone_from="/text-files").commit(commit_message="My first file :)"): ... with open("file.txt", "w+") as f: ... f.write(json.dumps({"hey": 8})) ``` Here is another example of how to use the `commit` context manager to save and upload a file to a repository: ```python >>> import torch >>> model = torch.nn.Transformer() >>> with Repository("torch-model", clone_from="/torch-model", token=True).commit(commit_message="My cool model :)"): ... torch.save(model.state_dict(), "model.pt") ``` Set `blocking=False` if you would like to push your commits asynchronously. Non-blocking behavior is helpful when you want to continue running your script while your commits are being pushed. ```python >>> with repo.commit(commit_message="My cool model :)", blocking=False) ``` You can check the status of your push with the `command_queue` method: ```python >>> last_command = repo.command_queue[-1] >>> last_command.status ``` Refer to the table below for the possible statuses: | Status | Description | | -------- | ------------------------------------ | | -1 | The push is ongoing. | | 0 | The push has completed successfully. | | Non-zero | An error has occurred. | When `blocking=False`, commands are tracked, and your script will only exit when all pushes are completed, even if other errors occur in your script. Some additional useful commands for checking the status of a push include: ```python # Inspect an error. >>> last_command.stderr # Check whether a push is completed or ongoing. >>> last_command.is_done # Check whether a push command has errored. >>> last_command.failed ``` ### push_to_hub The [`Repository`] class has a [`~Repository.push_to_hub`] function to add files, make a commit, and push them to a repository. Unlike the `commit` context manager, you'll need to pull from a repository first before calling [`~Repository.push_to_hub`]. For example, if you've already cloned a repository from the Hub, then you can initialize the `repo` from the local directory: ```python >>> from huggingface_hub import Repository >>> repo = Repository(local_dir="path/to/local/repo") ``` Update your local clone with [`~Repository.git_pull`] and then push your file to the Hub: ```py >>> repo.git_pull() >>> repo.push_to_hub(commit_message="Commit my-awesome-file to the Hub") ``` However, if you aren't ready to push a file yet, you can use [`~Repository.git_add`] and [`~Repository.git_commit`] to only add and commit your file: ```py >>> repo.git_add("path/to/file") >>> repo.git_commit(commit_message="add my first model config file :)") ``` When you're ready, push the file to your repository with [`~Repository.git_push`]: ```py >>> repo.git_push() ``` huggingface_hub-0.31.1/docs/source/en/guides/webhooks.md000066400000000000000000000277071500667546600232120ustar00rootroot00000000000000 # Webhooks Webhooks are a foundation for MLOps-related features. They allow you to listen for new changes on specific repos or to all repos belonging to particular users/organizations you're interested in following. This guide will first explain how to manage webhooks programmatically. Then we'll see how to leverage `huggingface_hub` to create a server listening to webhooks and deploy it to a Space. This guide assumes you are familiar with the concept of webhooks on the Huggingface Hub. To learn more about webhooks themselves, you should read this [guide](https://huggingface.co/docs/hub/webhooks) first. ## Managing Webhooks `huggingface_hub` allows you to manage your webhooks programmatically. You can list your existing webhooks, create new ones, and update, enable, disable or delete them. This section guides you through the procedures using the Hugging Face Hub's API functions. ### Creating a Webhook To create a new webhook, use [`create_webhook`] and specify the URL where payloads should be sent, what events should be watched, and optionally set a domain and a secret for security. ```python from huggingface_hub import create_webhook # Example: Creating a webhook webhook = create_webhook( url="https://webhook.site/your-custom-url", watched=[{"type": "user", "name": "your-username"}, {"type": "org", "name": "your-org-name"}], domains=["repo", "discussion"], secret="your-secret" ) ``` ### Listing Webhooks To see all the webhooks you have configured, you can list them with [`list_webhooks`]. This is useful to review their IDs, URLs, and statuses. ```python from huggingface_hub import list_webhooks # Example: Listing all webhooks webhooks = list_webhooks() for webhook in webhooks: print(webhook) ``` ### Updating a Webhook If you need to change the configuration of an existing webhook, such as the URL or the events it watches, you can update it using [`update_webhook`]. ```python from huggingface_hub import update_webhook # Example: Updating a webhook updated_webhook = update_webhook( webhook_id="your-webhook-id", url="https://new.webhook.site/url", watched=[{"type": "user", "name": "new-username"}], domains=["repo"] ) ``` ### Enabling and Disabling Webhooks You might want to temporarily disable a webhook without deleting it. This can be done using [`disable_webhook`], and the webhook can be re-enabled later with [`enable_webhook`]. ```python from huggingface_hub import enable_webhook, disable_webhook # Example: Enabling a webhook enabled_webhook = enable_webhook("your-webhook-id") print("Enabled:", enabled_webhook) # Example: Disabling a webhook disabled_webhook = disable_webhook("your-webhook-id") print("Disabled:", disabled_webhook) ``` ### Deleting a Webhook When a webhook is no longer needed, it can be permanently deleted using [`delete_webhook`]. ```python from huggingface_hub import delete_webhook # Example: Deleting a webhook delete_webhook("your-webhook-id") ``` ## Webhooks Server The base class that we will use in this guides section is [`WebhooksServer`]. It is a class for easily configuring a server that can receive webhooks from the Huggingface Hub. The server is based on a [Gradio](https://gradio.app/) app. It has a UI to display instructions for you or your users and an API to listen to webhooks. To see a running example of a webhook server, check out the [Spaces CI Bot](https://huggingface.co/spaces/spaces-ci-bot/webhook) one. It is a Space that launches ephemeral environments when a PR is opened on a Space. This is an [experimental feature](../package_reference/environment_variables#hfhubdisableexperimentalwarning). This means that we are still working on improving the API. Breaking changes might be introduced in the future without prior notice. Make sure to pin the version of `huggingface_hub` in your requirements. ### Create an endpoint Implementing a webhook endpoint is as simple as decorating a function. Let's see a first example to explain the main concepts: ```python # app.py from huggingface_hub import webhook_endpoint, WebhookPayload @webhook_endpoint async def trigger_training(payload: WebhookPayload) -> None: if payload.repo.type == "dataset" and payload.event.action == "update": # Trigger a training job if a dataset is updated ... ``` Save this snippet in a file called `'app.py'` and run it with `'python app.py'`. You should see a message like this: ```text Webhook secret is not defined. This means your webhook endpoints will be open to everyone. To add a secret, set `WEBHOOK_SECRET` as environment variable or pass it at initialization: `app = WebhooksServer(webhook_secret='my_secret', ...)` For more details about webhook secrets, please refer to https://huggingface.co/docs/hub/webhooks#webhook-secret. Running on local URL: http://127.0.0.1:7860 Running on public URL: https://1fadb0f52d8bf825fc.gradio.live This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces Webhooks are correctly setup and ready to use: - POST https://1fadb0f52d8bf825fc.gradio.live/webhooks/trigger_training Go to https://huggingface.co/settings/webhooks to setup your webhooks. ``` Good job! You just launched a webhook server! Let's break down what happened exactly: 1. By decorating a function with [`webhook_endpoint`], a [`WebhooksServer`] object has been created in the background. As you can see, this server is a Gradio app running on http://127.0.0.1:7860. If you open this URL in your browser, you will see a landing page with instructions about the registered webhooks. 2. A Gradio app is a FastAPI server under the hood. A new POST route `/webhooks/trigger_training` has been added to it. This is the route that will listen to webhooks and run the `trigger_training` function when triggered. FastAPI will automatically parse the payload and pass it to the function as a [`WebhookPayload`] object. This is a `pydantic` object that contains all the information about the event that triggered the webhook. 3. The Gradio app also opened a tunnel to receive requests from the internet. This is the interesting part: you can configure a Webhook on https://huggingface.co/settings/webhooks pointing to your local machine. This is useful for debugging your webhook server and quickly iterating before deploying it to a Space. 4. Finally, the logs also tell you that your server is currently not secured by a secret. This is not problematic for local debugging but is to keep in mind for later. By default, the server is started at the end of your script. If you are running it in a notebook, you can start the server manually by calling `decorated_function.run()`. Since a unique server is used, you only have to start the server once even if you have multiple endpoints. ### Configure a Webhook Now that you have a webhook server running, you want to configure a Webhook to start receiving messages. Go to https://huggingface.co/settings/webhooks, click on "Add a new webhook" and configure your Webhook. Set the target repositories you want to watch and the Webhook URL, here `https://1fadb0f52d8bf825fc.gradio.live/webhooks/trigger_training`.
And that's it! You can now trigger that webhook by updating the target repository (e.g. push a commit). Check the Activity tab of your Webhook to see the events that have been triggered. Now that you have a working setup, you can test it and quickly iterate. If you modify your code and restart the server, your public URL might change. Make sure to update the webhook configuration on the Hub if needed. ### Deploy to a Space Now that you have a working webhook server, the goal is to deploy it to a Space. Go to https://huggingface.co/new-space to create a Space. Give it a name, select the Gradio SDK and click on "Create Space". Upload your code to the Space in a file called `app.py`. Your Space will start automatically! For more details about Spaces, please refer to this [guide](https://huggingface.co/docs/hub/spaces-overview). Your webhook server is now running on a public Space. If most cases, you will want to secure it with a secret. Go to your Space settings > Section "Repository secrets" > "Add a secret". Set the `WEBHOOK_SECRET` environment variable to the value of your choice. Go back to the [Webhooks settings](https://huggingface.co/settings/webhooks) and set the secret in the webhook configuration. Now, only requests with the correct secret will be accepted by your server. And this is it! Your Space is now ready to receive webhooks from the Hub. Please keep in mind that if you run the Space on a free 'cpu-basic' hardware, it will be shut down after 48 hours of inactivity. If you need a permanent Space, you should consider setting to an [upgraded hardware](https://huggingface.co/docs/hub/spaces-gpus#hardware-specs). ### Advanced usage The guide above explained the quickest way to setup a [`WebhooksServer`]. In this section, we will see how to customize it further. #### Multiple endpoints You can register multiple endpoints on the same server. For example, you might want to have one endpoint to trigger a training job and another one to trigger a model evaluation. You can do this by adding multiple `@webhook_endpoint` decorators: ```python # app.py from huggingface_hub import webhook_endpoint, WebhookPayload @webhook_endpoint async def trigger_training(payload: WebhookPayload) -> None: if payload.repo.type == "dataset" and payload.event.action == "update": # Trigger a training job if a dataset is updated ... @webhook_endpoint async def trigger_evaluation(payload: WebhookPayload) -> None: if payload.repo.type == "model" and payload.event.action == "update": # Trigger an evaluation job if a model is updated ... ``` Which will create two endpoints: ```text (...) Webhooks are correctly setup and ready to use: - POST https://1fadb0f52d8bf825fc.gradio.live/webhooks/trigger_training - POST https://1fadb0f52d8bf825fc.gradio.live/webhooks/trigger_evaluation ``` #### Custom server To get more flexibility, you can also create a [`WebhooksServer`] object directly. This is useful if you want to customize the landing page of your server. You can do this by passing a [Gradio UI](https://gradio.app/docs/#blocks) that will overwrite the default one. For example, you can add instructions for your users or add a form to manually trigger the webhooks. When creating a [`WebhooksServer`], you can register new webhooks using the [`~WebhooksServer.add_webhook`] decorator. Here is a complete example: ```python import gradio as gr from fastapi import Request from huggingface_hub import WebhooksServer, WebhookPayload # 1. Define UI with gr.Blocks() as ui: ... # 2. Create WebhooksServer with custom UI and secret app = WebhooksServer(ui=ui, webhook_secret="my_secret_key") # 3. Register webhook with explicit name @app.add_webhook("/say_hello") async def hello(payload: WebhookPayload): return {"message": "hello"} # 4. Register webhook with implicit name @app.add_webhook async def goodbye(payload: WebhookPayload): return {"message": "goodbye"} # 5. Start server (optional) app.run() ``` 1. We define a custom UI using Gradio blocks. This UI will be displayed on the landing page of the server. 2. We create a [`WebhooksServer`] object with a custom UI and a secret. The secret is optional and can be set with the `WEBHOOK_SECRET` environment variable. 3. We register a webhook with an explicit name. This will create an endpoint at `/webhooks/say_hello`. 4. We register a webhook with an implicit name. This will create an endpoint at `/webhooks/goodbye`. 5. We start the server. This is optional as your server will automatically be started at the end of the script. huggingface_hub-0.31.1/docs/source/en/index.md000066400000000000000000000073231500667546600212100ustar00rootroot00000000000000 # 🤗 Hub client library The `huggingface_hub` library allows you to interact with the [Hugging Face Hub](https://hf.co), a machine learning platform for creators and collaborators. Discover pre-trained models and datasets for your projects or play with the hundreds of machine learning apps hosted on the Hub. You can also create and share your own models and datasets with the community. The `huggingface_hub` library provides a simple way to do all these things with Python. Read the [quick start guide](quick-start) to get up and running with the `huggingface_hub` library. You will learn how to download files from the Hub, create a repository, and upload files to the Hub. Keep reading to learn more about how to manage your repositories on the 🤗 Hub, how to interact in discussions or even how to access the Inference API. ## Contribute All contributions to the `huggingface_hub` are welcomed and equally valued! 🤗 Besides adding or fixing existing issues in the code, you can also help improve the documentation by making sure it is accurate and up-to-date, help answer questions on issues, and request new features you think will improve the library. Take a look at the [contribution guide](https://github.com/huggingface/huggingface_hub/blob/main/CONTRIBUTING.md) to learn more about how to submit a new issue or feature request, how to submit a pull request, and how to test your contributions to make sure everything works as expected. Contributors should also be respectful of our [code of conduct](https://github.com/huggingface/huggingface_hub/blob/main/CODE_OF_CONDUCT.md) to create an inclusive and welcoming collaborative space for everyone. huggingface_hub-0.31.1/docs/source/en/installation.md000066400000000000000000000151721500667546600226030ustar00rootroot00000000000000 # Installation Before you start, you will need to setup your environment by installing the appropriate packages. `huggingface_hub` is tested on **Python 3.8+**. ## Install with pip It is highly recommended to install `huggingface_hub` in a [virtual environment](https://docs.python.org/3/library/venv.html). If you are unfamiliar with Python virtual environments, take a look at this [guide](https://packaging.python.org/en/latest/guides/installing-using-pip-and-virtual-environments/). A virtual environment makes it easier to manage different projects, and avoid compatibility issues between dependencies. Start by creating a virtual environment in your project directory: ```bash python -m venv .env ``` Activate the virtual environment. On Linux and macOS: ```bash source .env/bin/activate ``` Activate virtual environment on Windows: ```bash .env/Scripts/activate ``` Now you're ready to install `huggingface_hub` [from the PyPi registry](https://pypi.org/project/huggingface-hub/): ```bash pip install --upgrade huggingface_hub ``` Once done, [check installation](#check-installation) is working correctly. ### Install optional dependencies Some dependencies of `huggingface_hub` are [optional](https://setuptools.pypa.io/en/latest/userguide/dependency_management.html#optional-dependencies) because they are not required to run the core features of `huggingface_hub`. However, some features of the `huggingface_hub` may not be available if the optional dependencies aren't installed. You can install optional dependencies via `pip`: ```bash # Install dependencies for tensorflow-specific features # /!\ Warning: this is not equivalent to `pip install tensorflow` pip install 'huggingface_hub[tensorflow]' # Install dependencies for both torch-specific and CLI-specific features. pip install 'huggingface_hub[cli,torch]' ``` Here is the list of optional dependencies in `huggingface_hub`: - `cli`: provide a more convenient CLI interface for `huggingface_hub`. - `fastai`, `torch`, `tensorflow`: dependencies to run framework-specific features. - `dev`: dependencies to contribute to the lib. Includes `testing` (to run tests), `typing` (to run type checker) and `quality` (to run linters). ### Install from source In some cases, it is interesting to install `huggingface_hub` directly from source. This allows you to use the bleeding edge `main` version rather than the latest stable version. The `main` version is useful for staying up-to-date with the latest developments, for instance if a bug has been fixed since the last official release but a new release hasn't been rolled out yet. However, this means the `main` version may not always be stable. We strive to keep the `main` version operational, and most issues are usually resolved within a few hours or a day. If you run into a problem, please open an Issue so we can fix it even sooner! ```bash pip install git+https://github.com/huggingface/huggingface_hub ``` When installing from source, you can also specify a specific branch. This is useful if you want to test a new feature or a new bug-fix that has not been merged yet: ```bash pip install git+https://github.com/huggingface/huggingface_hub@my-feature-branch ``` Once done, [check installation](#check-installation) is working correctly. ### Editable install Installing from source allows you to setup an [editable install](https://pip.pypa.io/en/stable/topics/local-project-installs/#editable-installs). This is a more advanced installation if you plan to contribute to `huggingface_hub` and need to test changes in the code. You need to clone a local copy of `huggingface_hub` on your machine. ```bash # First, clone repo locally git clone https://github.com/huggingface/huggingface_hub.git # Then, install with -e flag cd huggingface_hub pip install -e . ``` These commands will link the folder you cloned the repository to and your Python library paths. Python will now look inside the folder you cloned to in addition to the normal library paths. For example, if your Python packages are typically installed in `./.venv/lib/python3.13/site-packages/`, Python will also search the folder you cloned `./huggingface_hub/`. ## Install with conda If you are more familiar with it, you can install `huggingface_hub` using the [conda-forge channel](https://anaconda.org/conda-forge/huggingface_hub): ```bash conda install -c conda-forge huggingface_hub ``` Once done, [check installation](#check-installation) is working correctly. ## Check installation Once installed, check that `huggingface_hub` works properly by running the following command: ```bash python -c "from huggingface_hub import model_info; print(model_info('gpt2'))" ``` This command will fetch information from the Hub about the [gpt2](https://huggingface.co/gpt2) model. Output should look like this: ```text Model Name: gpt2 Tags: ['pytorch', 'tf', 'jax', 'tflite', 'rust', 'safetensors', 'gpt2', 'text-generation', 'en', 'doi:10.57967/hf/0039', 'transformers', 'exbert', 'license:mit', 'has_space'] Task: text-generation ``` ## Windows limitations With our goal of democratizing good ML everywhere, we built `huggingface_hub` to be a cross-platform library and in particular to work correctly on both Unix-based and Windows systems. However, there are a few cases where `huggingface_hub` has some limitations when run on Windows. Here is an exhaustive list of known issues. Please let us know if you encounter any undocumented problem by opening [an issue on Github](https://github.com/huggingface/huggingface_hub/issues/new/choose). - `huggingface_hub`'s cache system relies on symlinks to efficiently cache files downloaded from the Hub. On Windows, you must activate developer mode or run your script as admin to enable symlinks. If they are not activated, the cache-system still works but in a non-optimized manner. Please read [the cache limitations](./guides/manage-cache#limitations) section for more details. - Filepaths on the Hub can have special characters (e.g. `"path/to?/my/file"`). Windows is more restrictive on [special characters](https://learn.microsoft.com/en-us/windows/win32/intl/character-sets-used-in-file-names) which makes it impossible to download those files on Windows. Hopefully this is a rare case. Please reach out to the repo owner if you think this is a mistake or to us to figure out a solution. ## Next steps Once `huggingface_hub` is properly installed on your machine, you might want [configure environment variables](package_reference/environment_variables) or [check one of our guides](guides/overview) to get started. huggingface_hub-0.31.1/docs/source/en/package_reference/000077500000000000000000000000001500667546600231635ustar00rootroot00000000000000huggingface_hub-0.31.1/docs/source/en/package_reference/authentication.md000066400000000000000000000013211500667546600265210ustar00rootroot00000000000000 # Authentication The `huggingface_hub` library allows users to programmatically manage authentication to the Hub. This includes logging in, logging out, switching between tokens, and listing available tokens. For more details about authentication, check out [this section](../quick-start#authentication). ## login [[autodoc]] login ## interpreter_login [[autodoc]] interpreter_login ## notebook_login [[autodoc]] notebook_login ## logout [[autodoc]] logout ## auth_switch [[autodoc]] auth_switch ## auth_list [[autodoc]] auth_list huggingface_hub-0.31.1/docs/source/en/package_reference/cache.md000066400000000000000000000024341500667546600245530ustar00rootroot00000000000000 # Cache-system reference The caching system was updated in v0.8.0 to become the central cache-system shared across libraries that depend on the Hub. Read the [cache-system guide](../guides/manage-cache) for a detailed presentation of caching at HF. ## Helpers ### try_to_load_from_cache [[autodoc]] huggingface_hub.try_to_load_from_cache ### cached_assets_path [[autodoc]] huggingface_hub.cached_assets_path ### scan_cache_dir [[autodoc]] huggingface_hub.scan_cache_dir ## Data structures All structures are built and returned by [`scan_cache_dir`] and are immutable. ### HFCacheInfo [[autodoc]] huggingface_hub.HFCacheInfo ### CachedRepoInfo [[autodoc]] huggingface_hub.CachedRepoInfo - size_on_disk_str - refs ### CachedRevisionInfo [[autodoc]] huggingface_hub.CachedRevisionInfo - size_on_disk_str - nb_files ### CachedFileInfo [[autodoc]] huggingface_hub.CachedFileInfo - size_on_disk_str ### DeleteCacheStrategy [[autodoc]] huggingface_hub.DeleteCacheStrategy - expected_freed_size_str ## Exceptions ### CorruptedCacheException [[autodoc]] huggingface_hub.CorruptedCacheException huggingface_hub-0.31.1/docs/source/en/package_reference/cards.md000066400000000000000000000033331500667546600246030ustar00rootroot00000000000000 # Repository Cards The huggingface_hub library provides a Python interface to create, share, and update Model/Dataset Cards. Visit the [dedicated documentation page](https://huggingface.co/docs/hub/models-cards) for a deeper view of what Model Cards on the Hub are, and how they work under the hood. You can also check out our [Model Cards guide](../how-to-model-cards) to get a feel for how you would use these utilities in your own projects. ## Repo Card The `RepoCard` object is the parent class of [`ModelCard`], [`DatasetCard`] and `SpaceCard`. [[autodoc]] huggingface_hub.repocard.RepoCard - __init__ - all ## Card Data The [`CardData`] object is the parent class of [`ModelCardData`] and [`DatasetCardData`]. [[autodoc]] huggingface_hub.repocard_data.CardData ## Model Cards ### ModelCard [[autodoc]] ModelCard ### ModelCardData [[autodoc]] ModelCardData ## Dataset Cards Dataset cards are also known as Data Cards in the ML Community. ### DatasetCard [[autodoc]] DatasetCard ### DatasetCardData [[autodoc]] DatasetCardData ## Space Cards ### SpaceCard [[autodoc]] SpaceCard ### SpaceCardData [[autodoc]] SpaceCardData ## Utilities ### EvalResult [[autodoc]] EvalResult ### model_index_to_eval_results [[autodoc]] huggingface_hub.repocard_data.model_index_to_eval_results ### eval_results_to_model_index [[autodoc]] huggingface_hub.repocard_data.eval_results_to_model_index ### metadata_eval_result [[autodoc]] huggingface_hub.repocard.metadata_eval_result ### metadata_update [[autodoc]] huggingface_hub.repocard.metadata_update huggingface_hub-0.31.1/docs/source/en/package_reference/collections.md000066400000000000000000000013751500667546600260310ustar00rootroot00000000000000 # Managing collections Check out the [`HfApi`] documentation page for the reference of methods to manage your Space on the Hub. - Get collection content: [`get_collection`] - Create new collection: [`create_collection`] - Update a collection: [`update_collection_metadata`] - Delete a collection: [`delete_collection`] - Add an item to a collection: [`add_collection_item`] - Update an item in a collection: [`update_collection_item`] - Remove an item from a collection: [`delete_collection_item`] ### Collection [[autodoc]] Collection ### CollectionItem [[autodoc]] CollectionItem huggingface_hub-0.31.1/docs/source/en/package_reference/community.md000066400000000000000000000015141500667546600255320ustar00rootroot00000000000000 # Interacting with Discussions and Pull Requests Check the [`HfApi`] documentation page for the reference of methods enabling interaction with Pull Requests and Discussions on the Hub. - [`get_repo_discussions`] - [`get_discussion_details`] - [`create_discussion`] - [`create_pull_request`] - [`rename_discussion`] - [`comment_discussion`] - [`edit_discussion_comment`] - [`change_discussion_status`] - [`merge_pull_request`] ## Data structures [[autodoc]] Discussion [[autodoc]] DiscussionWithDetails [[autodoc]] DiscussionEvent [[autodoc]] DiscussionComment [[autodoc]] DiscussionStatusChange [[autodoc]] DiscussionCommit [[autodoc]] DiscussionTitleChange huggingface_hub-0.31.1/docs/source/en/package_reference/environment_variables.md000066400000000000000000000314071500667546600301060ustar00rootroot00000000000000 # Environment variables `huggingface_hub` can be configured using environment variables. If you are unfamiliar with environment variable, here are generic articles about them [on macOS and Linux](https://linuxize.com/post/how-to-set-and-list-environment-variables-in-linux/) and on [Windows](https://phoenixnap.com/kb/windows-set-environment-variable). This page will guide you through all environment variables specific to `huggingface_hub` and their meaning. ## Generic ### HF_INFERENCE_ENDPOINT To configure the inference api base url. You might want to set this variable if your organization is pointing at an API Gateway rather than directly at the inference api. Defaults to `"https://api-inference.huggingface.co"`. ### HF_HOME To configure where `huggingface_hub` will locally store data. In particular, your token and the cache will be stored in this folder. Defaults to `"~/.cache/huggingface"` unless [XDG_CACHE_HOME](#xdgcachehome) is set. ### HF_HUB_CACHE To configure where repositories from the Hub will be cached locally (models, datasets and spaces). Defaults to `"$HF_HOME/hub"` (e.g. `"~/.cache/huggingface/hub"` by default). ### HF_XET_CACHE To configure where Xet chunks (byte ranges from files managed by Xet backend) are cached locally. Defaults to `"$HF_HOME/xet"` (e.g. `"~/.cache/huggingface/xet"` by default). ### HF_ASSETS_CACHE To configure where [assets](../guides/manage-cache#caching-assets) created by downstream libraries will be cached locally. Those assets can be preprocessed data, files downloaded from GitHub, logs,... Defaults to `"$HF_HOME/assets"` (e.g. `"~/.cache/huggingface/assets"` by default). ### HF_TOKEN To configure the User Access Token to authenticate to the Hub. If set, this value will overwrite the token stored on the machine (in either `$HF_TOKEN_PATH` or `"$HF_HOME/token"` if the former is not set). For more details about authentication, check out [this section](../quick-start#authentication). ### HF_TOKEN_PATH To configure where `huggingface_hub` should store the User Access Token. Defaults to `"$HF_HOME/token"` (e.g. `~/.cache/huggingface/token` by default). ### HF_HUB_VERBOSITY Set the verbosity level of the `huggingface_hub`'s logger. Must be one of `{"debug", "info", "warning", "error", "critical"}`. Defaults to `"warning"`. For more details, see [logging reference](../package_reference/utilities#huggingface_hub.utils.logging.get_verbosity). ### HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD This environment variable has been deprecated and is now ignored by `huggingface_hub`. Downloading files to the local dir does not rely on symlinks anymore. ### HF_HUB_ETAG_TIMEOUT Integer value to define the number of seconds to wait for server response when fetching the latest metadata from a repo before downloading a file. If the request times out, `huggingface_hub` will default to the locally cached files. Setting a lower value speeds up the workflow for machines with a slow connection that have already cached files. A higher value guarantees the metadata call to succeed in more cases. Default to 10s. ### HF_HUB_DOWNLOAD_TIMEOUT Integer value to define the number of seconds to wait for server response when downloading a file. If the request times out, a TimeoutError is raised. Setting a higher value is beneficial on machine with a slow connection. A smaller value makes the process fail quicker in case of complete network outage. Default to 10s. ## Xet ### Other Xet environment variables * [`HF_HUB_DISABLE_XET`](../package_reference/environment_variables#hfhubdisablexet) * [`HF_XET_CACHE`](../package_reference/environment_variables#hfxetcache) * [`HF_XET_HIGH_PERFORMANCE`](../package_reference/environment_variables#hfxethighperformance) * [`HF_XET_RECONSTRUCT_WRITE_SEQUENTIALLY`](../package_reference/environment_variables#hfxetreconstructwritesequentially) ### HF_XET_CHUNK_CACHE_SIZE_BYTES To set the size of the Xet cache locally. Increasing this will give more space for caching terms/chunks fetched from S3. A larger cache can better take advantage of deduplication across repos & files. If your network speed is much greater than your local disk speed (ex 10Gbps vs SSD or worse) then consider disabling the Xet cache for increased performance. To disable the Xet cache, set `HF_XET_CHUNK_CACHE_SIZE_BYTES=0`. Defaults to `10737418240` (10GiB). ### HF_XET_NUM_CONCURRENT_RANGE_GETS To set the number of concurrent terms (range of bytes from within a xorb, often called a chunk) downloaded from S3 per file. Increasing this will help with the speed of downloading a file if there is network bandwidth available. Defaults to `16`. ## Boolean values The following environment variables expect a boolean value. The variable will be considered as `True` if its value is one of `{"1", "ON", "YES", "TRUE"}` (case-insensitive). Any other value (or undefined) will be considered as `False`. ### HF_DEBUG If set, the log level for the `huggingface_hub` logger is set to DEBUG. Additionally, all requests made by HF libraries will be logged as equivalent cURL commands for easier debugging and reproducibility. ### HF_HUB_OFFLINE If set, no HTTP calls will be made to the Hugging Face Hub. If you try to download files, only the cached files will be accessed. If no cache file is detected, an error is raised This is useful in case your network is slow and you don't care about having the latest version of a file. If `HF_HUB_OFFLINE=1` is set as environment variable and you call any method of [`HfApi`], an [`~huggingface_hub.utils.OfflineModeIsEnabled`] exception will be raised. **Note:** even if the latest version of a file is cached, calling `hf_hub_download` still triggers a HTTP request to check that a new version is not available. Setting `HF_HUB_OFFLINE=1` will skip this call which speeds up your loading time. ### HF_HUB_DISABLE_IMPLICIT_TOKEN Authentication is not mandatory for every requests to the Hub. For instance, requesting details about `"gpt2"` model does not require to be authenticated. However, if a user is [logged in](../package_reference/login), the default behavior will be to always send the token in order to ease user experience (never get a HTTP 401 Unauthorized) when accessing private or gated repositories. For privacy, you can disable this behavior by setting `HF_HUB_DISABLE_IMPLICIT_TOKEN=1`. In this case, the token will be sent only for "write-access" calls (example: create a commit). **Note:** disabling implicit sending of token can have weird side effects. For example, if you want to list all models on the Hub, your private models will not be listed. You would need to explicitly pass `token=True` argument in your script. ### HF_HUB_DISABLE_PROGRESS_BARS For time consuming tasks, `huggingface_hub` displays a progress bar by default (using tqdm). You can disable all the progress bars at once by setting `HF_HUB_DISABLE_PROGRESS_BARS=1`. ### HF_HUB_DISABLE_SYMLINKS_WARNING If you are on a Windows machine, it is recommended to enable the developer mode or to run `huggingface_hub` in admin mode. If not, `huggingface_hub` will not be able to create symlinks in your cache system. You will be able to execute any script but your user experience will be degraded as some huge files might end-up duplicated on your hard-drive. A warning message is triggered to warn you about this behavior. Set `HF_HUB_DISABLE_SYMLINKS_WARNING=1`, to disable this warning. For more details, see [cache limitations](../guides/manage-cache#limitations). ### HF_HUB_DISABLE_EXPERIMENTAL_WARNING Some features of `huggingface_hub` are experimental. This means you can use them but we do not guarantee they will be maintained in the future. In particular, we might update the API or behavior of such features without any deprecation cycle. A warning message is triggered when using an experimental feature to warn you about it. If you're comfortable debugging any potential issues using an experimental feature, you can set `HF_HUB_DISABLE_EXPERIMENTAL_WARNING=1` to disable the warning. If you are using an experimental feature, please let us know! Your feedback can help us design and improve it. ### HF_HUB_DISABLE_TELEMETRY By default, some data is collected by HF libraries (`transformers`, `datasets`, `gradio`,..) to monitor usage, debug issues and help prioritize features. Each library defines its own policy (i.e. which usage to monitor) but the core implementation happens in `huggingface_hub` (see [`send_telemetry`]). You can set `HF_HUB_DISABLE_TELEMETRY=1` as environment variable to globally disable telemetry. ### HF_HUB_DISABLE_XET Set to disable using `hf-xet`, even if it is available in your Python environment. This is since `hf-xet` will be used automatically if it is found, this allows explicitly disabling its usage. ### HF_HUB_ENABLE_HF_TRANSFER Set to `True` for faster uploads and downloads from the Hub using `hf_transfer`. By default, `huggingface_hub` uses the Python-based `requests.get` and `requests.post` functions. Although these are reliable and versatile, they may not be the most efficient choice for machines with high bandwidth. [`hf_transfer`](https://github.com/huggingface/hf_transfer) is a Rust-based package developed to maximize the bandwidth used by dividing large files into smaller parts and transferring them simultaneously using multiple threads. This approach can potentially double the transfer speed. To use `hf_transfer`: 1. Specify the `hf_transfer` extra when installing `huggingface_hub` (e.g. `pip install huggingface_hub[hf_transfer]`). 2. Set `HF_HUB_ENABLE_HF_TRANSFER=1` as an environment variable. Please note that using `hf_transfer` comes with certain limitations. Since it is not purely Python-based, debugging errors may be challenging. Additionally, `hf_transfer` lacks several user-friendly features such as resumable downloads and proxies. These omissions are intentional to maintain the simplicity and speed of the Rust logic. Consequently, `hf_transfer` is not enabled by default in `huggingface_hub`. `hf_xet` is an alternative to `hf_transfer`. It provides efficient file transfers through a chunk-based deduplication strategy, custom Xet storage (replacing Git LFS), and a seamless integration with `huggingface_hub`. [Read more about the package](https://huggingface.co/docs/hub/storage-backends) and enable with `pip install "huggingface_hub[hf_xet]"`. ### HF_XET_HIGH_PERFORMANCE Set `hf-xet` to operate with increased settings to maximize network and disk resources on the machine. Enabling high performance mode will try to saturate the network bandwidth of this machine and utilize all CPU cores for parallel upload/download activity. Consider this analogous to setting `HF_HUB_ENABLE_HF_TRANSFER=True` when uploading / downloading using `hf-xet` to the Xet storage backend. ### HF_XET_RECONSTRUCT_WRITE_SEQUENTIALLY To have `hf-xet` write sequentially to local disk, instead of in parallel. `hf-xet` is designed for SSD/NVMe disks (using parallel writes with direct addressing). If you are using an HDD (spinning hard disk), setting this will change disk writes to be sequential instead of parallel. For slower hard disks, this can improve overall write performance, as the disk is not spinning to seek for parallel writes. ## Deprecated environment variables In order to standardize all environment variables within the Hugging Face ecosystem, some variables have been marked as deprecated. Although they remain functional, they no longer take precedence over their replacements. The following table outlines the deprecated variables and their corresponding alternatives: | Deprecated Variable | Replacement | | --------------------------- | ------------------ | | `HUGGINGFACE_HUB_CACHE` | `HF_HUB_CACHE` | | `HUGGINGFACE_ASSETS_CACHE` | `HF_ASSETS_CACHE` | | `HUGGING_FACE_HUB_TOKEN` | `HF_TOKEN` | | `HUGGINGFACE_HUB_VERBOSITY` | `HF_HUB_VERBOSITY` | ## From external tools Some environment variables are not specific to `huggingface_hub` but are still taken into account when they are set. ### DO_NOT_TRACK Boolean value. Equivalent to `HF_HUB_DISABLE_TELEMETRY`. When set to true, telemetry is globally disabled in the Hugging Face Python ecosystem (`transformers`, `diffusers`, `gradio`, etc.). See https://consoledonottrack.com/ for more details. ### NO_COLOR Boolean value. When set, `huggingface-cli` tool will not print any ANSI color. See [no-color.org](https://no-color.org/). ### XDG_CACHE_HOME Used only when `HF_HOME` is not set! This is the default way to configure where [user-specific non-essential (cached) data should be written](https://wiki.archlinux.org/title/XDG_Base_Directory) on linux machines. If `HF_HOME` is not set, the default home will be `"$XDG_CACHE_HOME/huggingface"` instead of `"~/.cache/huggingface"`. huggingface_hub-0.31.1/docs/source/en/package_reference/file_download.md000066400000000000000000000016621500667546600263200ustar00rootroot00000000000000 # Downloading files ## Download a single file ### hf_hub_download [[autodoc]] huggingface_hub.hf_hub_download ### hf_hub_url [[autodoc]] huggingface_hub.hf_hub_url ## Download a snapshot of the repo [[autodoc]] huggingface_hub.snapshot_download ## Get metadata about a file ### get_hf_file_metadata [[autodoc]] huggingface_hub.get_hf_file_metadata ### HfFileMetadata [[autodoc]] huggingface_hub.HfFileMetadata ## Caching The methods displayed above are designed to work with a caching system that prevents re-downloading files. The caching system was updated in v0.8.0 to become the central cache-system shared across libraries that depend on the Hub. Read the [cache-system guide](../guides/manage-cache) for a detailed presentation of caching at at HF. huggingface_hub-0.31.1/docs/source/en/package_reference/hf_api.md000066400000000000000000000051101500667546600247300ustar00rootroot00000000000000 # HfApi Client Below is the documentation for the `HfApi` class, which serves as a Python wrapper for the Hugging Face Hub's API. All methods from the `HfApi` are also accessible from the package's root directly. Both approaches are detailed below. Using the root method is more straightforward but the [`HfApi`] class gives you more flexibility. In particular, you can pass a token that will be reused in all HTTP calls. This is different than `huggingface-cli login` or [`login`] as the token is not persisted on the machine. It is also possible to provide a different endpoint or configure a custom user-agent. ```python from huggingface_hub import HfApi, list_models # Use root method models = list_models() # Or configure a HfApi client hf_api = HfApi( endpoint="https://huggingface.co", # Can be a Private Hub endpoint. token="hf_xxx", # Token is not persisted on the machine. ) models = hf_api.list_models() ``` ## HfApi [[autodoc]] HfApi ## API Dataclasses ### AccessRequest [[autodoc]] huggingface_hub.hf_api.AccessRequest ### CommitInfo [[autodoc]] huggingface_hub.hf_api.CommitInfo ### DatasetInfo [[autodoc]] huggingface_hub.hf_api.DatasetInfo ### GitRefInfo [[autodoc]] huggingface_hub.hf_api.GitRefInfo ### GitCommitInfo [[autodoc]] huggingface_hub.hf_api.GitCommitInfo ### GitRefs [[autodoc]] huggingface_hub.hf_api.GitRefs ### LFSFileInfo [[autodoc]] huggingface_hub.hf_api.LFSFileInfo ### ModelInfo [[autodoc]] huggingface_hub.hf_api.ModelInfo ### RepoSibling [[autodoc]] huggingface_hub.hf_api.RepoSibling ### RepoFile [[autodoc]] huggingface_hub.hf_api.RepoFile ### RepoUrl [[autodoc]] huggingface_hub.hf_api.RepoUrl ### SafetensorsRepoMetadata [[autodoc]] huggingface_hub.utils.SafetensorsRepoMetadata ### SafetensorsFileMetadata [[autodoc]] huggingface_hub.utils.SafetensorsFileMetadata ### SpaceInfo [[autodoc]] huggingface_hub.hf_api.SpaceInfo ### TensorInfo [[autodoc]] huggingface_hub.utils.TensorInfo ### User [[autodoc]] huggingface_hub.hf_api.User ### UserLikes [[autodoc]] huggingface_hub.hf_api.UserLikes ### WebhookInfo [[autodoc]] huggingface_hub.hf_api.WebhookInfo ### WebhookWatchedItem [[autodoc]] huggingface_hub.hf_api.WebhookWatchedItem ## CommitOperation Below are the supported values for [`CommitOperation`]: [[autodoc]] CommitOperationAdd [[autodoc]] CommitOperationDelete [[autodoc]] CommitOperationCopy ## CommitScheduler [[autodoc]] CommitScheduler huggingface_hub-0.31.1/docs/source/en/package_reference/hf_file_system.md000066400000000000000000000013601500667546600265050ustar00rootroot00000000000000 # Filesystem API The `HfFileSystem` class provides a pythonic file interface to the Hugging Face Hub based on [`fsspec`](https://filesystem-spec.readthedocs.io/en/latest/). ## HfFileSystem `HfFileSystem` is based on [fsspec](https://filesystem-spec.readthedocs.io/en/latest/), so it is compatible with most of the APIs that it offers. For more details, check out [our guide](../guides/hf_file_system) and fsspec's [API Reference](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem). [[autodoc]] HfFileSystem - __init__ - all huggingface_hub-0.31.1/docs/source/en/package_reference/inference_client.md000066400000000000000000000053061500667546600270050ustar00rootroot00000000000000 # Inference Inference is the process of using a trained model to make predictions on new data. Because this process can be compute-intensive, running on a dedicated or external service can be an interesting option. The `huggingface_hub` library provides a unified interface to run inference across multiple services for models hosted on the Hugging Face Hub: 1. [Inference API](https://huggingface.co/docs/api-inference/index): a serverless solution that allows you to run accelerated inference on Hugging Face's infrastructure for free. This service is a fast way to get started, test different models, and prototype AI products. 2. Third-party providers: various serverless solution provided by external providers (Together, Sambanova, etc.). These providers offer production-ready APIs on a pay-a-you-go model. This is the fastest way to integrate AI in your products with a maintenance-free and scalable solution. Refer to the [Supported providers and tasks](../guides/inference#supported-providers-and-tasks) section for a list of supported providers. 3. [Inference Endpoints](https://huggingface.co/docs/inference-endpoints/index): a product to easily deploy models to production. Inference is run by Hugging Face in a dedicated, fully managed infrastructure on a cloud provider of your choice. These services can be called with the [`InferenceClient`] object. Please refer to [this guide](../guides/inference) for more information on how to use it. ## Inference Client [[autodoc]] InferenceClient ## Async Inference Client An async version of the client is also provided, based on `asyncio` and `aiohttp`. To use it, you can either install `aiohttp` directly or use the `[inference]` extra: ```sh pip install --upgrade huggingface_hub[inference] # or # pip install aiohttp ``` [[autodoc]] AsyncInferenceClient ## InferenceTimeoutError [[autodoc]] InferenceTimeoutError ### ModelStatus [[autodoc]] huggingface_hub.inference._common.ModelStatus ## InferenceAPI [`InferenceAPI`] is the legacy way to call the Inference API. The interface is more simplistic and requires knowing the input parameters and output format for each task. It also lacks the ability to connect to other services like Inference Endpoints or AWS SageMaker. [`InferenceAPI`] will soon be deprecated so we recommend using [`InferenceClient`] whenever possible. Check out [this guide](../guides/inference#legacy-inferenceapi-client) to learn how to switch from [`InferenceAPI`] to [`InferenceClient`] in your scripts. [[autodoc]] InferenceApi - __init__ - __call__ - all huggingface_hub-0.31.1/docs/source/en/package_reference/inference_endpoints.md000066400000000000000000000040621500667546600275300ustar00rootroot00000000000000# Inference Endpoints Inference Endpoints provides a secure production solution to easily deploy models on a dedicated and autoscaling infrastructure managed by Hugging Face. An Inference Endpoint is built from a model from the [Hub](https://huggingface.co/models). This page is a reference for `huggingface_hub`'s integration with Inference Endpoints. For more information about the Inference Endpoints product, check out its [official documentation](https://huggingface.co/docs/inference-endpoints/index). Check out the [related guide](../guides/inference_endpoints) to learn how to use `huggingface_hub` to manage your Inference Endpoints programmatically. Inference Endpoints can be fully managed via API. The endpoints are documented with [Swagger](https://api.endpoints.huggingface.cloud/). The [`InferenceEndpoint`] class is a simple wrapper built on top on this API. ## Methods A subset of the Inference Endpoint features are implemented in [`HfApi`]: - [`get_inference_endpoint`] and [`list_inference_endpoints`] to get information about your Inference Endpoints - [`create_inference_endpoint`], [`update_inference_endpoint`] and [`delete_inference_endpoint`] to deploy and manage Inference Endpoints - [`pause_inference_endpoint`] and [`resume_inference_endpoint`] to pause and resume an Inference Endpoint - [`scale_to_zero_inference_endpoint`] to manually scale an Endpoint to 0 replicas ## InferenceEndpoint The main dataclass is [`InferenceEndpoint`]. It contains information about a deployed `InferenceEndpoint`, including its configuration and current state. Once deployed, you can run inference on the Endpoint using the [`InferenceEndpoint.client`] and [`InferenceEndpoint.async_client`] properties that respectively return an [`InferenceClient`] and an [`AsyncInferenceClient`] object. [[autodoc]] InferenceEndpoint - from_raw - client - async_client - all ## InferenceEndpointStatus [[autodoc]] InferenceEndpointStatus ## InferenceEndpointType [[autodoc]] InferenceEndpointType ## InferenceEndpointError [[autodoc]] InferenceEndpointError huggingface_hub-0.31.1/docs/source/en/package_reference/inference_types.md000066400000000000000000000220351500667546600266710ustar00rootroot00000000000000 # Inference types This page lists the types (e.g. dataclasses) available for each task supported on the Hugging Face Hub. Each task is specified using a JSON schema, and the types are generated from these schemas - with some customization due to Python requirements. Visit [@huggingface.js/tasks](https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks) to find the JSON schemas for each task. This part of the lib is still under development and will be improved in future releases. ## audio_classification [[autodoc]] huggingface_hub.AudioClassificationInput [[autodoc]] huggingface_hub.AudioClassificationOutputElement [[autodoc]] huggingface_hub.AudioClassificationParameters ## audio_to_audio [[autodoc]] huggingface_hub.AudioToAudioInput [[autodoc]] huggingface_hub.AudioToAudioOutputElement ## automatic_speech_recognition [[autodoc]] huggingface_hub.AutomaticSpeechRecognitionGenerationParameters [[autodoc]] huggingface_hub.AutomaticSpeechRecognitionInput [[autodoc]] huggingface_hub.AutomaticSpeechRecognitionOutput [[autodoc]] huggingface_hub.AutomaticSpeechRecognitionOutputChunk [[autodoc]] huggingface_hub.AutomaticSpeechRecognitionParameters ## chat_completion [[autodoc]] huggingface_hub.ChatCompletionInput [[autodoc]] huggingface_hub.ChatCompletionInputFunctionDefinition [[autodoc]] huggingface_hub.ChatCompletionInputFunctionName [[autodoc]] huggingface_hub.ChatCompletionInputGrammarType [[autodoc]] huggingface_hub.ChatCompletionInputMessage [[autodoc]] huggingface_hub.ChatCompletionInputMessageChunk [[autodoc]] huggingface_hub.ChatCompletionInputStreamOptions [[autodoc]] huggingface_hub.ChatCompletionInputTool [[autodoc]] huggingface_hub.ChatCompletionInputToolCall [[autodoc]] huggingface_hub.ChatCompletionInputToolChoiceClass [[autodoc]] huggingface_hub.ChatCompletionInputURL [[autodoc]] huggingface_hub.ChatCompletionOutput [[autodoc]] huggingface_hub.ChatCompletionOutputComplete [[autodoc]] huggingface_hub.ChatCompletionOutputFunctionDefinition [[autodoc]] huggingface_hub.ChatCompletionOutputLogprob [[autodoc]] huggingface_hub.ChatCompletionOutputLogprobs [[autodoc]] huggingface_hub.ChatCompletionOutputMessage [[autodoc]] huggingface_hub.ChatCompletionOutputToolCall [[autodoc]] huggingface_hub.ChatCompletionOutputTopLogprob [[autodoc]] huggingface_hub.ChatCompletionOutputUsage [[autodoc]] huggingface_hub.ChatCompletionStreamOutput [[autodoc]] huggingface_hub.ChatCompletionStreamOutputChoice [[autodoc]] huggingface_hub.ChatCompletionStreamOutputDelta [[autodoc]] huggingface_hub.ChatCompletionStreamOutputDeltaToolCall [[autodoc]] huggingface_hub.ChatCompletionStreamOutputFunction [[autodoc]] huggingface_hub.ChatCompletionStreamOutputLogprob [[autodoc]] huggingface_hub.ChatCompletionStreamOutputLogprobs [[autodoc]] huggingface_hub.ChatCompletionStreamOutputTopLogprob [[autodoc]] huggingface_hub.ChatCompletionStreamOutputUsage ## depth_estimation [[autodoc]] huggingface_hub.DepthEstimationInput [[autodoc]] huggingface_hub.DepthEstimationOutput ## document_question_answering [[autodoc]] huggingface_hub.DocumentQuestionAnsweringInput [[autodoc]] huggingface_hub.DocumentQuestionAnsweringInputData [[autodoc]] huggingface_hub.DocumentQuestionAnsweringOutputElement [[autodoc]] huggingface_hub.DocumentQuestionAnsweringParameters ## feature_extraction [[autodoc]] huggingface_hub.FeatureExtractionInput ## fill_mask [[autodoc]] huggingface_hub.FillMaskInput [[autodoc]] huggingface_hub.FillMaskOutputElement [[autodoc]] huggingface_hub.FillMaskParameters ## image_classification [[autodoc]] huggingface_hub.ImageClassificationInput [[autodoc]] huggingface_hub.ImageClassificationOutputElement [[autodoc]] huggingface_hub.ImageClassificationParameters ## image_segmentation [[autodoc]] huggingface_hub.ImageSegmentationInput [[autodoc]] huggingface_hub.ImageSegmentationOutputElement [[autodoc]] huggingface_hub.ImageSegmentationParameters ## image_to_image [[autodoc]] huggingface_hub.ImageToImageInput [[autodoc]] huggingface_hub.ImageToImageOutput [[autodoc]] huggingface_hub.ImageToImageParameters [[autodoc]] huggingface_hub.ImageToImageTargetSize ## image_to_text [[autodoc]] huggingface_hub.ImageToTextGenerationParameters [[autodoc]] huggingface_hub.ImageToTextInput [[autodoc]] huggingface_hub.ImageToTextOutput [[autodoc]] huggingface_hub.ImageToTextParameters ## object_detection [[autodoc]] huggingface_hub.ObjectDetectionBoundingBox [[autodoc]] huggingface_hub.ObjectDetectionInput [[autodoc]] huggingface_hub.ObjectDetectionOutputElement [[autodoc]] huggingface_hub.ObjectDetectionParameters ## question_answering [[autodoc]] huggingface_hub.QuestionAnsweringInput [[autodoc]] huggingface_hub.QuestionAnsweringInputData [[autodoc]] huggingface_hub.QuestionAnsweringOutputElement [[autodoc]] huggingface_hub.QuestionAnsweringParameters ## sentence_similarity [[autodoc]] huggingface_hub.SentenceSimilarityInput [[autodoc]] huggingface_hub.SentenceSimilarityInputData ## summarization [[autodoc]] huggingface_hub.SummarizationInput [[autodoc]] huggingface_hub.SummarizationOutput [[autodoc]] huggingface_hub.SummarizationParameters ## table_question_answering [[autodoc]] huggingface_hub.TableQuestionAnsweringInput [[autodoc]] huggingface_hub.TableQuestionAnsweringInputData [[autodoc]] huggingface_hub.TableQuestionAnsweringOutputElement [[autodoc]] huggingface_hub.TableQuestionAnsweringParameters ## text2text_generation [[autodoc]] huggingface_hub.Text2TextGenerationInput [[autodoc]] huggingface_hub.Text2TextGenerationOutput [[autodoc]] huggingface_hub.Text2TextGenerationParameters ## text_classification [[autodoc]] huggingface_hub.TextClassificationInput [[autodoc]] huggingface_hub.TextClassificationOutputElement [[autodoc]] huggingface_hub.TextClassificationParameters ## text_generation [[autodoc]] huggingface_hub.TextGenerationInput [[autodoc]] huggingface_hub.TextGenerationInputGenerateParameters [[autodoc]] huggingface_hub.TextGenerationInputGrammarType [[autodoc]] huggingface_hub.TextGenerationOutput [[autodoc]] huggingface_hub.TextGenerationOutputBestOfSequence [[autodoc]] huggingface_hub.TextGenerationOutputDetails [[autodoc]] huggingface_hub.TextGenerationOutputPrefillToken [[autodoc]] huggingface_hub.TextGenerationOutputToken [[autodoc]] huggingface_hub.TextGenerationStreamOutput [[autodoc]] huggingface_hub.TextGenerationStreamOutputStreamDetails [[autodoc]] huggingface_hub.TextGenerationStreamOutputToken ## text_to_audio [[autodoc]] huggingface_hub.TextToAudioGenerationParameters [[autodoc]] huggingface_hub.TextToAudioInput [[autodoc]] huggingface_hub.TextToAudioOutput [[autodoc]] huggingface_hub.TextToAudioParameters ## text_to_image [[autodoc]] huggingface_hub.TextToImageInput [[autodoc]] huggingface_hub.TextToImageOutput [[autodoc]] huggingface_hub.TextToImageParameters ## text_to_speech [[autodoc]] huggingface_hub.TextToSpeechGenerationParameters [[autodoc]] huggingface_hub.TextToSpeechInput [[autodoc]] huggingface_hub.TextToSpeechOutput [[autodoc]] huggingface_hub.TextToSpeechParameters ## text_to_video [[autodoc]] huggingface_hub.TextToVideoInput [[autodoc]] huggingface_hub.TextToVideoOutput [[autodoc]] huggingface_hub.TextToVideoParameters ## token_classification [[autodoc]] huggingface_hub.TokenClassificationInput [[autodoc]] huggingface_hub.TokenClassificationOutputElement [[autodoc]] huggingface_hub.TokenClassificationParameters ## translation [[autodoc]] huggingface_hub.TranslationInput [[autodoc]] huggingface_hub.TranslationOutput [[autodoc]] huggingface_hub.TranslationParameters ## video_classification [[autodoc]] huggingface_hub.VideoClassificationInput [[autodoc]] huggingface_hub.VideoClassificationOutputElement [[autodoc]] huggingface_hub.VideoClassificationParameters ## visual_question_answering [[autodoc]] huggingface_hub.VisualQuestionAnsweringInput [[autodoc]] huggingface_hub.VisualQuestionAnsweringInputData [[autodoc]] huggingface_hub.VisualQuestionAnsweringOutputElement [[autodoc]] huggingface_hub.VisualQuestionAnsweringParameters ## zero_shot_classification [[autodoc]] huggingface_hub.ZeroShotClassificationInput [[autodoc]] huggingface_hub.ZeroShotClassificationOutputElement [[autodoc]] huggingface_hub.ZeroShotClassificationParameters ## zero_shot_image_classification [[autodoc]] huggingface_hub.ZeroShotImageClassificationInput [[autodoc]] huggingface_hub.ZeroShotImageClassificationOutputElement [[autodoc]] huggingface_hub.ZeroShotImageClassificationParameters ## zero_shot_object_detection [[autodoc]] huggingface_hub.ZeroShotObjectDetectionBoundingBox [[autodoc]] huggingface_hub.ZeroShotObjectDetectionInput [[autodoc]] huggingface_hub.ZeroShotObjectDetectionOutputElement [[autodoc]] huggingface_hub.ZeroShotObjectDetectionParameters huggingface_hub-0.31.1/docs/source/en/package_reference/mixins.md000066400000000000000000000015531500667546600250200ustar00rootroot00000000000000 # Mixins & serialization methods ## Mixins The `huggingface_hub` library offers a range of mixins that can be used as a parent class for your objects, in order to provide simple uploading and downloading functions. Check out our [integration guide](../guides/integrations) to learn how to integrate any ML framework with the Hub. ### Generic [[autodoc]] ModelHubMixin - all - _save_pretrained - _from_pretrained ### PyTorch [[autodoc]] PyTorchModelHubMixin ### Keras [[autodoc]] KerasModelHubMixin [[autodoc]] from_pretrained_keras [[autodoc]] push_to_hub_keras [[autodoc]] save_pretrained_keras ### Fastai [[autodoc]] from_pretrained_fastai [[autodoc]] push_to_hub_fastai huggingface_hub-0.31.1/docs/source/en/package_reference/overview.md000066400000000000000000000004441500667546600253550ustar00rootroot00000000000000 # Overview This section contains an exhaustive and technical description of `huggingface_hub` classes and methods. huggingface_hub-0.31.1/docs/source/en/package_reference/repository.md000066400000000000000000000026521500667546600257310ustar00rootroot00000000000000 # Managing local and online repositories The `Repository` class is a helper class that wraps `git` and `git-lfs` commands. It provides tooling adapted for managing repositories which can be very large. It is the recommended tool as soon as any `git` operation is involved, or when collaboration will be a point of focus with the repository itself. ## The Repository class [[autodoc]] Repository - __init__ - current_branch - all ## Helper methods [[autodoc]] huggingface_hub.repository.is_git_repo [[autodoc]] huggingface_hub.repository.is_local_clone [[autodoc]] huggingface_hub.repository.is_tracked_with_lfs [[autodoc]] huggingface_hub.repository.is_git_ignored [[autodoc]] huggingface_hub.repository.files_to_be_staged [[autodoc]] huggingface_hub.repository.is_tracked_upstream [[autodoc]] huggingface_hub.repository.commits_to_push ## Following asynchronous commands The `Repository` utility offers several methods which can be launched asynchronously: - `git_push` - `git_pull` - `push_to_hub` - The `commit` context manager See below for utilities to manage such asynchronous methods. [[autodoc]] Repository - commands_failed - commands_in_progress - wait_for_commands [[autodoc]] huggingface_hub.repository.CommandInProgress huggingface_hub-0.31.1/docs/source/en/package_reference/serialization.md000066400000000000000000000177421500667546600263750ustar00rootroot00000000000000 # Serialization `huggingface_hub` provides helpers to save and load ML model weights in a standardized way. This part of the library is still under development and will be improved in future releases. The goal is to harmonize how weights are saved and loaded across the Hub, both to remove code duplication across libraries and to establish consistent conventions. ## DDUF file format DDUF is a file format designed for diffusion models. It allows saving all the information to run a model in a single file. This work is inspired by the [GGUF](https://github.com/ggerganov/ggml/blob/master/docs/gguf.md) format. `huggingface_hub` provides helpers to save and load DDUF files, ensuring the file format is respected. This is a very early version of the parser. The API and implementation can evolve in the near future. The parser currently does very little validation. For more details about the file format, check out https://github.com/huggingface/huggingface.js/tree/main/packages/dduf. ### How to write a DDUF file? Here is how to export a folder containing different parts of a diffusion model using [`export_folder_as_dduf`]: ```python # Export a folder as a DDUF file >>> from huggingface_hub import export_folder_as_dduf >>> export_folder_as_dduf("FLUX.1-dev.dduf", folder_path="path/to/FLUX.1-dev") ``` For more flexibility, you can use [`export_entries_as_dduf`] and pass a list of files to include in the final DDUF file: ```python # Export specific files from the local disk. >>> from huggingface_hub import export_entries_as_dduf >>> export_entries_as_dduf( ... dduf_path="stable-diffusion-v1-4-FP16.dduf", ... entries=[ # List entries to add to the DDUF file (here, only FP16 weights) ... ("model_index.json", "path/to/model_index.json"), ... ("vae/config.json", "path/to/vae/config.json"), ... ("vae/diffusion_pytorch_model.fp16.safetensors", "path/to/vae/diffusion_pytorch_model.fp16.safetensors"), ... ("text_encoder/config.json", "path/to/text_encoder/config.json"), ... ("text_encoder/model.fp16.safetensors", "path/to/text_encoder/model.fp16.safetensors"), ... # ... add more entries here ... ] ... ) ``` The `entries` parameter also supports passing an iterable of paths or bytes. This can prove useful if you have a loaded model and want to serialize it directly into a DDUF file instead of having to serialize each component to disk first and then as a DDUF file. Here is an example of how a `StableDiffusionPipeline` can be serialized as DDUF: ```python # Export state_dicts one by one from a loaded pipeline >>> from diffusers import DiffusionPipeline >>> from typing import Generator, Tuple >>> import safetensors.torch >>> from huggingface_hub import export_entries_as_dduf >>> pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") ... # ... do some work with the pipeline >>> def as_entries(pipe: DiffusionPipeline) -> Generator[Tuple[str, bytes], None, None]: ... # Build a generator that yields the entries to add to the DDUF file. ... # The first element of the tuple is the filename in the DDUF archive (must use UNIX separator!). The second element is the content of the file. ... # Entries will be evaluated lazily when the DDUF file is created (only 1 entry is loaded in memory at a time) ... yield "vae/config.json", pipe.vae.to_json_string().encode() ... yield "vae/diffusion_pytorch_model.safetensors", safetensors.torch.save(pipe.vae.state_dict()) ... yield "text_encoder/config.json", pipe.text_encoder.config.to_json_string().encode() ... yield "text_encoder/model.safetensors", safetensors.torch.save(pipe.text_encoder.state_dict()) ... # ... add more entries here >>> export_entries_as_dduf(dduf_path="stable-diffusion-v1-4.dduf", entries=as_entries(pipe)) ``` **Note:** in practice, `diffusers` provides a method to directly serialize a pipeline in a DDUF file. The snippet above is only meant as an example. ### How to read a DDUF file? ```python >>> import json >>> import safetensors.torch >>> from huggingface_hub import read_dduf_file # Read DDUF metadata >>> dduf_entries = read_dduf_file("FLUX.1-dev.dduf") # Returns a mapping filename <> DDUFEntry >>> dduf_entries["model_index.json"] DDUFEntry(filename='model_index.json', offset=66, length=587) # Load model index as JSON >>> json.loads(dduf_entries["model_index.json"].read_text()) {'_class_name': 'FluxPipeline', '_diffusers_version': '0.32.0.dev0', '_name_or_path': 'black-forest-labs/FLUX.1-dev', 'scheduler': ['diffusers', 'FlowMatchEulerDiscreteScheduler'], 'text_encoder': ['transformers', 'CLIPTextModel'], 'text_encoder_2': ['transformers', 'T5EncoderModel'], 'tokenizer': ['transformers', 'CLIPTokenizer'], 'tokenizer_2': ['transformers', 'T5TokenizerFast'], 'transformer': ['diffusers', 'FluxTransformer2DModel'], 'vae': ['diffusers', 'AutoencoderKL']} # Load VAE weights using safetensors >>> with dduf_entries["vae/diffusion_pytorch_model.safetensors"].as_mmap() as mm: ... state_dict = safetensors.torch.load(mm) ``` ### Helpers [[autodoc]] huggingface_hub.export_entries_as_dduf [[autodoc]] huggingface_hub.export_folder_as_dduf [[autodoc]] huggingface_hub.read_dduf_file [[autodoc]] huggingface_hub.DDUFEntry ### Errors [[autodoc]] huggingface_hub.errors.DDUFError [[autodoc]] huggingface_hub.errors.DDUFCorruptedFileError [[autodoc]] huggingface_hub.errors.DDUFExportError [[autodoc]] huggingface_hub.errors.DDUFInvalidEntryNameError ## Saving tensors The main helper of the `serialization` module takes a torch `nn.Module` as input and saves it to disk. It handles the logic to save shared tensors (see [safetensors explanation](https://huggingface.co/docs/safetensors/torch_shared_tensors)) as well as logic to split the state dictionary into shards, using [`split_torch_state_dict_into_shards`] under the hood. At the moment, only `torch` framework is supported. If you want to save a state dictionary (e.g. a mapping between layer names and related tensors) instead of a `nn.Module`, you can use [`save_torch_state_dict`] which provides the same features. This is useful for example if you want to apply custom logic to the state dict before saving it. ### save_torch_model [[autodoc]] huggingface_hub.save_torch_model ### save_torch_state_dict [[autodoc]] huggingface_hub.save_torch_state_dict The `serialization` module also contains low-level helpers to split a state dictionary into several shards, while creating a proper index in the process. These helpers are available for `torch` and `tensorflow` tensors and are designed to be easily extended to any other ML frameworks. ### split_tf_state_dict_into_shards [[autodoc]] huggingface_hub.split_tf_state_dict_into_shards ### split_torch_state_dict_into_shards [[autodoc]] huggingface_hub.split_torch_state_dict_into_shards ### split_state_dict_into_shards_factory This is the underlying factory from which each framework-specific helper is derived. In practice, you are not expected to use this factory directly except if you need to adapt it to a framework that is not yet supported. If that is the case, please let us know by [opening a new issue](https://github.com/huggingface/huggingface_hub/issues/new) on the `huggingface_hub` repo. [[autodoc]] huggingface_hub.split_state_dict_into_shards_factory ## Loading tensors The loading helpers support both single-file and sharded checkpoints in either safetensors or pickle format. [`load_torch_model`] takes a `nn.Module` and a checkpoint path (either a single file or a directory) as input and load the weights into the model. ### load_torch_model [[autodoc]] huggingface_hub.load_torch_model ### load_state_dict_from_file [[autodoc]] huggingface_hub.load_state_dict_from_file ## Tensors helpers ### get_torch_storage_id [[autodoc]] huggingface_hub.get_torch_storage_id ### get_torch_storage_size [[autodoc]] huggingface_hub.get_torch_storage_sizehuggingface_hub-0.31.1/docs/source/en/package_reference/space_runtime.md000066400000000000000000000014731500667546600263500ustar00rootroot00000000000000 # Managing your Space runtime Check the [`HfApi`] documentation page for the reference of methods to manage your Space on the Hub. - Duplicate a Space: [`duplicate_space`] - Fetch current runtime: [`get_space_runtime`] - Manage secrets: [`add_space_secret`] and [`delete_space_secret`] - Manage hardware: [`request_space_hardware`] - Manage state: [`pause_space`], [`restart_space`], [`set_space_sleep_time`] ## Data structures ### SpaceRuntime [[autodoc]] SpaceRuntime ### SpaceHardware [[autodoc]] SpaceHardware ### SpaceStage [[autodoc]] SpaceStage ### SpaceStorage [[autodoc]] SpaceStorage ### SpaceVariable [[autodoc]] SpaceVariable huggingface_hub-0.31.1/docs/source/en/package_reference/tensorboard.md000066400000000000000000000022051500667546600260260ustar00rootroot00000000000000 # TensorBoard logger TensorBoard is a visualization toolkit for machine learning experimentation. TensorBoard allows tracking and visualizing metrics such as loss and accuracy, visualizing the model graph, viewing histograms, displaying images and much more. TensorBoard is well integrated with the Hugging Face Hub. The Hub automatically detects TensorBoard traces (such as `tfevents`) when pushed to the Hub which starts an instance to visualize them. To get more information about TensorBoard integration on the Hub, check out [this guide](https://huggingface.co/docs/hub/tensorboard). To benefit from this integration, `huggingface_hub` provides a custom logger to push logs to the Hub. It works as a drop-in replacement for [SummaryWriter](https://tensorboardx.readthedocs.io/en/latest/tensorboard.html) with no extra code needed. Traces are still saved locally and a background job push them to the Hub at regular interval. ## HFSummaryWriter [[autodoc]] HFSummaryWriter huggingface_hub-0.31.1/docs/source/en/package_reference/utilities.md000066400000000000000000000233421500667546600255240ustar00rootroot00000000000000 # Utilities ## Configure logging The `huggingface_hub` package exposes a `logging` utility to control the logging level of the package itself. You can import it as such: ```py from huggingface_hub import logging ``` Then, you may define the verbosity in order to update the amount of logs you'll see: ```python from huggingface_hub import logging logging.set_verbosity_error() logging.set_verbosity_warning() logging.set_verbosity_info() logging.set_verbosity_debug() logging.set_verbosity(...) ``` The levels should be understood as follows: - `error`: only show critical logs about usage which may result in an error or unexpected behavior. - `warning`: show logs that aren't critical but usage may result in unintended behavior. Additionally, important informative logs may be shown. - `info`: show most logs, including some verbose logging regarding what is happening under the hood. If something is behaving in an unexpected manner, we recommend switching the verbosity level to this in order to get more information. - `debug`: show all logs, including some internal logs which may be used to track exactly what's happening under the hood. [[autodoc]] logging.get_verbosity [[autodoc]] logging.set_verbosity [[autodoc]] logging.set_verbosity_info [[autodoc]] logging.set_verbosity_debug [[autodoc]] logging.set_verbosity_warning [[autodoc]] logging.set_verbosity_error [[autodoc]] logging.disable_propagation [[autodoc]] logging.enable_propagation ### Repo-specific helper methods The methods exposed below are relevant when modifying modules from the `huggingface_hub` library itself. Using these shouldn't be necessary if you use `huggingface_hub` and you don't modify them. [[autodoc]] logging.get_logger ## Configure progress bars Progress bars are a useful tool to display information to the user while a long-running task is being executed (e.g. when downloading or uploading files). `huggingface_hub` exposes a [`~utils.tqdm`] wrapper to display progress bars in a consistent way across the library. By default, progress bars are enabled. You can disable them globally by setting `HF_HUB_DISABLE_PROGRESS_BARS` environment variable. You can also enable/disable them using [`~utils.enable_progress_bars`] and [`~utils.disable_progress_bars`]. If set, the environment variable has priority on the helpers. ```py >>> from huggingface_hub import snapshot_download >>> from huggingface_hub.utils import are_progress_bars_disabled, disable_progress_bars, enable_progress_bars >>> # Disable progress bars globally >>> disable_progress_bars() >>> # Progress bar will not be shown ! >>> snapshot_download("gpt2") >>> are_progress_bars_disabled() True >>> # Re-enable progress bars globally >>> enable_progress_bars() ``` ### Group-specific control of progress bars You can also enable or disable progress bars for specific groups. This allows you to manage progress bar visibility more granularly within different parts of your application or library. When a progress bar is disabled for a group, all subgroups under it are also affected unless explicitly overridden. ```py # Disable progress bars for a specific group >>> disable_progress_bars("peft.foo") >>> assert not are_progress_bars_disabled("peft") >>> assert not are_progress_bars_disabled("peft.something") >>> assert are_progress_bars_disabled("peft.foo") >>> assert are_progress_bars_disabled("peft.foo.bar") # Re-enable progress bars for a subgroup >>> enable_progress_bars("peft.foo.bar") >>> assert are_progress_bars_disabled("peft.foo") >>> assert not are_progress_bars_disabled("peft.foo.bar") # Use groups with tqdm # No progress bar for `name="peft.foo"` >>> for _ in tqdm(range(5), name="peft.foo"): ... pass # Progress bar will be shown for `name="peft.foo.bar"` >>> for _ in tqdm(range(5), name="peft.foo.bar"): ... pass 100%|███████████████████████████████████████| 5/5 [00:00<00:00, 117817.53it/s] ``` ### are_progress_bars_disabled [[autodoc]] huggingface_hub.utils.are_progress_bars_disabled ### disable_progress_bars [[autodoc]] huggingface_hub.utils.disable_progress_bars ### enable_progress_bars [[autodoc]] huggingface_hub.utils.enable_progress_bars ## Configure HTTP backend In some environments, you might want to configure how HTTP calls are made, for example if you are using a proxy. `huggingface_hub` let you configure this globally using [`configure_http_backend`]. All requests made to the Hub will then use your settings. Under the hood, `huggingface_hub` uses `requests.Session` so you might want to refer to the [`requests` documentation](https://requests.readthedocs.io/en/latest/user/advanced) to learn more about the available parameters. Since `requests.Session` is not guaranteed to be thread-safe, `huggingface_hub` creates one session instance per thread. Using sessions allows us to keep the connection open between HTTP calls and ultimately save time. If you are integrating `huggingface_hub` in a third-party library and wants to make a custom call to the Hub, use [`get_session`] to get a Session configured by your users (i.e. replace any `requests.get(...)` call by `get_session().get(...)`). [[autodoc]] configure_http_backend [[autodoc]] get_session ## Handle HTTP errors `huggingface_hub` defines its own HTTP errors to refine the `HTTPError` raised by `requests` with additional information sent back by the server. ### Raise for status [`~utils.hf_raise_for_status`] is meant to be the central method to "raise for status" from any request made to the Hub. It wraps the base `requests.raise_for_status` to provide additional information. Any `HTTPError` thrown is converted into a `HfHubHTTPError`. ```py import requests from huggingface_hub.utils import hf_raise_for_status, HfHubHTTPError response = requests.post(...) try: hf_raise_for_status(response) except HfHubHTTPError as e: print(str(e)) # formatted message e.request_id, e.server_message # details returned by server # Complete the error message with additional information once it's raised e.append_to_message("\n`create_commit` expects the repository to exist.") raise ``` [[autodoc]] huggingface_hub.utils.hf_raise_for_status ### HTTP errors Here is a list of HTTP errors thrown in `huggingface_hub`. #### HfHubHTTPError `HfHubHTTPError` is the parent class for any HF Hub HTTP error. It takes care of parsing the server response and format the error message to provide as much information to the user as possible. [[autodoc]] huggingface_hub.utils.HfHubHTTPError #### RepositoryNotFoundError [[autodoc]] huggingface_hub.utils.RepositoryNotFoundError #### GatedRepoError [[autodoc]] huggingface_hub.utils.GatedRepoError #### RevisionNotFoundError [[autodoc]] huggingface_hub.utils.RevisionNotFoundError #### EntryNotFoundError [[autodoc]] huggingface_hub.utils.EntryNotFoundError #### BadRequestError [[autodoc]] huggingface_hub.utils.BadRequestError #### LocalEntryNotFoundError [[autodoc]] huggingface_hub.utils.LocalEntryNotFoundError #### OfflineModeIsEnabled [[autodoc]] huggingface_hub.utils.OfflineModeIsEnabled ## Telemetry `huggingface_hub` includes an helper to send telemetry data. This information helps us debug issues and prioritize new features. Users can disable telemetry collection at any time by setting the `HF_HUB_DISABLE_TELEMETRY=1` environment variable. Telemetry is also disabled in offline mode (i.e. when setting HF_HUB_OFFLINE=1). If you are maintainer of a third-party library, sending telemetry data is as simple as making a call to [`send_telemetry`]. Data is sent in a separate thread to reduce as much as possible the impact for users. [[autodoc]] utils.send_telemetry ## Validators `huggingface_hub` includes custom validators to validate method arguments automatically. Validation is inspired by the work done in [Pydantic](https://pydantic-docs.helpmanual.io/) to validate type hints but with more limited features. ### Generic decorator [`~utils.validate_hf_hub_args`] is a generic decorator to encapsulate methods that have arguments following `huggingface_hub`'s naming. By default, all arguments that has a validator implemented will be validated. If an input is not valid, a [`~utils.HFValidationError`] is thrown. Only the first non-valid value throws an error and stops the validation process. Usage: ```py >>> from huggingface_hub.utils import validate_hf_hub_args >>> @validate_hf_hub_args ... def my_cool_method(repo_id: str): ... print(repo_id) >>> my_cool_method(repo_id="valid_repo_id") valid_repo_id >>> my_cool_method("other..repo..id") huggingface_hub.utils._validators.HFValidationError: Cannot have -- or .. in repo_id: 'other..repo..id'. >>> my_cool_method(repo_id="other..repo..id") huggingface_hub.utils._validators.HFValidationError: Cannot have -- or .. in repo_id: 'other..repo..id'. >>> @validate_hf_hub_args ... def my_cool_auth_method(token: str): ... print(token) >>> my_cool_auth_method(token="a token") "a token" >>> my_cool_auth_method(use_auth_token="a use_auth_token") "a use_auth_token" >>> my_cool_auth_method(token="a token", use_auth_token="a use_auth_token") UserWarning: Both `token` and `use_auth_token` are passed (...). `use_auth_token` value will be ignored. "a token" ``` #### validate_hf_hub_args [[autodoc]] utils.validate_hf_hub_args #### HFValidationError [[autodoc]] utils.HFValidationError ### Argument validators Validators can also be used individually. Here is a list of all arguments that can be validated. #### repo_id [[autodoc]] utils.validate_repo_id #### smoothly_deprecate_use_auth_token Not exactly a validator, but ran as well. [[autodoc]] utils.smoothly_deprecate_use_auth_token huggingface_hub-0.31.1/docs/source/en/package_reference/webhooks_server.md000066400000000000000000000053331500667546600267200ustar00rootroot00000000000000 # Webhooks Server Webhooks are a foundation for MLOps-related features. They allow you to listen for new changes on specific repos or to all repos belonging to particular users/organizations you're interested in following. To learn more about webhooks on the Huggingface Hub, you can read the Webhooks [guide](https://huggingface.co/docs/hub/webhooks). Check out this [guide](../guides/webhooks_server) for a step-by-step tutorial on how to setup your webhooks server and deploy it as a Space. This is an experimental feature. This means that we are still working on improving the API. Breaking changes might be introduced in the future without prior notice. Make sure to pin the version of `huggingface_hub` in your requirements. A warning is triggered when you use an experimental feature. You can disable it by setting `HF_HUB_DISABLE_EXPERIMENTAL_WARNING=1` as an environment variable. ## Server The server is a [Gradio](https://gradio.app/) app. It has a UI to display instructions for you or your users and an API to listen to webhooks. Implementing a webhook endpoint is as simple as decorating a function. You can then debug it by redirecting the Webhooks to your machine (using a Gradio tunnel) before deploying it to a Space. ### WebhooksServer [[autodoc]] huggingface_hub.WebhooksServer ### @webhook_endpoint [[autodoc]] huggingface_hub.webhook_endpoint ## Payload [`WebhookPayload`] is the main data structure that contains the payload from Webhooks. This is a `pydantic` class which makes it very easy to use with FastAPI. If you pass it as a parameter to a webhook endpoint, it will be automatically validated and parsed as a Python object. For more information about webhooks payload, you can refer to the Webhooks Payload [guide](https://huggingface.co/docs/hub/webhooks#webhook-payloads). [[autodoc]] huggingface_hub.WebhookPayload ### WebhookPayload [[autodoc]] huggingface_hub.WebhookPayload ### WebhookPayloadComment [[autodoc]] huggingface_hub.WebhookPayloadComment ### WebhookPayloadDiscussion [[autodoc]] huggingface_hub.WebhookPayloadDiscussion ### WebhookPayloadDiscussionChanges [[autodoc]] huggingface_hub.WebhookPayloadDiscussionChanges ### WebhookPayloadEvent [[autodoc]] huggingface_hub.WebhookPayloadEvent ### WebhookPayloadMovedTo [[autodoc]] huggingface_hub.WebhookPayloadMovedTo ### WebhookPayloadRepo [[autodoc]] huggingface_hub.WebhookPayloadRepo ### WebhookPayloadUrl [[autodoc]] huggingface_hub.WebhookPayloadUrl ### WebhookPayloadWebhook [[autodoc]] huggingface_hub.WebhookPayloadWebhook huggingface_hub-0.31.1/docs/source/en/quick-start.md000066400000000000000000000171561500667546600223550ustar00rootroot00000000000000 # Quickstart The [Hugging Face Hub](https://huggingface.co/) is the go-to place for sharing machine learning models, demos, datasets, and metrics. `huggingface_hub` library helps you interact with the Hub without leaving your development environment. You can create and manage repositories easily, download and upload files, and get useful model and dataset metadata from the Hub. ## Installation To get started, install the `huggingface_hub` library: ```bash pip install --upgrade huggingface_hub ``` For more details, check out the [installation](installation) guide. ## Download files Repositories on the Hub are git version controlled, and users can download a single file or the whole repository. You can use the [`hf_hub_download`] function to download files. This function will download and cache a file on your local disk. The next time you need that file, it will load from your cache, so you don't need to re-download it. You will need the repository id and the filename of the file you want to download. For example, to download the [Pegasus](https://huggingface.co/google/pegasus-xsum) model configuration file: ```py >>> from huggingface_hub import hf_hub_download >>> hf_hub_download(repo_id="google/pegasus-xsum", filename="config.json") ``` To download a specific version of the file, use the `revision` parameter to specify the branch name, tag, or commit hash. If you choose to use the commit hash, it must be the full-length hash instead of the shorter 7-character commit hash: ```py >>> from huggingface_hub import hf_hub_download >>> hf_hub_download( ... repo_id="google/pegasus-xsum", ... filename="config.json", ... revision="4d33b01d79672f27f001f6abade33f22d993b151" ... ) ``` For more details and options, see the API reference for [`hf_hub_download`]. ## Authentication In a lot of cases, you must be authenticated with a Hugging Face account to interact with the Hub: download private repos, upload files, create PRs,... [Create an account](https://huggingface.co/join) if you don't already have one, and then sign in to get your [User Access Token](https://huggingface.co/docs/hub/security-tokens) from your [Settings page](https://huggingface.co/settings/tokens). The User Access Token is used to authenticate your identity to the Hub. Tokens can have `read` or `write` permissions. Make sure to have a `write` access token if you want to create or edit a repository. Otherwise, it's best to generate a `read` token to reduce risk in case your token is inadvertently leaked. ### Login command The easiest way to authenticate is to save the token on your machine. You can do that from the terminal using the [`login`] command: ```bash huggingface-cli login ``` The command will tell you if you are already logged in and prompt you for your token. The token is then validated and saved in your `HF_HOME` directory (defaults to `~/.cache/huggingface/token`). Any script or library interacting with the Hub will use this token when sending requests. Alternatively, you can programmatically login using [`login`] in a notebook or a script: ```py >>> from huggingface_hub import login >>> login() ``` You can only be logged in to one account at a time. Logging in to a new account will automatically log you out of the previous one. To determine your currently active account, simply run the `huggingface-cli whoami` command. Once logged in, all requests to the Hub - even methods that don't necessarily require authentication - will use your access token by default. If you want to disable the implicit use of your token, you should set `HF_HUB_DISABLE_IMPLICIT_TOKEN=1` as an environment variable (see [reference](../package_reference/environment_variables#hfhubdisableimplicittoken)). ### Manage multiple tokens locally You can save multiple tokens on your machine by simply logging in with the [`login`] command with each token. If you need to switch between these tokens locally, you can use the [`auth switch`] command: ```bash huggingface-cli auth switch ``` This command will prompt you to select a token by its name from a list of saved tokens. Once selected, the chosen token becomes the _active_ token, and it will be used for all interactions with the Hub. You can list all available access tokens on your machine with `huggingface-cli auth list`. ### Environment variable The environment variable `HF_TOKEN` can also be used to authenticate yourself. This is especially useful in a Space where you can set `HF_TOKEN` as a [Space secret](https://huggingface.co/docs/hub/spaces-overview#managing-secrets). **NEW:** Google Colaboratory lets you define [private keys](https://twitter.com/GoogleColab/status/1719798406195867814) for your notebooks. Define a `HF_TOKEN` secret to be automatically authenticated! Authentication via an environment variable or a secret has priority over the token stored on your machine. ### Method parameters Finally, it is also possible to authenticate by passing your token to any method that accepts `token` as a parameter. ``` from huggingface_hub import whoami user = whoami(token=...) ``` This is usually discouraged except in an environment where you don't want to store your token permanently or if you need to handle several tokens at once. Please be careful when passing tokens as a parameter. It is always best practice to load the token from a secure vault instead of hardcoding it in your codebase or notebook. Hardcoded tokens present a major leak risk if you share your code inadvertently. ## Create a repository Once you've registered and logged in, create a repository with the [`create_repo`] function: ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.create_repo(repo_id="super-cool-model") ``` If you want your repository to be private, then: ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.create_repo(repo_id="super-cool-model", private=True) ``` Private repositories will not be visible to anyone except yourself. To create a repository or to push content to the Hub, you must provide a User Access Token that has the `write` permission. You can choose the permission when creating the token in your [Settings page](https://huggingface.co/settings/tokens). ## Upload files Use the [`upload_file`] function to add a file to your newly created repository. You need to specify: 1. The path of the file to upload. 2. The path of the file in the repository. 3. The repository id of where you want to add the file. ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.upload_file( ... path_or_fileobj="/home/lysandre/dummy-test/README.md", ... path_in_repo="README.md", ... repo_id="lysandre/test-model", ... ) ``` To upload more than one file at a time, take a look at the [Upload](./guides/upload) guide which will introduce you to several methods for uploading files (with or without git). ## Next steps The `huggingface_hub` library provides an easy way for users to interact with the Hub with Python. To learn more about how you can manage your files and repositories on the Hub, we recommend reading our [how-to guides](./guides/overview) to: - [Manage your repository](./guides/repository). - [Download](./guides/download) files from the Hub. - [Upload](./guides/upload) files to the Hub. - [Search the Hub](./guides/search) for your desired model or dataset. - [Access the Inference API](./guides/inference) for fast inference. huggingface_hub-0.31.1/docs/source/fr/000077500000000000000000000000001500667546600175575ustar00rootroot00000000000000huggingface_hub-0.31.1/docs/source/fr/_toctree.yml000066400000000000000000000005701500667546600221100ustar00rootroot00000000000000- title: "Introduction" sections: - local: index title: Home - local: quick-start title: Démarrage rapide - local: installation title: Installation - title: "Concepts" sections: - local: concepts/git_vs_http title: Git ou HTTP? - title: "Guides" sections: - local: guides/integrations title: Intégrer dans une librariehuggingface_hub-0.31.1/docs/source/fr/concepts/000077500000000000000000000000001500667546600213755ustar00rootroot00000000000000huggingface_hub-0.31.1/docs/source/fr/concepts/git_vs_http.md000066400000000000000000000111371500667546600242540ustar00rootroot00000000000000 # Git ou HTTP? `huggingface_hub` est une librairie qui permet d'interagir avec le Hugging Face Hub, qui est une collection de dépots Git (modèles, datasets ou spaces). Il y a deux manières principales pour accéder au Hub en utilisant `huggingface_hub`. La première approche, basée sur Git, appelée approche "git-based", est rendue possible par la classe [`Repository`]. Cette méthode utilise un wrapper autour de la commande `git` avec des fonctionnalités supplémentaires conçues pour interagir avec le Hub. La deuxième option, appelée approche "HTTP-based" , consiste à faire des requêtes HTTP en utilisant le client [`HfApi`]. Examinons les avantages et les inconvénients de ces deux méthodes. ## Repository: l'approche historique basée sur git Initialement, `huggingface_hub` était principalement construite autour de la classe [`Repository`]. Elle fournit des wrappers Python pour les commandes `git` usuelles, telles que `"git add"`, `"git commit"`, `"git push"`, `"git tag"`, `"git checkout"`, etc. Cette librairie permet aussi de gérer l'authentification et les fichiers volumineux, souvent présents dans les dépôts Git de machine learning. De plus, ses méthodes sont exécutables en arrière-plan, ce qui est utile pour upload des données durant l'entrainement d'un modèle. L'avantage principal de l'approche [`Repository`] est qu'elle permet de garder une copie en local du dépot Git sur votre machine. Cela peut aussi devenir un désavantage, car cette copie locale doit être mise à jour et maintenue constamment. C'est une méthode analogue au développement de logiciel classique où chaque développeur maintient sa propre copie locale et push ses changements lorsqu'il travaille sur une nouvelle fonctionnalité. Toutefois, dans le contexte du machine learning la taille des fichiers rend peu pertinente cette approche car les utilisateurs ont parfois besoin d'avoir uniquement les poids des modèles pour l'inférence ou de convertir ces poids d'un format à un autre sans avoir à cloner tout le dépôt. [`Repository`] est maintenant obsolète et remplacée par les alternatives basées sur des requêtes HTTP. Étant donné son adoption massive par les utilisateurs, la suppression complète de [`Repository`] ne sera faite que pour la version `v1.0`. ## HfApi: Un client HTTP plus flexible La classe [`HfApi`] a été développée afin de fournir une alternative aux dépôts git locaux, qui peuvent être encombrant à maintenir, en particulier pour des modèles ou datasets volumineux. La classe [`HfApi`] offre les mêmes fonctionnalités que les approches basées sur Git, telles que le téléchargement et le push de fichiers ainsi que la création de branches et de tags, mais sans avoir besoin d'un fichier local qui doit être constamment synchronisé. En plus des fonctionnalités déjà fournies par `git`, La classe [`HfApi`] offre des fonctionnalités additionnelles, telles que la capacité à gérer des dépôts, le téléchargement des fichiers dans le cache (permettant une réutilisation), la recherche dans le Hub pour trouver des dépôts et des métadonnées, l'accès aux fonctionnalités communautaires telles que, les discussions, les pull requests et les commentaires. ## Quelle méthode utiliser et quand ? En général, **l'approche HTTP est la méthode recommandée** pour utiliser `huggingface_hub` [`HfApi`] permet de pull et push des changements, de travailler avec les pull requests, les tags et les branches, l'interaction avec les discussions et bien plus encore. Depuis la version `0.16`, les méthodes HTTP-based peuvent aussi être exécutées en arrière-plan, ce qui constituait le dernier gros avantage de la classe [`Repository`]. Toutefois, certaines commandes restent indisponibles en utilisant [`HfApi`]. Peut être que certaines ne le seront jamais, mais nous essayons toujours de réduire le fossé entre ces deux approches. Si votre cas d'usage n'est pas couvert, nous serions ravis de vous aider. Pour cela, ouvrez [une issue sur Github](https://github.com/huggingface/huggingface_hub)! Nous écoutons tous les retours nous permettant de construire l'écosystème 🤗 avec les utilisateurs et pour les utilisateurs. Cette préférence pour l'approche basée sur [`HfApi`] plutôt que [`Repository`] ne signifie pas que les dépôts stopperons d'être versionnés avec git sur le Hugging Face Hub. Il sera toujours possible d'utiliser les commandes `git` en local lorsque nécessaire.huggingface_hub-0.31.1/docs/source/fr/guides/000077500000000000000000000000001500667546600210375ustar00rootroot00000000000000huggingface_hub-0.31.1/docs/source/fr/guides/integrations.md000066400000000000000000000351341500667546600240750ustar00rootroot00000000000000 # Intégrez votre framework de ML avec le Hub Le Hugging Face Hub facilite l'hébergement et le partage de modèles et de jeux de données. Des [dizaines de librairies](https://huggingface.co/docs/hub/models-libraries) sont intégrées à cet écosystème. La communauté travaille constamment à en intégrer de nouvelles et contribue ainsi à faciliter la collaboration dans le milieu du machine learning. La librairie `huggingface_hub` joue un rôle clé dans ce processus puisqu'elle permet d'interagir avec le Hub depuis n'importe quel script Python. Il existe quatre façons principales d'intégrer une bibliothèque au Hub : 1. **Push to Hub** implémente une méthode pour upload un modèle sur le Hub. Cela inclut les paramètres du modèle, sa fiche descriptive (appelée [Model Card](https://huggingface.co/docs/huggingface_hub/how-to-model-cards)) et toute autre information pertinente liée au modèle (par exemple, les logs d'entraînement). Cette méthode est souvent appelée `push_to_hub()`. 2. **Download from Hub** implémente une méthode pour charger un modèle depuis le Hub. La méthode doit télécharger la configuration et les poids du modèle puis instancier celui-ci. Cette méthode est souvent appelée `from_pretrained` ou `load_from_hub()`. 3. **Inference API** utilise nos serveurs pour faire de l'inférence gratuitement sur des modèles supportés par votre librairie. 4. **Widgets** affiche un widget sur la page d'accueil de votre modèle dans le Hub. Les widgets permettent aux utilisateurs de rapidement tester un modèle depuis le navigateur. Dans ce guide, nous nous concentrerons sur les deux premiers sujets. Nous présenterons les deux approches principales que vous pouvez utiliser pour intégrer une librairie, avec leurs avantages et leurs inconvénients. Tout est résumé à la fin du guide pour vous aider à choisir entre les deux. Veuillez garder à l'esprit que ce ne sont que des conseils, et vous êtes libres de les adapter à votre cas d'usage. Si l'Inference API et les Widgets vous intéressent, vous pouvez suivre [ce guide](https://huggingface.co/docs/hub/models-adding-libraries#set-up-the-inference-api). Dans les deux cas, vous pouvez nous contacter si vous intégrez une librairie au Hub et que vous voulez être listé [dans la documentation officielle](https://huggingface.co/docs/hub/models-libraries). ## Une approche flexible: les helpers La première approche pour intégrer une librairie au Hub est d'implémenter les méthodes `push_to_hub` et `from_pretrained` vous-même. Ceci vous donne une flexibilité totale sur le choix du fichier que vous voulez upload/download et sur comment gérer les inputs spécifiques à votre framework. Vous pouvez vous référer aux guides : [upload des fichiers](./upload) et [télécharger des fichiers](./download) pour en savoir plus sur la manière de faire. Par example, c'est de cette manière que l'intégration de FastAI est implémentée (voir [`push_to_hub_fastai`] et [`from_pretrained_fastai`]). L'implémentation peut varier entre différentes librairies, mais le workflow est souvent similaire. ### from_pretrained Voici un exemple classique pour implémenter la méthode `from_pretrained`: ```python def from_pretrained(model_id: str) -> MyModelClass: # Téléchargement des paramètres depuis le Hub cached_model = hf_hub_download( repo_id=repo_id, filename="model.pkl", library_name="fastai", library_version=get_fastai_version(), ) # Instanciation du modèle return load_model(cached_model) ``` ### push_to_hub La méthode `push_to_hub` demande souvent un peu plus de complexité pour gérer la création du dépôt git, générer la model card et enregistrer les paramètres. Une approche commune est de sauvegarder tous ces fichiers dans un dossier temporaire, les transférer sur le Hub avant de les supprimer localement. ```python def push_to_hub(model: MyModelClass, repo_name: str) -> None: api = HfApi() # Créez d'un dépôt s'il n'existe pas encore, et obtenez le repo_id associé repo_id = api.create_repo(repo_name, exist_ok=True).repo_id # Sauvegardez tous les fichiers dans un chemin temporaire, et pushez les en un seul commit with TemporaryDirectory() as tmpdir: tmpdir = Path(tmpdir) # Sauvegardez les poids save_model(model, tmpdir / "model.safetensors") # Générez la model card card = generate_model_card(model) (tmpdir / "README.md").write_text(card) # Sauvegardez les logs # Sauvegardez le métriques d'évaluation # ... # Pushez vers le Hub return api.upload_folder(repo_id=repo_id, folder_path=tmpdir) ``` Ceci n'est qu'un exemple. Si vous êtes intéressés par des manipulations plus complexes (supprimer des fichiers distants, upload des poids à la volée, maintenir les poids localement, etc.) consultez le guide [upload des fichiers](./upload). ### Limitations Bien que très flexible, cette approche a quelques défauts, particulièrement en termes de maintenance. Les utilisateurs d'Hugging Face sont habitués à utiliser certaines fonctionnalités lorsqu'ils travaillent avec `huggingface_hub`. Par exemple, lors du chargement de fichiers depuis le Hub, il est commun de passer des paramètres tels que: - `token`: pour télécharger depuis un dépôt privé - `revision`: pour télécharger depuis une branche spécifique - `cache_dir`: pour paramétrer la mise en cache des fichiers - `force_download`: pour désactiver le cache - `api_endpoint`/`proxies`: pour configurer la session HTTP Lorsque vous pushez des modèles, des paramètres similaires sont utilisables: - `commit_message`: message de commit personnalisé - `private`: crée un dépôt privé s'il en manque un - `create_pr`: crée une pull request au lieu de push vers `main` - `branch`: push vers une branche au lieu de push sur `main` - `allow_patterns`/`ignore_patterns`: filtre les fichiers à upload - `token` - `api_endpoint` - ... Tous ces paramètres peuvent être ajoutés aux implémentations vues ci-dessus et passés aux méthodes de `huggingface_hub`. Toutefois, si un paramètre change ou qu'une nouvelle fonctionnalité est ajoutée, vous devrez mettre à jour votre package. Supporter ces paramètres implique aussi plus de documentation à maintenir de votre côté. Dans la prochaine section, nous allons voir comment dépasser ces limitations. ## Une approche plus complexe: l'héritage de classe Comme vu ci-dessus, deux méthodes principales sont à inclure dans votre librairie pour l'intégrer avec le Hub: la méthode permettant d'upload des fichiers (`push_to_hub`) et celle pour télécharger des fichiers (`from_pretrained`). Vous pouvez implémenter ces méthodes vous-même mais cela a des inconvénients. Pour gérer ça, `huggingface_hub` fournit un outil qui utilise l'héritage de classe. Regardons comment ça marche ! Dans beaucoup de cas, une librairie définit déjà les modèles comme des classes Python. La classe contient les propriétés du modèle et des méthodes pour charger, lancer, entraîner et évaluer le modèle. Notre approche est d'étendre cette classe pour inclure les fonctionnalités upload et download en utilisant les mixins. Une [mixin](https://stackoverflow.com/a/547714) est une classe qui est faite pour étendre une classe existante avec une liste de fonctionnalités spécifiques en utilisant l'héritage de classe. `huggingface_hub` offre son propre mixin, le [`ModelHubMixin`]. La clef ici est de comprendre son comportement et comment le customiser. La classe [`ModelHubMixin`] implémente 3 méthodes *public* (`push_to_hub`, `save_pretrained` et `from_pretrained`). Ce sont les méthodes que vos utilisateurs appelleront pour charger/enregistrer des modèles avec votre librairie. [`ModelHubMixin`] définit aussi 2 méthodes *private* (`_save_pretrained` et `from_pretrained`). Ce sont celles que vous devez implémenter. Ainsi, pour intégrer votre librairie, vous devez : 1. Faire en sorte que votre classe Model hérite de [`ModelHubMixin`]. 2. Implémenter les méthodes privées: - [`~ModelHubMixin._save_pretrained`]: méthode qui prend en entrée un chemin vers un directory et qui sauvegarde le modèle. Vous devez écrire toute la logique pour dump votre modèle de cette manière: model card, poids du modèle, fichiers de configuration, et logs d'entraînement. Toute information pertinente pour ce modèle doit être gérée par cette méthode. Les [model cards](https://huggingface.co/docs/hub/model-cards) sont particulièrement importantes pour décrire votre modèle. Vérifiez [notre guide d'implémentation](./model-cards) pour plus de détails. - [`~ModelHubMixin._from_pretrained`]: **méthode de classe** prenant en entrée un `model_id` et qui retourne un modèle instancié. Cette méthode doit télécharger un ou plusieurs fichier(s) et le(s) charger. 3. Fini! L'avantage d'utiliser [`ModelHubMixin`] est qu'une fois que vous vous êtes occupés de la sérialisation et du chargement du fichier, vous êtes prêts. Vous n'avez pas besoin de vous soucier de la création du dépôt, des commits, des pull requests ou des révisions. Tout ceci est géré par le mixin et est disponible pour vos utilisateurs. Le Mixin s'assure aussi que les méthodes publiques sont bien documentées et que les annotations de typage sont spécifiées. ### Un exemple concret: PyTorch Un bon exemple de ce que nous avons vu ci-dessus est [`PyTorchModelHubMixin`], notre intégration pour le framework PyTorch. C'est une intégration prête à l'emploi. #### Comment l'utiliser ? Voici comment n'importe quel utilisateur peut charger/enregistrer un modèle Pytorch depuis/vers le Hub: ```python >>> import torch >>> import torch.nn as nn >>> from huggingface_hub import PyTorchModelHubMixin # 1. Définissez votre modèle Pytorch exactement comme vous êtes habitués à le faire >>> class MyModel(nn.Module, PyTorchModelHubMixin): # héritage multiple ... def __init__(self): ... super().__init__() ... self.param = nn.Parameter(torch.rand(3, 4)) ... self.linear = nn.Linear(4, 5) ... def forward(self, x): ... return self.linear(x + self.param) >>> model = MyModel() # 2. (optionnel) Sauvegarder le modèle dans un chemin local >>> model.save_pretrained("path/to/my-awesome-model") # 3. Pushez les poids du modèle vers le Hub >>> model.push_to_hub("my-awesome-model") # 4. initialisez le modèle depuis le Hub >>> model = MyModel.from_pretrained("username/my-awesome-model") ``` #### Implémentation L'implémentation est très succincte (voir [ici](https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/hub_mixin.py)). 1. Premièrement, faites hériter votre classe de `ModelHubMixin`: ```python from huggingface_hub import ModelHubMixin class PyTorchModelHubMixin(ModelHubMixin): (...) ``` 2. Implémentez la méthode `_save_pretrained`: ```py from huggingface_hub import ModelCard, ModelCardData class PyTorchModelHubMixin(ModelHubMixin): (...) def _save_pretrained(self, save_directory: Path): """Générez une model card et enregistrez les poids d'un modèle Pytroch vers un chemin local.""" model_card = ModelCard.from_template( card_data=ModelCardData( license='mit', library_name="pytorch", ... ), model_summary=..., model_type=..., ... ) (save_directory / "README.md").write_text(str(model)) torch.save(obj=self.module.state_dict(), f=save_directory / "pytorch_model.bin") ``` 3. Implémentez la méthode `_from_pretrained`: ```python class PyTorchModelHubMixin(ModelHubMixin): (...) @classmethod # Doit absolument être une méthode de clase ! def _from_pretrained( cls, *, model_id: str, revision: str, cache_dir: str, force_download: bool, proxies: Optional[Dict], resume_download: bool, local_files_only: bool, token: Union[str, bool, None], map_location: str = "cpu", # argument supplémentaire strict: bool = False, # argument supplémentaire **model_kwargs, ): """Chargez les poids pré-entrainés et renvoyez les au modèle chargé.""" if os.path.isdir(model_id): # Peut être un chemin local print("Loading weights from local directory") model_file = os.path.join(model_id, "pytorch_model.bin") else: # Ou un modèle du Hub model_file = hf_hub_download( # Téléchargez depuis le Hub, en passant le mêmes arguments d'entrée repo_id=model_id, filename="pytorch_model.bin", revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, token=token, local_files_only=local_files_only, ) # Chargez le modèle et reoutnez une logique personnalisée dépendant de votre framework model = cls(**model_kwargs) state_dict = torch.load(model_file, map_location=torch.device(map_location)) model.load_state_dict(state_dict, strict=strict) model.eval() return model ``` Et c'est fini ! Votre librairie permet maintenant aux utilisateurs d'upload et de télécharger des fichiers vers et depuis le Hub. ## Comparaison Résumons rapidement les deux approches que nous avons vu avec leurs avantages et leurs défauts. Le tableau ci-dessous est purement indicatif. Votre framework aura peut-êre des spécificités à prendre en compte. Ce guide est ici pour vous donner des indications et des idées sur comment gérer l'intégration. Dans tous les cas, n'hésitez pas à nous contacter si vous avez une question ! | Intégration | Utilisant des helpers | Utilisant [`ModelHubMixin`] | |:---:|:---:|:---:| | Expérience utilisateur | `model = load_from_hub(...)`
`push_to_hub(model, ...)` | `model = MyModel.from_pretrained(...)`
`model.push_to_hub(...)` | | Flexible | Très flexible.
Vous controllez complètement l'implémentation. | Moins flexible.
Votre framework doit avoir une classe de modèle. | | Maintenance | Plus de maintenance pour ajouter du support pour la configuration, et de nouvelles fonctionnalités. Peut aussi nécessiter de fixx des problèmes signalés par les utilisateurs.| Moins de maintenance vu que la plupart des intégrations avec le Hub sont implémentés dans `huggingface_hub` | | Documentation / Anotation de type| A écrire à la main | Géré partiellement par `huggingface_hub`. | huggingface_hub-0.31.1/docs/source/fr/index.md000066400000000000000000000100461500667546600212110ustar00rootroot00000000000000 # Un client Python pour le Hugging Face Hub La librairie `huggingface_hub` vous permet d'interagir avec le [Hugging Face Hub](https://hf.co), une plateforme de machine learning pour créer et collaborer. Découvrez des modèles pré- entrainés et des datasets pour vos projets ou jouez avec des centraines d'applications hébergées sur le Hub. Vous pouvez aussi créer et partager vos propres modèles et datasets avec la communauté. La librairie `huggingface_hub` offre une manière simple de faire toutes ces choses avec Python. Lisez le [guide d'introduction rapide](quick-start) pour vous lancer avec la librairie `huggingface_hub`. Vous apprendrez à télécharger des fichiers depuis le Hub, à créer un dépôt et upload des fichiers vers le Hub. Continuez à lire pour apprendre le management de dépôt sur le Hub, comment interagir avec les discussions ou même comment accéder à l'API d'inférence. ## Contributions Toutes les contributions au projet `huggingface_hub` sont les bienvenues et valorisées à la même hauteur. 🤗 En plus de l'ajout ou de la correction de bug dans le code, vous pouvez également aider à améliorer la documentation en vérifiant qu'elle est exacte et à jour, répondre à des questions sur des issues, et demander de nouvelles fonctionnalités qui amélioreront la librairie. Regardez le [guide de contribution](https://github.com/huggingface/huggingface_hub/blob/main/CONTRIBUTING.md) pour en savoir plus sur comment commencer à travailler sur une issue, comment faire une pull request et comment tester vos contributions pour vérifier que vos modifications fonctionnent comme prévu. Les contributeurs doivent aussi respecter notre [code de conduite](https://github.com/huggingface/huggingface_hub/blob/main/CODE_OF_CONDUCT.md) (en anglais) pour créer un espace collaboratif inclusif et bienveillant envers tout le monde.huggingface_hub-0.31.1/docs/source/fr/installation.md000066400000000000000000000174541500667546600226150ustar00rootroot00000000000000 # Installation Avant de commencer, vous allez avoir besoin de préparer votre environnement en installant les packages appropriés. `huggingface_hub` est testée sur **Python 3.8+**. ## Installation avec pip Il est fortement recommandé d'installer `huggingface_hub` dans un [environnement virtuel](https://docs.python.org/3/library/venv.html). Si vous n'êtes pas familier avec les environnements virtuels Python, suivez ce [guide](https://packaging.python.org/en/latest/guides/installing-using-pip-and-virtual-environments/). Un environnement virtuel sera utile lorsque vous devrez gérer des plusieurs projets en parallèle afin d'éviter les problèmes de compatibilité entre les différentes dépendances. Commencez par créer un environnement virtuel à l'emplacement de votre projet: ```bash python -m venv .env ``` Activez l'environnement virtuel sur Linux et macOS: ```bash source .env/bin/activate ``` Activez l'environnement virtuel sur Windows: ```bash .env/Scripts/activate ``` Maintenant, vous êtes prêts à installer `hugginface_hub` [depuis PyPi](https://pypi.org/project/huggingface-hub/): ```bash pip install --upgrade huggingface_hub ``` Une fois l'installation terminée, rendez-vous à la section [vérification](#verification-de-l-installation) pour s'assurer que tout fonctionne correctement. ### Installation des dépendances optionnelles Certaines dépendances de `huggingface_hub` sont [optionnelles](https://setuptools.pypa.io/en/latest/userguide/dependency_management.html#optional-dependencies) car elles ne sont pas nécessaire pour faire marcher les fonctionnalités principales de `huggingface_hub`. Toutefois, certaines fonctionnalités de `huggingface_hub` ne seront pas disponibles si les dépendances optionnelles ne sont pas installées Vous pouvez installer des dépendances optionnelles via `pip`: ```bash #Installation des dépendances pour les fonctionnalités spécifiques à Tensorflow. #/!\ Attention : cette commande n'est pas équivalente à `pip install tensorflow`. pip install 'huggingface_hub[tensorflow]' #Installation des dépendances spécifiques à Pytorch et au CLI. pip install 'huggingface_hub[cli,torch]' ``` Voici une liste des dépendances optionnelles dans `huggingface_hub`: - `cli` fournit une interface d'invite de commande plus pratique pour `huggingface_hub`. - `fastai`, `torch` et `tensorflow` sont des dépendances pour utiliser des fonctionnalités spécifiques à un framework. - `dev` permet de contribuer à la librairie. Cette dépendance inclut `testing` (pour lancer des tests), `typing` (pour lancer le vérifieur de type) et `quality` (pour lancer des linters). ### Installation depuis le code source Dans certains cas, il est intéressant d'installer `huggingface_hub` directement depuis le code source. Ceci vous permet d'utiliser la version `main`, contenant les dernières mises à jour, plutôt que d'utiliser la dernière version stable. La version `main` est utile pour rester à jour sur les derniers développements, par exemple si un bug est corrigé depuis la dernière version officielle mais que la nouvelle version n'a pas encore été faite. Toutefois, cela signifie que la version `main` peut ne pas être stable. Nous travaillons afin de rendre la version `main` aussi stable que possible, et la plupart des problèmes sont résolus en quelques heures ou jours. Si vous avez un problème, ouvrez une issue afin que nous puissions la régler au plus vite ! ```bash pip install git+https://github.com/huggingface/huggingface_hub ``` Lorsque vous installez depuis le code source, vous pouvez préciser la branche depuis laquelle installer. Cela permet de tester une nouvelle fonctionnalité ou un bug-fix qui n'a pas encore été merge: ```bash pip install git+https://github.com/huggingface/huggingface_hub@ma-branche ``` Une fois l'installation terminée, rendez-vous à la section [vérification](#verification-de-l-installation) pour s'assurer que tout fonctionne correctement. ### Installation éditable L'installation depuis le code source vous permet de mettre en place une [installation éditable](https://pip.pypa.io/en/stable/topics/local-project-installs/#editable-installs). Cette installation sert surtout si vous comptez contribuer à `huggingface_hub` et que vous avez besoin de tester rapidement des changements dans le code. Pour cela, vous devez cloner le projet `huggingface_hub` sur votre machine. ```bash # Commencez par cloner le dépôt en local git clone https://github.com/huggingface/huggingface_hub.git # Ensuite, installez-le avec le flag -e cd huggingface_hub pip install -e . ``` Python regardera maintenant à l'intérieur du dossier dans lequel vous avez cloné le dépôt en plus des chemins de librairie classiques. Par exemple, si vos packages Python sont installés dans `./.venv/lib/python3.13/site-packages/`, Python regardera aussi dans le dossier que vous avez cloné `./huggingface_hub/`. ## Installation avec conda Si vous avez plutôt l'habitude d'utiliser conda, vous pouvez installer `huggingface_hub` en utilisant le [channel conda-forge](https://anaconda.org/conda-forge/huggingface_hub): ```bash conda install -c conda-forge huggingface_hub ``` Une fois l'installation terminée, rendez-vous à la section [vérification](#verification-de-l-installation) pour s'assurer que tout fonctionne correctement. ## Vérification de l'installation Une fois installée, vérifiez que `huggingface_hub` marche correctement en lançant la commande suivante: ```bash python -c "from huggingface_hub import model_info; print(model_info('gpt2'))" ``` Cette commande va récupérer des informations sur le modèle [gpt2](https://huggingface.co/gpt2) depuis le Hub. La sortie devrait ressembler à ça: ```text Model Name: gpt2 Tags: ['pytorch', 'tf', 'jax', 'tflite', 'rust', 'safetensors', 'gpt2', 'text-generation', 'en', 'doi:10.57967/hf/0039', 'transformers', 'exbert', 'license:mit', 'has_space'] Task: text-generation ``` ## Les limitations Windows Afin de démocratiser le machine learning au plus grand nombre, nous avons développé `huggingface_hub` de manière cross-platform et en particulier, pour qu'elle fonctionne sur une maximum de systèmes d'exploitation différents. Toutefois `huggingface_hub` connaît dans certains cas des limitations sur Windows. Nous avons listé ci-dessous les problèmes connus. N'hésitez pas à nous signaler si vous rencontrez un problème non documenté en ouvrant une [issue sur Github](https://github.com/huggingface/huggingface_hub/issues/new/choose). - Le cache de `huggingface_hub` a besoin des symlinks pour mettre en cache les fichiers installés depuis le Hub. Sur windows, vous devez activer le mode développeur pour lancer ou lancer votre script en tant qu'administrateur afin de faire fonctionner les symlinks. S'ils ne sont pas activés, le système de cache fonctionnera toujours mais de manière sous-optimale. Consultez les [limitations du cache](./guides/manage-cache#limitations) pour plus de détails. - Les noms de fichiers sur le Hub peuvent avoir des caractères spéciaux (par exemple `"path/to?/my/file"`). Windows est plus restrictif sur les [caractères spéciaux](https://learn.microsoft.com/en-us/windows/win32/intl/character-sets-used-in-file-names) ce qui rend ces fichiers ininstallables sur Windows. Heureusement c'est un cas assez rare. Contactez le propriétaire du dépôt si vous pensez que c'est une erreur ou contactez nous pour que nous cherchions une solution. ## Prochaines étapes Une fois que `huggingface_hub` est installé correctement sur votre machine, vous aurez peut-être besoin de [configurer les variables d'environnement](package_reference/environment_variables) ou de [lire un de nos guides](guides/overview) pour vous lancer. huggingface_hub-0.31.1/docs/source/fr/quick-start.md000066400000000000000000000145721500667546600223610ustar00rootroot00000000000000 # Démarrage rapide Le [Hugging Face Hub](https://huggingface.co/) est le meilleur endroit pour partager des modèles de machine learning, des démos, des datasets et des métriques. La librairie `huggingface_hub` vous aide à intéragir avec le Hub sans sortir de votre environnement de développement. Vous pouvez: créer et gérer des dépôts facilement, télécharger et upload des fichiers, et obtenir des modèles et des métadonnées depuis le Hub. ## Installation Pour commencer, installez la librairie `huggingface_hub`: ```bash pip install --upgrade huggingface_hub ``` Pour plus de détails, vérifiez le guide d'[installation](installation) ## Télécharger des fichiers Les dépôts sur le Hub utilisent le versioning Git, les utilisateurs peuvent télécharger un fichier, ou un dépôt entier. Vous pouvez utiliser la fonction [`hf_hub_download`] pour télécharger des fichiers. Cette fonction téléchargera et mettra dans le cache un fichier sur votre disque local. La prochaine fois que vous aurez besoin de ce fichier, il sera chargé depuis votre cache de façon à ce que vous n'ayez pas besoin de le retélécharger. Vous aurez besoin de l'id du dépôt et du nom du fichier que vous voulez télécharger. Par exemple, pour télécharger le fichier de configuration du modèle [Pegasus](https://huggingface.co/google/pegasus-xsum): ```py >>> from huggingface_hub import hf_hub_download >>> hf_hub_download(repo_id="google/pegasus-xsum", filename="config.json") ``` Pour télécharger une version spécifique du fichier, utilisez le paramètre `revision` afin de spécifier le nom de la branche, le tag ou le hash de commit. Si vous décidez d'utiliser le hash de commit, vous devez renseigner le hash entier et pas le hash court de 7 caractères: ```py >>> from huggingface_hub import hf_hub_download >>> hf_hub_download( ... repo_id="google/pegasus-xsum", ... filename="config.json", ... revision="4d33b01d79672f27f001f6abade33f22d993b151" ... ) ``` Pour plus de détails et d'options, consultez la réference de l'API pour [`hf_hub_download`]. ## Connexion Dans la plupart des cas, vous devez être connectés avec un compte Hugging Face pour interagir avec le Hub: pour télécharger des dépôts privés, upload des fichiers, créer des pull requests... [Créez un compte](https://huggingface.co/join) si vous n'en avez pas déjà un et connectez vous pour obtenir votre [token d'authentification](https://huggingface.co/docs/hub/security-tokens) depuis vos [paramètres](https://huggingface.co/settings/tokens). Le token est utilisé pour authentifier votre identité au Hub. Une fois que vous avez votre token d'authentification, lancez la commande suivante dans votre terminal: ```bash huggingface-cli login # ou en utilisant une varible d'environnement: huggingface-cli login --token $HUGGINGFACE_TOKEN ``` Sinon, vous pouvez vous connecter en utilisant [`login`] dans un notebook ou un script: ```py >>> from huggingface_hub import login >>> login() ``` Il est aussi possible de se connecter automatiquement sans qu'on vous demande votre token en passant le token dans [`login`] de cette manière: `login(token="hf_xxx")`. Si vous choisissez cette méthode, faites attention lorsque vous partagez votre code source. Une bonne pratique est de charger le token depuis un trousseau sécurisé au lieu de l'enregistrer en clair dans votre codebase/notebook. Vous ne pouvez être connecté qu'à un seul compte à la fois. Si vous connectez votre machine à un autre compte, vous serez déconnecté du premier compte. Vérifiez toujours le compte que vous utilisez avec la commande `huggingface-cli whoami`. Si vous voulez gérer plusieurs compte dans le même script, vous pouvez passer votre token à chaque appel de méthode. C'est aussi utile si vous ne voulez pas sauvegarder de token sur votre machine. Une fois que vous êtes connectés, toutes les requêtes vers le Hub (même les méthodes qui ne nécessite pas explicitement d'authentification) utiliseront votre token d'authentification par défaut. Si vous voulez supprimer l'utilisation implicite de votre token, vous devez définir la variable d'environnement `HF_HUB_DISABLE_IMPLICIT_TOKEN`. ## Créer un dépôt Une fois que vous avez créé votre compte et que vous vous êtes connectés, vous pouvez créer un dépôt avec la fonction [`create_repo`]: ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.create_repo(repo_id="super-cool-model") ``` Si vous voulez que votre dépôt soit privé, alors: ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.create_repo(repo_id="super-cool-model", private=True) ``` Les dépôts privés ne seront visible que par vous. Pour créer un dépôt ou push du contenu sur le Hub, vous devez fournir un token d'authentification qui a les permissions `write`. Vous pouvez choisir la permission lorsque vous générez le token dans vos [paramètres](https://huggingface.co/settings/tokens). ## Upload des fichiers Utilisez la fonction [`upload_file`] pour ajouter un fichier à votre dépôt. Vous devez spécifier: 1. Le chemin du fichier à upload. 2. Le chemin du fichier dans le dépôt. 3. L'id du dépôt dans lequel vous voulez ajouter le fichier. ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.upload_file( ... path_or_fileobj="/home/lysandre/dummy-test/README.md", ... path_in_repo="README.md", ... repo_id="lysandre/test-model", ... ) ``` Pour upload plus d'un fichier à la fois, consultez le guide [upload](./guides/upload) qui détaille plusieurs méthodes pour upload des fichiers (avec ou sans Git). ## Prochaines étapes La librairie `huggingface_hub` permet à ses utilisateurs d'intéragir facilementavec le Hub via Python. Pour en apprendre plus sur comment gérer vos fichiers et vos dépôts sur le Hub, nous vous recommandons de lire notre [guide conceptuel](./guides/overview) pour : - [Gérer votre dépôt](./guides/repository). - [Télécharger](./guides/download) des fichiers depuis le Hub. - [Upload](./guides/upload) des fichiers vers le Hub. - [Faire des recherches dans le Hub](./guides/search) pour votre modèle ou dataset. - [Accédder à l'API d'inférence](./guides/inference) pour faire des inférences rapides.huggingface_hub-0.31.1/docs/source/hi/000077500000000000000000000000001500667546600175505ustar00rootroot00000000000000huggingface_hub-0.31.1/docs/source/hi/_toctree.yml000066400000000000000000000003501500667546600220750ustar00rootroot00000000000000- title: "शुरू हो जाओ" sections: - local: index title: होम - local: quick-start title: जल्दी शुरू - local: installation title: इंस्टालेशन huggingface_hub-0.31.1/docs/source/hi/index.md000066400000000000000000000141131500667546600212010ustar00rootroot00000000000000 # 🤗 हब क्लाइंट लाइब्रेरी `huggingface_hub` लाइब्रेरी आपको [हगिंग फेस' के साथ काम करने की अनुमति देती है हब](https://hf.co), रचनाकारों और सहयोगियों के लिए एक मशीन लर्निंग प्लेटफॉर्म। अपनी परियोजनाओं के लिए पूर्व-प्रशिक्षित मॉडल और डेटासेट खोजें या सैकड़ों के साथ खेलें हब पर होस्ट किए गए मशीन लर्निंग ऐप्स। आप अपने स्वयं के मॉडल और डेटासेट्स भी बना सकते हैं और उन्हें समुदाय के साथ साझा कर सकते हैं। `huggingface_hub` लाइब्रेरी इसका एक आसान तरीका प्रदान करती है ये सभी चीजें Python के साथ करें। इसके साथ काम के लिए [quick-start] (क्विक-स्टार्ट) पढ़ें `huggingface_hub` लाइब्रेरी। आप सीखेंगे कि हब से फाइलें कैसे डाउनलोड करें, एक रिपॉजिटरी कैसे बनाएं, और हब पर फाइलें कैसे अपलोड करें। 🤗 हब पर अपनी रिपॉजिटरी कैसे प्रबंधित करें, चर्चाओं में कैसे भाग लें, या इन्फरेंस एपीआई तक कैसे पहुँचें, इसके बारे में अधिक जानने के लिए पढ़ते रहें। ## योगदान देना `huggingface_hub` में सभी योगदानों का स्वागत किया जाता है और समान रूप से महत्व दिया जाता है! 🤗 इसके अलावा कोड में मौजूदा समस्याओं को जोड़ने या ठीक करने से आप इसे बेहतर बनाने में भी मदद कर सकते हैं यह सुनिश्चित करके कि दस्तावेज़ीकरण सटीक और अद्यतित है, प्रश्नों के उत्तर देने में सहायता करें मुद्दे, और नई सुविधाओं का अनुरोध करें जो आपको लगता है कि लाइब्रेरी में सुधार करेगी। एक नया मुद्दा या सुविधा अनुरोध कैसे सबमिट करें, पुल अनुरोध कैसे सबमिट करें, और यह सुनिश्चित करने के लिए अपने योगदानों का परीक्षण कैसे करें कि सब कुछ अपेक्षा के अनुरूप काम करता है, इसके बारे में अधिक जानने के लिए [योगदान गाइड](https://github.com/huggingface/huggingface_hub/blob/main/CONTRIBUTING.md) पर एक नज़र डालें। योगदानकर्ताओं को सभी के लिए एक समावेशी और स्वागत योग्य सहयोगी स्थान बनाने के लिए हमारे [आचार संहिता](https://github.com/huggingface/huggingface_hub/blob/main/CODE_OF_CONDUCT.md), का भी सम्मान करना चाहिए।huggingface_hub-0.31.1/docs/source/hi/installation.md000066400000000000000000000333251500667546600226010ustar00rootroot00000000000000 # स्थापना आरंभ करने से पहले, आपको उपयुक्त पैकेज स्थापित करके अपना परिवेश सेटअप करना होगा। `huggingface_hub` का परीक्षण **Python 3.8+** पर किया गया है। ## पिप के साथ स्थापित करें [वर्चुअल वातावरण](https://docs.python.org/3/library/venv.html) में `huggingface_hub` इंस्टॉल करने की अत्यधिक अनुशंसा की जाती है। यदि आप Python वर्चुअल वातावरण से अपरिचित हैं, तो इस [गाइड](https://packaging.python.org/en/latest/guides/installing-using-pip-and-virtual-environments/) पर एक नज़र डालें। एक वर्चुअल वातावरण विभिन्न परियोजनाओं को प्रबंधित करना आसान बनाता है, और निर्भरताओं के बीच संगतता समस्याओं से बचाता है। अपनी प्रोजेक्ट निर्देशिका में एक वर्चुअल वातावरण बनाकर प्रारंभ करें: ```bash python -m venv .env ``` वर्चुअल वातावरण सक्रिय करें. Linux और macOS पर: ```bash source .env/bin/activate ``` वर्चुअल वातावरण सक्रिय करें Windows पर: ```bash .env/Scripts/activate ``` अब आप `huggingface_hub` [PyPi रजिस्ट्री से](https://pypi.org/project/huggingface-hub/), इंस्टॉल करने के लिए तैयार हैं: ```bash pip install --upgrade huggingface_hub ``` एक बार हो जाने के बाद [चेक इंस्टालेशन](#चेक-इंस्टॉलेशन), यह सुनिश्चित करने के लिए कि वह ठीक से काम कर रहा है। ### वैकल्पिक निर्भरताएँ स्थापित करें `huggingface_hub` की कुछ निर्भरताएं [वैकल्पिक](https://setuptools.pypa.io/en/latest/userguide/dependency_management.html#optional-dependencies) हैं क्योंकि उन्हें `huggingface_hub` की मुख्य विशेषताओं को चलाने की आवश्यकता नहीं है। हालाँकि, यदि वैकल्पिक निर्भरताएँ स्थापित नहीं हैं तो `huggingface_hub` की कुछ सुविधाएँ उपलब्ध नहीं हो सकती हैं। आप `pip` के माध्यम से वैकल्पिक निर्भरताएँ स्थापित कर सकते हैं: ```bash # Install dependencies for tensorflow-specific features # /!\ Warning: this is not equivalent to `pip install tensorflow` pip install 'huggingface_hub[tensorflow]' # Install dependencies for both torch-specific and CLI-specific features. pip install 'huggingface_hub[cli,torch]' ``` यहां `huggingface_hub` में वैकल्पिक निर्भरताओं की सूची दी गई है: - `cli`: `huggingface_hub` के लिए अधिक सुविधाजनक CLI इंटरफ़ेस प्रदान करें। - `fastai`, `torch`, `tensorflow`: फ्रेमवर्क-विशिष्ट सुविधाओं को चलाने के लिए निर्भरताएँ। - `dev`: lib में योगदान करने के लिए निर्भरताएँ। इसमें 'परीक्षण' (परीक्षण चलाने के लिए), 'टाइपिंग' (टाइप चेकर चलाने के लिए) और 'गुणवत्ता' (लिंटर चलाने के लिए) शामिल हैं। ### स्रोत से इंस्टॉल करें कुछ मामलों में, `huggingface_hub` को सीधे स्रोत से स्थापित करना दिलचस्प होता है। यह आपको नवीनतम स्थिर संस्करण के बजाय अत्याधुनिक `main` संस्करण का उपयोग करने की अनुमति देता है। `main` संस्करण नवीनतम विकास के साथ अद्यतित रहने के लिए उपयोगी है, उदाहरण के लिए यदि अंतिम आधिकारिक रिलीज के बाद से एक बग को ठीक किया गया है लेकिन अभी तक एक नई रिलीज शुरू नहीं की गई है। हालांकि, इसका मतलब है कि `main` संस्करण हमेशा स्थिर नहीं हो सकता है। हम `main` संस्करण को चालू रखने का प्रयास करते हैं, और अधिकांश समस्याएं आमतौर पर कुछ घंटों या एक दिन के भीतर हल हो जाती हैं। यदि आप किसी समस्या का सामना करते हैं, तो कृपया एक समस्या खोलें ताकि हम इसे और भी जल्दी ठीक कर सकें! ```bash pip install git+https://github.com/huggingface/huggingface_hub ``` स्रोत से इंस्टॉल करते समय, आप एक विशिष्ट शाखा भी निर्दिष्ट कर सकते हैं। यह तब उपयोगी होता है जब आप किसी नई सुविधा या नए बग-फिक्स का परीक्षण करना चाहते हैं जिसे अभी तक मर्ज नहीं किया गया है: ```bash pip install git+https://github.com/huggingface/huggingface_hub@my-feature-branch ``` एक बार हो जाने के बाद [चेक इंस्टालेशन](#चेक-इंस्टॉलेशन), यह सुनिश्चित करने के लिए कि वह ठीक से काम कर रहा है। ### संपादन योग्य इंस्टॉल स्रोत से इंस्टॉल करने से आपको एक [संपादन योग्य इंस्टॉल](https://pip.pypa.io/en/stable/topics/local-project-installs/#editable-installs) कर सकते हैं। यदि आप `huggingface_hub` में योगदान करने की योजना बना रहे हैं और कोड में परिवर्तनों का परीक्षण करने की आवश्यकता है, तो यह एक अधिक उन्नत इंस्टॉलेशन है। आपको अपनी मशीन पर `huggingface_hub` की एक स्थानीय प्रति क्लोन करने की आवश्यकता है। ```bash # First, clone repo locally git clone https://github.com/huggingface/huggingface_hub.git # Then, install with -e flag cd huggingface_hub pip install -e . ``` ये कमांड उस फ़ोल्डर को लिंक करेंगे जिसमें आपने रिपॉजिटरी को क्लोन किया था और आपके Python लाइब्रेरी पथ। Python अब सामान्य लाइब्रेरी पथ के अलावा आपके द्वारा क्लोन किए गए फ़ोल्डर के अंदर भी देखेगा। उदाहरण के लिए, यदि आपके Python पैकेज आमतौर पर `./.venv/lib/python3.11/site-packages/` में स्थापित होते हैं, तो Python उस फ़ोल्डर को भी खोजेगा जिसे आपने `./huggingface_hub/` क्लोन किया था। ## कोंडा के साथ स्थापित करें यदि आप इससे अधिक परिचित हैं, तो आप [conda-forge चैनल](https://anaconda.org/conda-forge/huggingface_hub) का उपयोग करके `huggingface_hub` इंस्टॉल कर सकते हैं: ```bash conda install -c conda-forge huggingface_hub ``` एक बार हो जाने के बाद [चेक इंस्टालेशन](#चेक-इंस्टॉलेशन), यह सुनिश्चित करने के लिए कि वह ठीक से काम कर रहा है। ## स्थापना की जाँच करें एक बार इंस्टॉल हो जाने पर, निम्नलिखित कमांड चलाकर जांचें कि `huggingface_hub` ठीक से काम करता है: ```bash python -c "from huggingface_hub import model_info; print(model_info('gpt2'))" ``` यह कमांड हब से [gpt2](https://huggingface.co/gpt2) मॉडल के बारे में जानकारी प्राप्त करेगा। आउटपुट इस तरह दिखना चाहिए: ```text Model Name: gpt2 Tags: ['pytorch', 'tf', 'jax', 'tflite', 'rust', 'safetensors', 'gpt2', 'text-generation', 'en', 'doi:10.57967/hf/0039', 'transformers', 'exbert', 'license:mit', 'has_space'] Task: text-generation ``` ## विंडोज़ सीमाएँ हर जगह अच्छे एमएल को लोकतांत्रिक बनाने के हमारे लक्ष्य के साथ, हमने `huggingface_hub` को एक क्रॉस-प्लेटफ़ॉर्म लाइब्रेरी बनाने के लिए बनाया है और विशेष रूप से यूनिक्स-आधारित और विंडोज सिस्टम दोनों पर सही ढंग से काम करने के लिए। हालाँकि, ऐसे कुछ मामले हैं जहाँ विंडोज़ पर चलने पर `huggingface_hub` की कुछ सीमाएँ हैं। यहां ज्ञात मुद्दों की एक विस्तृत सूची दी गई है। यदि आप [Github](https://github.com/huggingface/huggingface_hub/issues/new/choose) पर एक समस्या खोलकर किसी अनिर्दिष्ट समस्या का सामना करते हैं तो कृपया हमें बताएं। - `huggingface_hub` का `cache` सिस्टम हब से डाउनलोड की गई फ़ाइलों को कुशलतापूर्वक `cache` करने के लिए सिमलिंक पर निर्भर करता है। विंडोज़ पर, आपको सिमलिंक को सक्षम करने के लिए डेवलपर मोड को सक्रिय करना होगा या अपने स्क्रिप्ट को व्यवस्थापक के रूप में चलाना होगा। यदि वे सक्रिय नहीं हैं, तो cache-सिस्टम अभी भी काम करता है लेकिन गैर-अनुकूलित तरीके से। अधिक जानकारी के लिए कृपया [cache सीमाएँ](./guides/manage-cache#limities) अनुभाग पढ़ें। - हब पर फ़ाइलपथ में विशेष वर्ण हो सकते हैं (उदा. `"path/to?/my/file"`)। विंडोज़ विशेष वर्णों पर अधिक प्रतिबंधात्मक है जिससे उन फ़ाइलों को विंडोज़ पर डाउनलोड करना असंभव हो जाता है। उम्मीद है कि यह एक दुर्लभ मामला है। अगर आपको लगता है कि यह एक गलती है तो कृपया रेपो मालिक से संपर्क करें या समाधान निकालने के लिए हमसे संपर्क करें। ## अगले कदम एक बार जब `huggingface_hub` आपकी मशीन पर ठीक से स्थापित हो जाता है, तो आप आरंभ करने के लिए [पर्यावरण चर कॉन्फ़िगर करें](package_reference/environment_variables) या [हमारे गाइडों में से एक की जांच करें](guides/overview)। huggingface_hub-0.31.1/docs/source/hi/quick-start.md000066400000000000000000000407041500667546600223460ustar00rootroot00000000000000 # जल्दी शुरू [Hugging Face Hub](https://huggingface.co/) मशीन लर्निंग मॉडल, डेमो, डेटासेट और मेट्रिक्स साझा करने के लिए सबसे उपयुक्त स्थान है। `huggingface_hub` लाइब्रेरी आपको अपने विकास परिवेश को छोड़े बिना हब के साथ इंटरैक्ट करने में मदद करती है। आप आसानी से रिपॉजिटरी बना और प्रबंधित कर सकते हैं, फ़ाइलें डाउनलोड और अपलोड कर सकते हैं, और हब से उपयोगी मॉडल और डेटासेट मेटाडेटा प्राप्त कर सकते हैं। ## इंस्टालेशन आरंभ करने के लिए, `huggingface_hub` लाइब्रेरी स्थापित करें: ```bash pip install --upgrade huggingface_hub ``` अधिक विवरण के लिए, [installation](इंस्टॉलेशन) गाइड देखें। ## फ़ाइलें डाउनलोड करें हब पर रिपॉजिटरी `git` वर्जन नियंत्रित हैं, और उपयोगकर्ता एक फ़ाइल या पूरी रिपॉजिटरी डाउनलोड कर सकते हैं। फ़ाइलों को डाउनलोड करने के लिए आप [`hf_hub_download`] फ़ंक्शन का उपयोग कर सकते हैं। यह फ़ंक्शन आपकी स्थानीय डिस्क पर एक फ़ाइल डाउनलोड और `cache` करेगा। अगली बार जब आपको उस फ़ाइल की आवश्यकता होगी, तो यह आपके `cache` से लोड हो जाएगी, इसलिए आपको इसे फिर से डाउनलोड करने की आवश्यकता नहीं है। आपको उस फ़ाइल की रिपॉजिटरी आईडी और फ़ाइल नाम की आवश्यकता होगी जिसे आप डाउनलोड करना चाहते हैं। उदाहरण के लिए, [Pegasus](https://huggingface.co/google/pegasus-xsum) मॉडल कॉन्फ़िगरेशन फ़ाइल डाउनलोड करने के लिए: ```py >>> from huggingface_hub import hf_hub_download >>> hf_hub_download(repo_id="google/pegasus-xsum", filename="config.json") ``` फ़ाइल के किसी विशिष्ट संस्करण को डाउनलोड करने के लिए, शाखा नाम, टैग या कमिट हैश निर्दिष्ट करने के लिए `revision` पैरामीटर का उपयोग करें। यदि आप कमिट हैश का उपयोग करना चुनते हैं, तो यह छोटे 7-वर्ण कमिट हैश के बजाय पूर्ण-लंबाई वाला हैश होना चाहिए: ```py >>> from huggingface_hub import hf_hub_download >>> hf_hub_download( ... repo_id="google/pegasus-xsum", ... filename="config.json", ... revision="4d33b01d79672f27f001f6abade33f22d993b151" ... ) ``` अधिक विवरण और विकल्पों के लिए, [`hf_hub_download`] के लिए एपीआई संदर्भ देखें। ## प्रमाणीकरण कई मामलों में, हब के साथ इंटरैक्ट करने के लिए आपको `Hugging Face` खाते से प्रमाणित होना होगा: निजी रेपो डाउनलोड करें, फ़ाइलें अपलोड करें, पीआर बनाएं,... [एक खाता बनाएं](https://huggingface.co/join), यदि आपके पास पहले से कोई खाता नहीं है| और फिर अपना [User Access Token](https://huggingface.co/docs/hub/security-tokens) प्राप्त करने के लिए साइन इन करें, आपके [सेटिंग्स पेज](https://huggingface.co/settings/tokens) से। उपयोगकर्ता एक्सेस टोकन का उपयोग हब पर आपकी पहचान को प्रमाणित करने के लिए किया जाता है। टोकन में `read` या `write` की अनुमतियाँ हो सकती हैं। यदि आप कोई रिपॉजिटरी बनाना या संपादित करना चाहते हैं तो सुनिश्चित करें कि आपके पास `write` का एक्सेस टोकन है। अन्यथा, अनजाने में आपके टोकन के लीक होने की स्थिति में जोखिम को कम करने के लिए `read` का टोकन जनरेट करना सबसे अच्छा है। ### लॉगिन कमांड प्रमाणित करने का सबसे आसान तरीका टोकन को अपनी मशीन पर सहेजना है। आप [`login`] कमांड का उपयोग करके टर्मिनल से ऐसा कर सकते हैं: ```bash huggingface-cli login ``` कमांड आपको बताएगा कि क्या आप पहले से लॉग इन हैं और आपसे आपके टोकन के लिए पूछेगा। फिर टोकन को मान्य किया जाता है और आपकी `HF_HOME` निर्देशिका (डिफ़ॉल्ट रूप से `~/.cache/huggingface/token`) में सहेजा जाता है। हब के साथ इंटरैक्ट करने वाला कोई भी स्क्रिप्ट या लाइब्रेरी अनुरोध भेजते समय इस टोकन का उपयोग करेगा। वैकल्पिक रूप से, आप किसी नोटबुक या स्क्रिप्ट में [`login`] का उपयोग करके प्रोग्रामेटिक रूप से लॉगिन कर सकते हैं: ```py >>> from huggingface_hub import login >>> login() ``` आप एक समय में केवल एक ही खाते में लॉग इन कर सकते हैं। नए खाते में लॉग इन करने से आप स्वचालित रूप से पिछले खाते से लॉग आउट हो जाएंगे। अपने वर्तमान में सक्रिय खाते को निर्धारित करने के लिए, बस `huggingface-cli whoami` कमांड चलाएँ। एक बार लॉग इन करने के बाद, हब के सभी अनुरोध - यहां तक ​​कि वे तरीके जिनके लिए आवश्यक रूप से प्रमाणीकरण की आवश्यकता नहीं होती है - डिफ़ॉल्ट रूप से आपके एक्सेस टोकन का उपयोग करेंगे। यदि आप अपने टोकन के निहित उपयोग को अक्षम करना चाहते हैं, तो आपको एक पर्यावरण चर के रूप में `HF_HUB_DISABLE_IMPLICIT_TOKEN=1` सेट करना चाहिए [देखें संदर्भ](../package_reference/environment_variables#hfhubdisableimplicittoken)। ### स्थानीय रूप से कई टोकन प्रबंधित करें आप प्रत्येक टोकन के साथ [`login`] कमांड से लॉग इन करके अपनी मशीन पर कई टोकन सहेज सकते हैं। यदि आपको इन टोकन के बीच स्थानीय रूप से स्विच करने की आवश्यकता है, तो आप [auth switch] कमांड का उपयोग कर सकते हैं: ```bash huggingface-cli auth switch ``` यह कमांड आपको सहेजे गए टोकन की सूची से उसके नाम से एक टोकन चुनने के लिए कहेगा। एक बार चुने जाने के बाद, चुना गया टोकन `_active_` टोकन बन जाता है, और इसका उपयोग हब के साथ सभी इंटरैक्शन के लिए किया जाएगा। आप `huggingface-cli auth list` के साथ अपनी मशीन पर उपलब्ध सभी एक्सेस टोकन सूचीबद्ध कर सकते हैं। ### पर्यावरण चर पर्यावरण चर `HF_TOKEN` का उपयोग स्वयं को प्रमाणित करने के लिए भी किया जा सकता है। यह एक ऐसे स्थान में विशेष रूप से उपयोगी है जहाँ आप `HF_TOKEN` को [Space Secret](https://huggingface.co/docs/hub/spaces-overview#managing-secrets) के रूप में सेट कर सकते हैं। **नया:** Google Colaboratory आपको अपनी नोटबुक के लिए [private keys](https://twitter.com/GoogleColab/status/1719798406195867814) परिभाषित करने देता है। स्वचालित रूप से प्रमाणित होने के लिए एक `HF_TOKEN` रहस्य परिभाषित करें! पर्यावरण चर या रहस्य के माध्यम से प्रमाणीकरण को आपकी मशीन पर संग्रहीत टोकन पर प्राथमिकता दी जाती है। ### मेथड पैरामीटर अंत में, `token` को पैरामीटर के रूप में स्वीकार करने वाली किसी भी विधि में अपना टोकन पास करके प्रमाणित करना भी संभव है। ``` from huggingface_hub import whoami user = whoami(token=...) ``` सामान्यतः इसकी अनुशंसा नहीं की जाती है, सिवाय उन परिस्थितियों में जहाँ आप अपना टोकन स्थायी रूप से संग्रहीत नहीं करना चाहते हैं या यदि आपको एक साथ कई टोकन संभालने की आवश्यकता है। टोकन को पैरामीटर के रूप में पास करते समय कृपया सावधान रहें। अपने कोडबेस या नोटबुक में इसे हार्डकोड करने के बजाय टोकन को एक सुरक्षित वॉल्ट से लोड करना हमेशा सबसे अच्छा अभ्यास होता है। यदि आप अनजाने में अपना कोड साझा करते हैं तो हार्डकोडेड टोकन एक बड़ा रिसाव जोखिम पेश करते हैं। ## रिपॉजिटरी बनाएँ एक बार जब आप पंजीकृत हो जाते हैं और लॉग इन कर लेते हैं, तो [`create_repo`] फ़ंक्शन के साथ एक रिपॉजिटरी बनाएँ: ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.create_repo(repo_id="super-cool-model") ``` यदि आप चाहते हैं कि आपकी रिपॉजिटरी निजी हो, तो: ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.create_repo(repo_id="super-cool-model", private=True) ``` निजी रिपॉजिटरी आपके अलावा किसी और को दिखाई नहीं देंगी। रिपॉजिटरी बनाने या हब पर सामग्री पुश करने के लिए, आपको एक उपयोगकर्ता एक्सेस टोकन प्रदान करना होगा जिसके पास `write` की अनुमति हो। टोकन बनाते समय आप अपने [सेटिंग्स पेज](https://huggingface.co/settings/tokens) में अनुमति चुन सकते हैं। ## फाइलें अपलोड करें अपनी नव निर्मित रिपॉजिटरी में फ़ाइल जोड़ने के लिए [`upload_file`] फ़ंक्शन का उपयोग करें। आप निर्दिष्ट करने की आवश्यकता है: 1. अपलोड करने के लिए फ़ाइल का पथ. 2. रिपोजिटरी में फ़ाइल का पथ. 3. रिपॉजिटरी आईडी जहाँ आप फ़ाइल जोड़ना चाहते हैं। ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.upload_file( ... path_or_fileobj="/home/lysandre/dummy-test/README.md", ... path_in_repo="README.md", ... repo_id="lysandre/test-model", ... ) ``` एक समय में एक से अधिक फ़ाइल अपलोड करने के लिए, [अपलोड](./guides/upload) मार्गदर्शिका पर एक नज़र डालें जो आपको फ़ाइलें अपलोड करने के कई तरीकों से परिचित कराएगा (git के साथ या उसके बिना)। ## अगले कदम `huggingface_hub` लाइब्रेरी उपयोगकर्ताओं को हब के साथ बातचीत करने का एक आसान तरीका प्रदान करती है Python के साथ. हब पर आप अपनी फ़ाइलों और रिपॉजिटरी को कैसे प्रबंधित कर सकते हैं,, इसके बारे में अधिक जानने के लिए, हम अनुशंसा करते हैं कि आप हमारे [कैसे करें मार्गदर्शिकाएं](./guides/अवलोकन) पढ़ें: - [अपना भंडार प्रबंधित करें](./guides/repository)। - हब से [डाउनलोड](./guides/download) फ़ाइलें। - हब पर [अपलोड](./guides/upload) फ़ाइलें। - अपने इच्छित मॉडल या डेटासेट के लिए [हब खोजें](./guides/search)। - तेज अनुमान के लिए [अनुमान एपीआई तक पहुंचें](./guides/अनुमान)। huggingface_hub-0.31.1/docs/source/ko/000077500000000000000000000000001500667546600175615ustar00rootroot00000000000000huggingface_hub-0.31.1/docs/source/ko/_toctree.yml000066400000000000000000000055611500667546600221170ustar00rootroot00000000000000- title: "시작하기" sections: - local: index title: 홈 - local: quick-start title: 둘러보기 - local: installation title: 설치 방법 - title: "How-to 가이드" sections: - local: guides/overview title: 개요 - local: guides/download title: 파일 다운로드하기 - local: guides/upload title: 파일 업로드하기 - local: guides/cli title: 명령줄 인터페이스(CLI) 사용하기 - local: guides/hf_file_system title: Hf파일시스템 - local: guides/repository title: 리포지토리 - local: guides/search title: Hub에서 검색하기 - local: guides/inference title: 추론 - local: guides/inference_endpoints title: 추론 엔드포인트 - local: guides/community title: 커뮤니티 - local: guides/collections title: Collections - local: guides/manage-cache title: 캐시 관리 - local: guides/model-cards title: 모델 카드 - local: guides/manage-spaces title: Space 관리 - local: guides/integrations title: 라이브러리 통합 - local: guides/webhooks_server title: 웹훅 서버 - title: "개념 가이드" sections: - local: concepts/git_vs_http title: Git 대 HTTP 패러다임 - title: "라이브러리 레퍼런스" sections: - local: package_reference/overview title: 개요 - local: package_reference/login title: 로그인 및 로그아웃 - local: package_reference/environment_variables title: 환경 변수 - local: package_reference/repository title: 로컬 및 온라인 리포지토리 관리 - local: package_reference/hf_api title: 허깅페이스 Hub API - local: package_reference/file_download title: 파일 다운로드하기 - local: package_reference/mixins title: 믹스인 & 직렬화 메소드 - local: package_reference/inference_types title: 추론 타입 - local: package_reference/inference_client title: 추론 클라이언트 - local: package_reference/inference_endpoints title: 추론 엔드포인트 - local: package_reference/hf_file_system title: Hf파일시스템 - local: package_reference/utilities title: 유틸리티 - local: package_reference/community title: Discussions 및 Pull Requests - local: package_reference/cache title: 캐시 시스템 참조 - local: package_reference/cards title: Repo Cards 와 Repo Card Data - local: package_reference/collections title: 컬렉션 관리 - local: package_reference/space_runtime title: Space 런타임 - local: package_reference/tensorboard title: TensorBoard 로거 - local: package_reference/webhooks_server title: 웹훅 서버 - local: package_reference/serialization title: 직렬화huggingface_hub-0.31.1/docs/source/ko/concepts/000077500000000000000000000000001500667546600213775ustar00rootroot00000000000000huggingface_hub-0.31.1/docs/source/ko/concepts/git_vs_http.md000066400000000000000000000111371500667546600242560ustar00rootroot00000000000000 # Git 대 HTTP 패러다임 `huggingface_hub` 라이브러리는 git 기반의 저장소(Models, Datasets 또는 Spaces)로 구성된 Hugging Face Hub과 상호 작용하기 위한 라이브러리입니다. `huggingface_hub`를 사용하여 Hub에 접근하는 방법은 크게 두 가지입니다. 첫 번째 접근 방식인 소위 "git 기반" 접근 방식은 [`Repository`] 클래스가 주도합니다. 이 방법은 허브와 상호 작용하도록 특별히 설계된 추가 기능이 있는 `git` 명령에 랩퍼를 사용합니다. 두 번째 방법은 "HTTP 기반" 접근 방식이며, [`HfApi`] 클라이언트를 사용하여 HTTP 요청을 수행합니다. 각 방법의 장단점을 살펴보겠습니다. ## Repository: 역사적인 Git 기반 접근 방식 먼저, `huggingface_hub`는 주로 [`Repository`] 클래스를 기반으로 구축되었습니다. 이 클래스는 `"git add"`, `"git commit"`, `"git push"`, `"git tag"`, `"git checkout"` 등과 같은 일반적인 `git` 명령에 대한 Python 랩퍼를 제공합니다. 이 라이브러리는 머신러닝 저장소에서 자주 사용되는 큰 파일을 추적하고 자격 증명을 설정하는 데 도움이 됩니다. 또한, 이 라이브러리는 백그라운드에서 메소드를 실행할 수 있어, 훈련 중에 데이터를 업로드할 때 유용합니다. 로컬 머신에 전체 저장소의 로컬 복사본을 유지할 수 있다는 것은 [`Repository`]를 사용하는 가장 큰 장점입니다. 하지만 동시에 로컬 복사본을 지속적으로 업데이트하고 유지해야 한다는 단점이 될 수도 있습니다. 이는 각 개발자가 자체 로컬 복사본을 유지하고 기능을 개발할 때 변경 사항을 push하는 전통적인 소프트웨어 개발과 유사합니다. 그러나 머신러닝의 경우, 사용자가 전체 저장소를 복제할 필요 없이 추론을 위해 가중치만 다운로드하거나 가중치를 한 형식에서 다른 형식으로 변환하기만 하면 되기 때문에 이런 방식이 항상 필요한 것은 아닙니다. [`Repository`]는 지원이 중단될 예정이므로 HTTP 기반 대안을 사용하는 것을 권장합니다. 기존 코드에서 널리 사용되기 때문에 [`Repository`]의 완전한 제거는 릴리스 `v1.0`에서 이루어질 예정입니다. ## HfApi: 유연하고 편리한 HTTP 클라이언트 [`HfApi`] 클래스는 특히 큰 모델이나 데이터셋을 처리할 때 유지하기 어려운 로컬 git 저장소의 대안으로 개발되었습니다. [`HfApi`] 클래스는 파일 다운로드 및 push, 브랜치 및 태그 생성과 같은 git 기반 접근 방식과 동일한 기능을 제공하지만, 동기화 상태를 유지해야 하는 로컬 폴더가 필요하지 않습니다. [`HfApi`] 클래스는 `git`이 제공하는 기능 외에도 추가적인 기능을 제공합니다. 저장소를 관리하고, 효율적인 재사용을 위해 캐싱을 사용하여 파일을 다운로드하고, Hub에서 저장소 및 메타데이터를 검색하고, 토론, PR 및 코멘트와 같은 커뮤니티 기능에 접근하고, Spaces 하드웨어 및 시크릿을 구성할 수 있습니다. ## 무엇을 사용해야 하나요? 언제 사용하나요? 전반적으로, **HTTP 기반 접근 방식은 모든 경우에** `huggingface_hub`를 사용하는 것이 좋습니다. [`HfApi`]를 사용하면 변경 사항을 pull하고 push하고, PR, 태그 및 브랜치로 작업하고, 토론과 상호 작용하는 등의 작업을 할 수 있습니다. `0.16` 릴리스부터는 [`Repository`] 클래스의 마지막 주요 장점이었던 http 기반 메소드도 백그라운드에서 실행할 수 있습니다. 그러나 모든 git 명령이 [`HfApi`]를 통해 사용 가능한 것은 아닙니다. 일부는 구현되지 않을 수도 있지만, 저희는 항상 개선하고 격차를 줄이기 위해 노력하고 있습니다. 사용 사례에 해당되지 않는 경우, [Github에서 이슈](https://github.com/huggingface/huggingface_hub)를 개설해 주세요! 사용자와 함께, 사용자를 위한 🤗 생태계를 구축하는 데 도움이 되는 피드백을 환영합니다. git 기반 [`Repository`]보다 http 기반 [`HfApi`]를 선호한다고 해서 Hugging Face Hub에서 git 버전 관리가 바로 사라지는 것은 아닙니다. 워크플로우 상 합당하다면 언제든 로컬에서 `git` 명령을 사용할 수 있습니다. huggingface_hub-0.31.1/docs/source/ko/guides/000077500000000000000000000000001500667546600210415ustar00rootroot00000000000000huggingface_hub-0.31.1/docs/source/ko/guides/cli.md000066400000000000000000000617141500667546600221430ustar00rootroot00000000000000 # 명령줄 인터페이스 (CLI) [[command-line-interface]] `huggingface_hub` Python 패키지는 `huggingface-cli`라는 내장 CLI를 함께 제공합니다. 이 도구를 사용하면 터미널에서 Hugging Face Hub와 직접 상호 작용할 수 있습니다. 계정에 로그인하고, 리포지토리를 생성하고, 파일을 업로드 및 다운로드하는 등의 다양한 작업을 수행할 수 있습니다. 또한 머신을 구성하거나 캐시를 관리하는 데 유용한 기능도 제공합니다. 이 가이드는 CLI의 주요 기능과 사용 방법에 관해 설명합니다. ## 시작하기 [[getting-started]] 먼저, CLI를 설치해 보세요: ``` >>> pip install -U "huggingface_hub[cli]" ``` 위의 코드에서 사용자 경험을 높이기 위해 `[cli]` 추가 종속성을 포함하였습니다. 이는 `delete-cache` 명령을 사용할 때 특히 유용합니다. 설치가 완료되면, CLI가 올바르게 설정되었는지 확인할 수 있습니다: ``` >>> huggingface-cli --help usage: huggingface-cli [] positional arguments: {env,login,whoami,logout,repo,upload,download,lfs-enable-largefiles,lfs-multipart-upload,scan-cache,delete-cache} huggingface-cli command helpers env Print information about the environment. login Log in using a token from huggingface.co/settings/tokens whoami Find out which huggingface.co account you are logged in as. logout Log out repo {create} Commands to interact with your huggingface.co repos. upload Upload a file or a folder to a repo on the Hub download Download files from the Hub lfs-enable-largefiles Configure your repository to enable upload of files > 5GB. scan-cache Scan cache directory. delete-cache Delete revisions from the cache directory. options: -h, --help show this help message and exit ``` CLI가 제대로 설치되었다면 CLI에서 사용 가능한 모든 옵션 목록이 출력됩니다. `command not found: huggingface-cli`와 같은 오류 메시지가 표시된다면 [설치](../installation) 가이드를 확인하세요. `--help` 옵션을 사용하면 명령어에 대한 자세한 정보를 얻을 수 있습니다. 언제든지 사용 가능한 모든 옵션과 그 세부 사항을 확인할 수 있습니다. 예를 들어 `huggingface-cli upload --help`는 CLI를 사용하여 파일을 업로드하는 구체적인 방법을 알려줍니다. ### 다른 방법으로 설치하기 [[alternative-install]] #### pkgx 사용하기 [[using-pkgx]] [Pkgx](https://pkgx.sh)는 다양한 플랫폼에서 빠르게 작동하는 패키지 매니저입니다. 다음과 같이 pkgx를 사용하여 huggingface-cli를 설치할 수 있습니다: ```bash >>> pkgx install huggingface-cli ``` 또는 pkgx를 통해 huggingface-cli를 직접 실행할 수도 있습니다: ```bash >>> pkgx huggingface-cli --help ``` pkgx huggingface에 대한 자세한 내용은 [여기](https://pkgx.dev/pkgs/huggingface.co/)에서 확인할 수 있습니다. #### Homebrew 사용하기 [[using-homebrew]] [Homebrew](https://brew.sh/)를 사용하여 CLI를 설치할 수도 있습니다: ```bash >>> brew install huggingface-cli ``` Homebrew huggingface에 대한 자세한 내용은 [여기](https://formulae.brew.sh/formula/huggingface-cli)에서 확인할 수 있습니다. ## huggingface-cli login [[huggingface-cli-login]] Hugging Face Hub에 접근하는 대부분의 작업(비공개 리포지토리 액세스, 파일 업로드, PR 제출 등)을 위해서는 Hugging Face 계정에 로그인해야 합니다. 로그인을 하기 위해서 [설정 페이지](https://huggingface.co/settings/tokens)에서 생성한 [사용자 액세스 토큰](https://huggingface.co/docs/hub/security-tokens)이 필요하며, 이 토큰은 Hub에서의 사용자 인증에 사용됩니다. 파일 업로드나 콘텐츠 수정을 위해선 쓰기 권한이 있는 토큰이 필요합니다. 토큰을 받은 후에 터미널에서 다음 명령을 실행하세요: ```bash >>> huggingface-cli login ``` 이 명령은 토큰을 입력하라는 메시지를 표시합니다. 토큰을 복사하여 붙여넣고 Enter 키를 입력합니다. 그런 다음 토큰을 git 자격 증명으로 저장할지 묻는 메시지가 표시됩니다. 로컬에서 `git`을 사용할 계획이라면 Enter 키를 입력합니다(기본값은 yes). 마지막으로 Hub에서 토큰의 유효성을 검증한 후 로컬에 저장합니다. ``` _| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| _|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_| _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| _| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_| To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens . Token: Add token as git credential? (Y/n) Token is valid (permission: write). Your token has been saved in your configured git credential helpers (store). Your token has been saved to /home/wauplin/.cache/huggingface/token Login successful ``` 프롬프트를 거치지 않고 바로 로그인하고 싶다면, 명령줄에서 토큰을 직접 입력할 수도 있습니다. 하지만 보안을 더욱 강화하기 위해서는 명령 기록에 토큰을 남기지 않고, 환경 변수를 통해 토큰을 전달하는 방법이 바람직합니다. ```bash # Or using an environment variable >>> huggingface-cli login --token $HUGGINGFACE_TOKEN --add-to-git-credential Token is valid (permission: write). Your token has been saved in your configured git credential helpers (store). Your token has been saved to /home/wauplin/.cache/huggingface/token Login successful ``` [이 단락](../quick-start#authentication)에서 인증에 대한 더 자세한 내용을 확인할 수 있습니다. ## huggingface-cli whoami [[huggingface-cli-whoami]] 로그인 여부를 확인하기 위해 `huggingface-cli whoami` 명령어를 사용할 수 있습니다. 이 명령어는 옵션이 없으며, 간단하게 사용자 이름과 소속된 조직들을 출력합니다: ```bash huggingface-cli whoami Wauplin orgs: huggingface,eu-test,OAuthTesters,hf-accelerate,HFSmolCluster ``` 로그인하지 않은 경우 오류 메시지가 출력됩니다. ## huggingface-cli logout [[huggingface-cli-logout]] 이 명령어를 사용하여 로그아웃할 수 있습니다. 실제로는 컴퓨터에 저장된 토큰을 삭제합니다. 하지만 `HF_TOKEN` 환경 변수를 사용하여 로그인했다면, 이 명령어로는 로그아웃할 수 없습니다([참조]((../package_reference/environment_variables#hftoken))). 대신 컴퓨터의 환경 설정에서 `HF_TOKEN` 변수를 제거하면 됩니다. ## huggingface-cli download [[huggingface-cli-download]] `huggingface-cli download` 명령어를 사용하여 Hub에서 직접 파일을 다운로드할 수 있습니다. [다운로드](./download) 가이드에서 설명된 [`hf_hub_download`], [`snapshot_download`] 헬퍼 함수를 사용하여 반환된 경로를 터미널에 출력합니다. 우리는 아래 예시에서 가장 일반적인 사용 사례를 살펴볼 것입니다. 사용 가능한 모든 옵션을 보려면 아래 명령어를 실행해보세요: ```bash huggingface-cli download --help ``` ### 파일 한 개 다운로드하기 [[download-a-single-file]] 리포지토리에서 파일 하나를 다운로드하고 싶다면, repo_id와 다운받고 싶은 파일명을 아래와 같이 입력하세요: ```bash >>> huggingface-cli download gpt2 config.json downloading https://huggingface.co/gpt2/resolve/main/config.json to /home/wauplin/.cache/huggingface/hub/tmpwrq8dm5o (…)ingface.co/gpt2/resolve/main/config.json: 100%|██████████████████████████████████| 665/665 [00:00<00:00, 2.49MB/s] /home/wauplin/.cache/huggingface/hub/models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10/config.json ``` 이 명령어를 실행하면 항상 마지막 줄에 파일 경로를 출력합니다. ### 전체 리포지토리 다운로드하기 [[download-an-entire-repository]] 리포지토리의 모든 파일을 다운로드하고 싶을 때에는 repo id만 입력하면 됩니다: ```bash >>> huggingface-cli download HuggingFaceH4/zephyr-7b-beta Fetching 23 files: 0%| | 0/23 [00:00>> huggingface-cli download gpt2 config.json model.safetensors Fetching 2 files: 0%| | 0/2 [00:00>> huggingface-cli download stabilityai/stable-diffusion-xl-base-1.0 --include "*.safetensors" --exclude "*.fp16.*"* Fetching 8 files: 0%| | 0/8 [00:00>> huggingface-cli download HuggingFaceH4/ultrachat_200k --repo-type dataset # https://huggingface.co/spaces/HuggingFaceH4/zephyr-chat >>> huggingface-cli download HuggingFaceH4/zephyr-chat --repo-type space ... ``` ### 특정 리비전 다운로드하기 [[download-a-specific-revision]] 따로 리비전을 지정하지 않는다면 기본적으로 main 브랜치의 최신 커밋에서 파일을 다운로드합니다. 특정 리비전(커밋 해시, 브랜치 이름 또는 태그)에서 다운로드하려면 `--revision` 옵션을 사용하세요: ```bash >>> huggingface-cli download bigcode/the-stack --repo-type dataset --revision v1.1 ... ``` ### 로컬 폴더에 다운로드하기 [[download-to-a-local-folder]] Hub에서 파일을 다운로드하는 권장되고 기본적인 방법은 캐시 시스템을 사용하는 것입니다. 그러나 특정한 경우에는 파일을 지정된 폴더로 다운로드하고 옮기고 싶을 수 있습니다. 이는 git 명령어와 유사한 워크플로우를 만드는데 도움이 됩니다. `--local_dir` 옵션을 사용하여 이 작업을 수행할 수 있습니다. 로컬 폴더에 다운로드하는 것에는 몇 가지 단점이 있습니다. `--local-dir` 명령어를 사용하기 전에 [다운로드](./download#download-files-to-local-folder) 가이드에서 해당 내용을 확인해보세요. ```bash >>> huggingface-cli download adept/fuyu-8b model-00001-of-00002.safetensors --local-dir . ... ./model-00001-of-00002.safetensors ``` ### 캐시 디렉터리 지정하기 [[specify-cache-directory]] 기본적으로 모든 파일은 `HF_HOME` [환경 변수](../package_reference/environment_variables#hfhome)에서 정의한 캐시 디렉터리에 다운로드됩니다. `--cache-dir`을 사용하여 직접 캐시 위치를 지정할 수 있습니다: ```bash >>> huggingface-cli download adept/fuyu-8b --cache-dir ./path/to/cache ... ./path/to/cache/models--adept--fuyu-8b/snapshots/ddcacbcf5fdf9cc59ff01f6be6d6662624d9c745 ``` ### 토큰 설정하기 [[specify-a-token]] 비공개 또는 접근이 제한된 리포지토리들에 접근하기 위해서는 토큰이 필요합니다. 기본적으로 로컬에 저장된 토큰(`huggingface-cli login`)이 사용됩니다. 직접 인증하고 싶다면 `--token` 옵션을 사용해보세요: ```bash >>> huggingface-cli download gpt2 config.json --token=hf_**** /home/wauplin/.cache/huggingface/hub/models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10/config.json ``` ### 조용한 모드 [[quiet-mode]] `huggingface-cli download` 명령은 상세한 정보를 출력합니다. 경고 메시지, 다운로드된 파일 정보, 진행률 등이 포함됩니다. 이 모든 출력을 숨기려면 `--quiet` 옵션을 사용하세요. 이 옵션을 사용하면 다운로드된 파일의 경로가 표시되는 마지막 줄만 출력됩니다. 이 기능은 스크립트에서 다른 명령어로 출력을 전달하고자 할 때 유용할 수 있습니다. ```bash >>> huggingface-cli download gpt2 --quiet /home/wauplin/.cache/huggingface/hub/models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10 ``` ## huggingface-cli upload [[huggingface-cli-upload]] `huggingface-cli upload` 명령어로 Hub에 직접 파일을 업로드할 수 있습니다. [업로드](./upload) 가이드에서 설명된 [`upload_file`], [`upload_folder`] 헬퍼 함수를 사용합니다. 우리는 아래 예시에서 가장 일반적인 사용 사례를 살펴볼 것입니다. 사용 가능한 모든 옵션을 보려면 아래 명령어를 실행해보세요: ```bash >>> huggingface-cli upload --help ``` ### 전체 폴더 업로드하기 [[upload-an-entire-folder]] 이 명령어의 기본 사용법은 다음과 같습니다: ```bash # Usage: huggingface-cli upload [repo_id] [local_path] [path_in_repo] ``` 현재 디텍터리를 리포지토리의 루트 위치에 업로드하려면, 아래 명령어를 사용하세요: ```bash >>> huggingface-cli upload my-cool-model . . https://huggingface.co/Wauplin/my-cool-model/tree/main/ ``` 리포지토리가 아직 존재하지 않으면 자동으로 생성됩니다. 또한, 특정 폴더만 업로드하는 것도 가능합니다: ```bash >>> huggingface-cli upload my-cool-model ./models . https://huggingface.co/Wauplin/my-cool-model/tree/main/ ``` 마지막으로, 리포지토리의 특정 위치에 폴더를 업로드할 수 있습니다: ```bash >>> huggingface-cli upload my-cool-model ./path/to/curated/data /data/train https://huggingface.co/Wauplin/my-cool-model/tree/main/data/train ``` ### 파일 한 개 업로드하기 [[upload-a-single-file]] 컴퓨터에 있는 파일을 가리키도록 `local_path`를 설정함으로써 파일 한 개를 업로드할 수 있습니다. 이때, `path_in_repo`는 선택사항이며 로컬 파일 이름을 기본값으로 사용합니다: ```bash >>> huggingface-cli upload Wauplin/my-cool-model ./models/model.safetensors https://huggingface.co/Wauplin/my-cool-model/blob/main/model.safetensors ``` 파일 한 개를 특정 디렉터리에 업로드하고 싶다면, `path_in_repo`를 그에 맞게 설정하세요: ```bash >>> huggingface-cli upload Wauplin/my-cool-model ./models/model.safetensors /vae/model.safetensors https://huggingface.co/Wauplin/my-cool-model/blob/main/vae/model.safetensors ``` ### 여러 파일 업로드하기 [[upload-multiple-files]] 전체 폴더를 업로드하지 않고 한 번에 여러 파일을 업로드하려면 `--include`와 `--exclude` 옵션을 사용해보세요. 리포지토리에 있는 파일을 삭제하면서 새 파일을 업로드하는 `--delete` 옵션과 함께 사용할 수 있습니다. 아래 예시는 `/logs` 안의 파일을 제외한 모든 파일을 업로드하고 원격 파일들을 삭제함으로써 로컬 Space를 동기화하는 방법을 보여줍니다: ```bash # Sync local Space with Hub (upload new files except from logs/, delete removed files) >>> huggingface-cli upload Wauplin/space-example --repo-type=space --exclude="/logs/*" --delete="*" --commit-message="Sync local Space with Hub" ... ``` ### 데이터 세트 또는 Space에 업로드하기 [[upload-to-a-dataset-or-space]] 데이터 세트나 Space에 업로드하려면 `--repo-type` 옵션을 사용하세요: ```bash >>> huggingface-cli upload Wauplin/my-cool-dataset ./data /train --repo-type=dataset ... ``` ### 조직에 업로드하기 [[upload-to-an-organization]] 개인 리포지토리 대신 조직이 소유한 리포지토리에 파일을 업로드하려면 `repo_id`를 입력해야 합니다: ```bash >>> huggingface-cli upload MyCoolOrganization/my-cool-model . . https://huggingface.co/MyCoolOrganization/my-cool-model/tree/main/ ``` ### 특정 개정에 업로드하기 [[upload-to-a-specific-revision]] 기본적으로 파일은 `main` 브랜치에 업로드됩니다. 다른 브랜치나 참조에 파일을 업로드하려면 `--revision` 옵션을 사용하세요: ```bash # Upload files to a PR >>> huggingface-cli upload bigcode/the-stack . . --repo-type dataset --revision refs/pr/104 ... ``` **참고:** `revision`이 존재하지 않고 `--create-pr` 옵션이 설정되지 않은 경우, `main` 브랜치에서 자동으로 새 브랜치가 생성됩니다. ### 업로드 및 PR 생성하기 [[upload-and-create-a-pr]] 리포지토리에 푸시할 권한이 없다면, PR을 생성하여 작성자들에게 변경하고자 하는 내용을 알려야 합니다. 이를 위해서 `--create-pr` 옵션을 사용할 수 있습니다: ```bash # Create a PR and upload the files to it >>> huggingface-cli upload bigcode/the-stack . . --repo-type dataset --revision refs/pr/104 https://huggingface.co/datasets/bigcode/the-stack/blob/refs%2Fpr%2F104/ ``` ### 정기적으로 업로드하기 [[upload-at-regular-intervals]] 리포지토리에 정기적으로 업데이트하고 싶을 때, `--every` 옵션을 사용할 수 있습니다. 예를 들어, 모델을 훈련하는 중에 로그 폴더를 10분마다 업로드하고 싶다면 다음과 같이 사용할 수 있습니다: ```bash # Upload new logs every 10 minutes huggingface-cli upload training-model logs/ --every=10 ``` ### 커밋 메시지 지정하기 [[specify-a-commit-message]] `--commit-message`와 `--commit-description`을 사용하여 기본 메시지 대신 사용자 지정 메시지와 설명을 커밋에 설정하세요: ```bash >>> huggingface-cli upload Wauplin/my-cool-model ./models . --commit-message="Epoch 34/50" --commit-description="Val accuracy: 68%. Check tensorboard for more details." ... https://huggingface.co/Wauplin/my-cool-model/tree/main ``` ### 토큰 지정하기 [[specify-a-token]] 파일을 업로드하려면 토큰이 필요합니다. 기본적으로 로컬에 저장된 토큰(`huggingface-cli login`)이 사용됩니다. 직접 인증하고 싶다면 `--token` 옵션을 사용해보세요: ```bash >>> huggingface-cli upload Wauplin/my-cool-model ./models . --token=hf_**** ... https://huggingface.co/Wauplin/my-cool-model/tree/main ``` ### 조용한 모드 [[quiet-mode]] 기본적으로 `huggingface-cli upload` 명령은 상세한 정보를 출력합니다. 경고 메시지, 업로드된 파일 정보, 진행률 등이 포함됩니다. 이 모든 출력을 숨기려면 `--quiet` 옵션을 사용하세요. 이 옵션을 사용하면 업로드된 파일의 URL이 표시되는 마지막 줄만 출력됩니다. 이 기능은 스크립트에서 다른 명령어로 출력을 전달하고자 할 때 유용할 수 있습니다. ```bash >>> huggingface-cli upload Wauplin/my-cool-model ./models . --quiet https://huggingface.co/Wauplin/my-cool-model/tree/main ``` ## huggingface-cli scan-cache [[huggingface-cli-scan-cache]] 캐시 디렉토리를 스캔하여 다운로드한 리포지토리가 무엇인지와 디스크에서 차지하는 공간을 알 수 있습니다. `huggingface-cli scan-cache` 명령어를 사용하여 이를 확인해보세요: ```bash >>> huggingface-cli scan-cache REPO ID REPO TYPE SIZE ON DISK NB FILES LAST_ACCESSED LAST_MODIFIED REFS LOCAL PATH --------------------------- --------- ------------ -------- ------------- ------------- ------------------- ------------------------------------------------------------------------- glue dataset 116.3K 15 4 days ago 4 days ago 2.4.0, main, 1.17.0 /home/wauplin/.cache/huggingface/hub/datasets--glue google/fleurs dataset 64.9M 6 1 week ago 1 week ago refs/pr/1, main /home/wauplin/.cache/huggingface/hub/datasets--google--fleurs Jean-Baptiste/camembert-ner model 441.0M 7 2 weeks ago 16 hours ago main /home/wauplin/.cache/huggingface/hub/models--Jean-Baptiste--camembert-ner bert-base-cased model 1.9G 13 1 week ago 2 years ago /home/wauplin/.cache/huggingface/hub/models--bert-base-cased t5-base model 10.1K 3 3 months ago 3 months ago main /home/wauplin/.cache/huggingface/hub/models--t5-base t5-small model 970.7M 11 3 days ago 3 days ago refs/pr/1, main /home/wauplin/.cache/huggingface/hub/models--t5-small Done in 0.0s. Scanned 6 repo(s) for a total of 3.4G. Got 1 warning(s) while scanning. Use -vvv to print details. ``` 캐시 디렉토리 스캔에 대한 자세한 내용을 알고 싶다면, [캐시 관리](./manage-cache#scan-cache-from-the-terminal) 가이드를 확인해보세요. ## huggingface-cli delete-cache [[huggingface-cli-delete-cache]] 사용하지 않는 캐시를 삭제하고 싶다면 `huggingface-cli delete-cache`를 사용해보세요. 이는 디스크 공간을 절약하고 확보하는 데 유용합니다. 이에 대한 자세한 내용은 [캐시 관리](./manage-cache#clean-cache-from-the-terminal) 가이드에서 확인할 수 있습니다. ## huggingface-cli env [[huggingface-cli-env]] `huggingface-cli env` 명령어는 사용자의 컴퓨터 설정에 대한 상세한 정보를 보여줍니다. 이는 [GitHub](https://github.com/huggingface/huggingface_hub)에서 문제를 제출할 때, 관리자가 문제를 파악하고 해결하는 데 도움이 됩니다. ```bash >>> huggingface-cli env Copy-and-paste the text below in your GitHub issue. - huggingface_hub version: 0.19.0.dev0 - Platform: Linux-6.2.0-36-generic-x86_64-with-glibc2.35 - Python version: 3.10.12 - Running in iPython ?: No - Running in notebook ?: No - Running in Google Colab ?: No - Token path ?: /home/wauplin/.cache/huggingface/token - Has saved token ?: True - Who am I ?: Wauplin - Configured git credential helpers: store - FastAI: N/A - Tensorflow: 2.11.0 - Torch: 1.12.1 - Jinja2: 3.1.2 - Graphviz: 0.20.1 - Pydot: 1.4.2 - Pillow: 9.2.0 - hf_transfer: 0.1.3 - gradio: 4.0.2 - tensorboard: 2.6 - numpy: 1.23.2 - pydantic: 2.4.2 - aiohttp: 3.8.4 - ENDPOINT: https://huggingface.co - HF_HUB_CACHE: /home/wauplin/.cache/huggingface/hub - HF_ASSETS_CACHE: /home/wauplin/.cache/huggingface/assets - HF_TOKEN_PATH: /home/wauplin/.cache/huggingface/token - HF_HUB_OFFLINE: False - HF_HUB_DISABLE_TELEMETRY: False - HF_HUB_DISABLE_PROGRESS_BARS: None - HF_HUB_DISABLE_SYMLINKS_WARNING: False - HF_HUB_DISABLE_EXPERIMENTAL_WARNING: False - HF_HUB_DISABLE_IMPLICIT_TOKEN: False - HF_HUB_ENABLE_HF_TRANSFER: False - HF_HUB_ETAG_TIMEOUT: 10 - HF_HUB_DOWNLOAD_TIMEOUT: 10 ``` huggingface_hub-0.31.1/docs/source/ko/guides/collections.md000066400000000000000000000252671500667546600237150ustar00rootroot00000000000000 # Collections[[collections]] Collection은 Hub(모델, 데이터셋, Spaces, 논문)에 있는 관련 항목들의 그룹으로, 같은 페이지에 함께 구성되어 있습니다. Collections는 자신만의 포트폴리오를 만들거나, 카테고리별로 콘텐츠를 북마크 하거나, 공유하고 싶은 item들의 큐레이팅 된 목록을 제시하는 데 유용합니다. 여기 [가이드](https://huggingface.co/docs/hub/collections)를 확인하여 Collections가 무엇이고 Hub에서 어떻게 보이는지 자세히 알아보세요. 브라우저에서 직접 Collections를 관리할 수 있지만, 이 가이드에서는 프로그래밍 방식으로 Collection을 관리하는 방법에 초점을 맞추겠습니다. ## Collection 가져오기[[fetch-a-collection]] [`get_collection`]을 사용하여 자신의 Collections나 공개된 Collection을 가져올 수 있습니다. Collection을 가져오려면 Collection의 *slug*가 필요합니다. Slug는 제목과 고유한 ID를 기반으로 한 Collection의 식별자입니다. Collection 페이지의 URL에서 slug를 찾을 수 있습니다.
`"TheBloke/recent-models-64f9a55bb3115b4f513ec026"` Collection을 가져와 봅시다: ```py >>> from huggingface_hub import get_collection >>> collection = get_collection("TheBloke/recent-models-64f9a55bb3115b4f513ec026") >>> collection Collection( slug='TheBloke/recent-models-64f9a55bb3115b4f513ec026', title='Recent models', owner='TheBloke', items=[...], last_updated=datetime.datetime(2023, 10, 2, 22, 56, 48, 632000, tzinfo=datetime.timezone.utc), position=1, private=False, theme='green', upvotes=90, description="Models I've recently quantized. Please note that currently this list has to be updated manually, and therefore is not guaranteed to be up-to-date." ) >>> collection.items[0] CollectionItem( item_object_id='651446103cd773a050bf64c2', item_id='TheBloke/U-Amethyst-20B-AWQ', item_type='model', position=88, note=None ) ``` [`get_collection`]에 의해 반환된 [`Collection`] 객체에는 다음이 포함되어 있습니다: - 높은 수준의 메타데이터: `slug`, `owner`, `title`, `description` 등 - [`CollectionItem`] 객체의 목록; 각 항목은 모델, 데이터셋, Space 또는 논문을 나타냅니다. 모든 Collection 항목에는 다음이 보장됩니다: - 고유한 `item_object_id`: 데이터베이스에서 Collection 항목의 id - 기본 항목(모델, 데이터셋, Space, 논문)의 Hub에서의 `item_id`; 고유하지 않으며, `item_id`/`item_type` 쌍만 고유합니다. - `item_type`: 모델, 데이터셋, Space, 논문 - Collection에서 항목의 `position`으로, 이를 업데이트하여 Collection을 재구성할 수 있습니다(아래의 [`update_collection_item`] 참조) 각 항목에는 추가 정보(코멘트, 블로그 포스트 링크 등)를 위한 `note`도 첨부될 수 있습니다. 항목에 note가 없으면 해당 속성값은 `None`이 됩니다. 이러한 기본 속성 외에도, 반환된 항목은 유형에 따라 추가 속성(`author`, `private`, `lastModified`, `gated`, `title`, `likes`, `upvotes` 등)을 가질 수 있습니다. 그러나 이러한 속성이 반환된다는 보장은 없습니다. ## Collections 나열하기[[fetch-a-collection]] [`list_collections`]를 사용하여 Collections를 나열할 수도 있습니다. Collections는 몇 가지 매개변수를 사용하여 필터링할 수 있습니다. 사용자 [`teknium`](https://huggingface.co/teknium)의 모든 Collections를 나열해 봅시다. ```py >>> from huggingface_hub import list_collections >>> collections = list_collections(owner="teknium") ``` 이렇게 하면 `Collection` 객체의 반복 가능한 객체가 반환됩니다. 예를 들어 각 Collection의 upvotes 수를 출력하기 위해 반복할 수 있습니다. ```py >>> for collection in collections: ... print("Number of upvotes:", collection.upvotes) Number of upvotes: 1 Number of upvotes: 5 ``` Collections를 나열할 때, 각 Collection의 항목 목록은 최대 4개 항목으로 잘립니다. Collection의 모든 항목을 가져오려면 [`get_collection`]을 사용해야 합니다. 고급 필터링을 수행할 수 있습니다. 예를 들어 모델 [TheBloke/OpenHermes-2.5-Mistral-7B-GGUF](https://huggingface.co/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF)를 포함하는 트렌딩 순으로 정렬된 Collections를 5개까지만 가져올 수 있습니다. ```py >>> collections = list_collections(item="models/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF", sort="trending", limit=5): >>> for collection in collections: ... print(collection.slug) teknium/quantized-models-6544690bb978e0b0f7328748 AmeerH/function-calling-65560a2565d7a6ef568527af PostArchitekt/7bz-65479bb8c194936469697d8c gnomealone/need-to-test-652007226c6ce4cdacf9c233 Crataco/favorite-7b-models-651944072b4fffcb41f8b568 ``` `sort` 매개변수는 `"last_modified"`, `"trending"` 또는 `"upvotes"` 중 하나여야 합니다. `item` 매개변수는 특정 항목을 받습니다. 예를 들면 다음과 같습니다: * `"models/teknium/OpenHermes-2.5-Mistral-7B"` * `"spaces/julien-c/open-gpt-rhyming-robot"` * `"datasets/squad"` * `"papers/2311.12983"` 자세한 내용은 [`list_collections`] 참조를 확인하시기 바랍니다. ## 새 Collection 만들기[[fetch-a-collection]] 이제 [`Collection`]을 가져오는 방법을 알았으니 우리만의 Collection을 만들어봅시다! 제목과 설명을 사용하여 [`create_collection`]을 호출합니다. 조직 페이지에 Collection을 만들려면 Collection 생성 시 `namespace="my-cool-org"`를 전달합니다. 마지막으로 `private=True`를 전달하여 비공개 Collection을 만들 수도 있습니다. ```py >>> from huggingface_hub import create_collection >>> collection = create_collection( ... title="ICCV 2023", ... description="Portfolio of models, papers and demos I presented at ICCV 2023", ... ) ``` 이렇게 하면 (제목, 설명, 소유자 등의) 높은 수준의 메타데이터와 빈 항목 목록을 가진 [`Collection`] 객체가 반환됩니다. 이제 `slug`를 사용하여 이 Collection을 참조할 수 있습니다. ```py >>> collection.slug 'owner/iccv-2023-15e23b46cb98efca45' >>> collection.title "ICCV 2023" >>> collection.owner "username" >>> collection.url 'https://huggingface.co/collections/owner/iccv-2023-15e23b46cb98efca45' ``` ## Collection의 item 관리[[manage-items-in-a-collection]] 이제 [`Collection`]을 가지고 있으므로, 여기에 item을 추가하고 구성해봅시다. ### item 추가[[add-items]] item은 [`add_collection_item`]을 사용하여 하나씩 추가해야 합니다. `collection_slug`, `item_id`, `item_type`만 알면 됩니다. 또한 선택적으로 항목에 `note`를 추가할 수도 있습니다(최대 500자). ```py >>> from huggingface_hub import create_collection, add_collection_item >>> collection = create_collection(title="OS Week Highlights - Sept 18 - 24", namespace="osanseviero") >>> collection.slug "osanseviero/os-week-highlights-sept-18-24-650bfed7f795a59f491afb80" >>> add_collection_item(collection.slug, item_id="coqui/xtts", item_type="space") >>> add_collection_item( ... collection.slug, ... item_id="warp-ai/wuerstchen", ... item_type="model", ... note="Würstchen is a new fast and efficient high resolution text-to-image architecture and model" ... ) >>> add_collection_item(collection.slug, item_id="lmsys/lmsys-chat-1m", item_type="dataset") >>> add_collection_item(collection.slug, item_id="warp-ai/wuerstchen", item_type="space") # 동일한 item_id, 다른 item_type ``` Collection에 item이 이미 존재하는 경우(동일한 `item_id`/`item_type` 쌍), HTTP 409 오류가 발생합니다. `exists_ok=True`를 설정하면 이 오류를 무시할 수 있습니다. ### 기존 item에 메모 추가[[add-a-note-to-an-existing-item]] [`update_collection_item`]을 사용하여 기존 item을 수정하여 메모를 추가하거나 변경할 수 있습니다. 위의 예시를 다시 사용해 봅시다: ```py >>> from huggingface_hub import get_collection, update_collection_item # 새로 추가된 item과 함께 Collection 가져오기 >>> collection_slug = "osanseviero/os-week-highlights-sept-18-24-650bfed7f795a59f491afb80" >>> collection = get_collection(collection_slug) # `lmsys-chat-1m` 데이터셋에 메모 추가 >>> update_collection_item( ... collection_slug=collection_slug, ... item_object_id=collection.items[2].item_object_id, ... note="This dataset contains one million real-world conversations with 25 state-of-the-art LLMs.", ... ) ``` ### item 재정렬[[reorder-items]] Collection의 item은 순서가 있습니다. 이 순서는 각 item의 `position` 속성에 의해 결정됩니다. 기본적으로 item은 Collection의 끝에 추가되는 방식으로 순서가 지정됩니다. [`update_collection_item`]을 사용하여 메모를 추가하는 것과 같은 방식으로 순서를 업데이트할 수 있습니다. 위의 예시를 다시 사용해 봅시다: ```py >>> from huggingface_hub import get_collection, update_collection_item # Collection 가져오기 >>> collection_slug = "osanseviero/os-week-highlights-sept-18-24-650bfed7f795a59f491afb80" >>> collection = get_collection(collection_slug) # 두 개의 `Wuerstchen` item을 함께 배치하도록 재정렬 >>> update_collection_item( ... collection_slug=collection_slug, ... item_object_id=collection.items[3].item_object_id, ... position=2, ... ) ``` ### item 제거[[remove-items]] 마지막으로 [`delete_collection_item`]을 사용하여 item을 제거할 수도 있습니다. ```py >>> from huggingface_hub import get_collection, update_collection_item # Collection 가져오기 >>> collection_slug = "osanseviero/os-week-highlights-sept-18-24-650bfed7f795a59f491afb80" >>> collection = get_collection(collection_slug) # 목록에서 `coqui/xtts` Space 제거 >>> delete_collection_item(collection_slug=collection_slug, item_object_id=collection.items[0].item_object_id) ``` ## Collection 삭제[[delete-collection]] [`delete_collection`]을 사용하여 Collection을 삭제할 수 있습니다. 이 작업은 되돌릴 수 없습니다. 삭제된 Collection은 복구할 수 없습니다. ```py >>> from huggingface_hub import delete_collection >>> collection = delete_collection("username/useless-collection-64f9a55bb3115b4f513ec026", missing_ok=True) ```huggingface_hub-0.31.1/docs/source/ko/guides/community.md000066400000000000000000000152201500667546600234070ustar00rootroot00000000000000 # Discussions 및 Pull Requests를 이용하여 상호작용하기[[interact-with-discussions-and-pull-requests]] `huggingface_hub` 라이브러리는 Hub의 Pull Requests 및 Discussions와 상호작용할 수 있는 Python 인터페이스를 제공합니다. [전용 문서 페이지](https://huggingface.co/docs/hub/repositories-pull-requests-discussions)를 방문하여 Hub의 Discussions와 Pull Requests가 무엇이고 어떻게 작동하는지 자세히 살펴보세요. ## Hub에서 Discussions 및 Pull Requests 가져오기[[retrieve-discussions-and-pull-requests-from-the-hub]] `HfApi` 클래스를 사용하면 지정된 리포지토리에 대한 Discussions 및 Pull Requests를 검색할 수 있습니다: ```python >>> from huggingface_hub import get_repo_discussions >>> for discussion in get_repo_discussions(repo_id="bigscience/bloom"): ... print(f"{discussion.num} - {discussion.title}, pr: {discussion.is_pull_request}") # 11 - Add Flax weights, pr: True # 10 - Update README.md, pr: True # 9 - Training languages in the model card, pr: True # 8 - Update tokenizer_config.json, pr: True # 7 - Slurm training script, pr: False [...] ``` `HfApi.get_repo_discussion`은 작성자, 유형(Pull Requests 또는 Discussion) 및 상태(`open` 또는 `closed`)별로 필터링을 지원합니다: ```python >>> from huggingface_hub import get_repo_discussions >>> for discussion in get_repo_discussions( ... repo_id="bigscience/bloom", ... author="ArthurZ", ... discussion_type="pull_request", ... discussion_status="open", ... ): ... print(f"{discussion.num} - {discussion.title} by {discussion.author}, pr: {discussion.is_pull_request}") # 19 - Add Flax weights by ArthurZ, pr: True ``` `HfApi.get_repo_discussions`는 [`Discussion`] 객체를 생성하는 [생성자](https://docs.python.org/3.7/howto/functional.html#generators)를 반환합니다. 모든 Discussions를 하나의 리스트로 가져오려면 다음을 실행합니다: ```python >>> from huggingface_hub import get_repo_discussions >>> discussions_list = list(get_repo_discussions(repo_id="bert-base-uncased")) ``` [`HfApi.get_repo_discussions`]가 반환하는 [`Discussion`] 객체에는 Discussions 또는 Pull Request에 대한 개략적인 개요가 포함되어 있습니다. [`HfApi.get_discussion_details`]를 사용하여 더 자세한 정보를 얻을 수도 있습니다: ```python >>> from huggingface_hub import get_discussion_details >>> get_discussion_details( ... repo_id="bigscience/bloom-1b3", ... discussion_num=2 ... ) DiscussionWithDetails( num=2, author='cakiki', title='Update VRAM memory for the V100s', status='open', is_pull_request=True, events=[ DiscussionComment(type='comment', author='cakiki', ...), DiscussionCommit(type='commit', author='cakiki', summary='Update VRAM memory for the V100s', oid='1256f9d9a33fa8887e1c1bf0e09b4713da96773a', ...), ], conflicting_files=[], target_branch='refs/heads/main', merge_commit_oid=None, diff='diff --git a/README.md b/README.md\nindex a6ae3b9294edf8d0eda0d67c7780a10241242a7e..3a1814f212bc3f0d3cc8f74bdbd316de4ae7b9e3 100644\n--- a/README.md\n+++ b/README.md\n@@ -132,7 +132,7 [...]', ) ``` [`HfApi.get_discussion_details`]는 Discussion 또는 Pull Request에 대한 자세한 정보가 포함된 [`Discussion`]의 하위 클래스인 [`DiscussionWithDetails`] 객체를 반환합니다. 해당 정보는 [`DiscussionWithDetails.events`]를 통해 Discussion의 모든 댓글, 상태 변경 및 이름 변경을 포함하고 있습니다. Pull Request의 경우, [`DiscussionWithDetails.diff`]를 통해 원시 git diff를 검색할 수 있습니다. Pull Request의 모든 커밋은 [`DiscussionWithDetails.events`]에 나열됩니다. ## 프로그래밍 방식으로 Discussion 또는 Pull Request를 생성하고 수정하기[[create-and-edit-a-discussion-or-pull-request-programmatically]] [`HfApi`] 클래스는 Discussions 및 Pull Requests를 생성하고 수정하는 방법도 제공합니다. Discussions와 Pull Requests를 만들고 편집하려면 [접근 토큰](https://huggingface.co/docs/hub/security-tokens)이 필요합니다. Hub의 리포지토리에 변경 사항을 제안하는 가장 간단한 방법은 [`create_commit`] API를 사용하는 것입니다. `create_pr` 매개변수를 `True`로 설정하기만 하면 됩니다. 이 매개변수는 [`create_commit`]을 래핑하는 다른 함수에서도 사용할 수 있습니다: * [`upload_file`] * [`upload_folder`] * [`delete_file`] * [`delete_folder`] * [`metadata_update`] ```python >>> from huggingface_hub import metadata_update >>> metadata_update( ... repo_id="username/repo_name", ... metadata={"tags": ["computer-vision", "awesome-model"]}, ... create_pr=True, ... ) ``` 리포지토리에 대한 Discussion(또는 Pull Request)을 만들려면 [`HfApi.create_discussion`](또는 [`HfApi.create_pull_request`])을 사용할 수도 있습니다. 이 방법으로 Pull Request를 열면 로컬에서 변경 작업을 해야 하는 경우에 유용할 수 있습니다. 이 방법으로 열린 Pull Request는 `"draft"` 모드가 됩니다. ```python >>> from huggingface_hub import create_discussion, create_pull_request >>> create_discussion( ... repo_id="username/repo-name", ... title="Hi from the huggingface_hub library!", ... token="", ... ) DiscussionWithDetails(...) >>> create_pull_request( ... repo_id="username/repo-name", ... title="Hi from the huggingface_hub library!", ... token="", ... ) DiscussionWithDetails(..., is_pull_request=True) ``` Pull Requests 및 Discussions 관리는 전적으로 [`HfApi`] 클래스로 할 수 있습니다. 예를 들어: * 댓글을 추가하려면 [`comment_discussion`] * 댓글을 수정하려면 [`edit_discussion_comment`] * Discussion 또는 Pull Request의 이름을 바꾸려면 [`rename_discussion`] * Discussion / Pull Request를 열거나 닫으려면 [`change_discussion_status`] * Pull Request를 병합하려면 [`merge_pull_request`]를 사용합니다. 사용 가능한 모든 메소드에 대한 전체 참조는 [`HfApi`] 문서 페이지를 참조하세요. ## Pull Request에 변경 사항 푸시[[push-changes-to-a-pull-request]] *곧 공개됩니다!* ## 참고 항목[[see-also]] 더 자세한 내용은 [Discussions 및 Pull Requests](../package_reference/community)와 [hf_api](../package_reference/hf_api) 문서 페이지를 참조하세요. huggingface_hub-0.31.1/docs/source/ko/guides/download.md000066400000000000000000000346711500667546600232050ustar00rootroot00000000000000 # Hub에서 파일 다운로드하기[[download-files-from-the-hub]] `huggingface_hub` 라이브러리는 Hub의 저장소에서 파일을 다운로드하는 기능을 제공합니다. 이 기능은 함수로 직접 사용할 수 있고, 사용자가 만든 라이브러리에 통합하여 Hub와 쉽게 상호 작용할 수 있도록 할 수 있습니다. 이 가이드에서는 다음 내용을 다룹니다: * 파일 하나를 다운로드하고 캐시하는 방법 * 리포지토리 전체를 다운로드하고 캐시하는 방법 * 로컬 폴더에 파일을 다운로드하는 방법 ## 파일 하나만 다운로드하기[[download-a-single-file]] [`hf_hub_download`] 함수를 사용하면 Hub에서 파일을 다운로드할 수 있습니다. 이 함수는 원격 파일을 다운로드하여 (버전별로) 디스크에 캐시하고, 로컬 파일 경로를 반환합니다. 반환된 파일 경로는 HF 로컬 캐시의 위치를 가리킵니다. 그러므로 캐시가 손상되지 않도록 파일을 수정하지 않는 것이 좋습니다. 캐시가 어떻게 작동하는지 자세히 알고 싶으시면 [캐싱 가이드](./manage-cache)를 참조하세요. ### 최신 버전에서 파일 다운로드하기[[from-latest-version]] 다운로드할 파일을 선택하기 위해 `repo_id`, `repo_type`, `filename` 매개변수를 사용합니다. `repo_type` 매개변수를 생략하면 파일은 `model` 리포의 일부라고 간주됩니다. ```python >>> from huggingface_hub import hf_hub_download >>> hf_hub_download(repo_id="lysandre/arxiv-nlp", filename="config.json") '/root/.cache/huggingface/hub/models--lysandre--arxiv-nlp/snapshots/894a9adde21d9a3e3843e6d5aeaaf01875c7fade/config.json' # 데이터세트의 경우 >>> hf_hub_download(repo_id="google/fleurs", filename="fleurs.py", repo_type="dataset") '/root/.cache/huggingface/hub/datasets--google--fleurs/snapshots/199e4ae37915137c555b1765c01477c216287d34/fleurs.py' ``` ### 특정 버전에서 파일 다운로드하기[[from-specific-version]] 기본적으로 `main` 브랜치의 최신 버전의 파일이 다운로드됩니다. 그러나 특정 버전의 파일을 다운로드하고 싶을 수도 있습니다. 예를 들어, 특정 브랜치, 태그, 커밋 해시 등에서 파일을 다운로드하고 싶을 수 있습니다. 이 경우 `revision` 매개변수를 사용하여 원하는 버전을 지정할 수 있습니다: ```python # `v1.0` 태그에서 다운로드하기 >>> hf_hub_download(repo_id="lysandre/arxiv-nlp", filename="config.json", revision="v1.0") # `test-branch` 브랜치에서 다운로드하기 >>> hf_hub_download(repo_id="lysandre/arxiv-nlp", filename="config.json", revision="test-branch") # PR #3에서 다운로드하기 >>> hf_hub_download(repo_id="lysandre/arxiv-nlp", filename="config.json", revision="refs/pr/3") # 특정 커밋 해시에서 다운로드하기 >>> hf_hub_download(repo_id="lysandre/arxiv-nlp", filename="config.json", revision="877b84a8f93f2d619faa2a6e514a32beef88ab0a") ``` **참고**: 커밋 해시를 사용할 때는 7자리의 짧은 커밋 해시가 아니라 전체 길이의 커밋 해시를 사용해야 합니다. ### 다운로드 URL 만들기[[construct-a-download-url]] 리포지토리에서 파일을 다운로드하는 데 사용할 URL을 만들고 싶은 경우 [`hf_hub_url`] 함수를 사용하여 URL을 반환받을 수 있습니다. 이 함수는 [`hf_hub_download`] 함수가 내부적으로 사용하는 URL을 생성한다는 점을 알아두세요. ## 전체 리포지토리 다운로드하기[[download-an-entire-repository]] [`snapshot_download`] 함수는 특정 버전의 전체 리포지토리를 다운로드합니다. 이 함수는 내부적으로 [`hf_hub_download`] 함수를 사용하므로, 다운로드한 모든 파일은 로컬 디스크에 캐시되어 저장됩니다. 다운로드는 여러 파일을 동시에 받아오기 때문에 빠르게 진행됩니다. 전체 리포지토리를 다운로드하려면 `repo_id`와 `repo_type`을 인자로 넘겨주면 됩니다: ```python >>> from huggingface_hub import snapshot_download >>> snapshot_download(repo_id="lysandre/arxiv-nlp") '/home/lysandre/.cache/huggingface/hub/models--lysandre--arxiv-nlp/snapshots/894a9adde21d9a3e3843e6d5aeaaf01875c7fade' # 또는 데이터세트의 경우 >>> snapshot_download(repo_id="google/fleurs", repo_type="dataset") '/home/lysandre/.cache/huggingface/hub/datasets--google--fleurs/snapshots/199e4ae37915137c555b1765c01477c216287d34' ``` [`snapshot_download`] 함수는 기본적으로 최신 버전의 리포지토리를 다운로드합니다. 특정 버전의 리포지토리를 다운로드하고 싶은 경우, `revision` 매개변수에 원하는 버전을 지정하면 됩니다: ```python >>> from huggingface_hub import snapshot_download >>> snapshot_download(repo_id="lysandre/arxiv-nlp", revision="refs/pr/1") ``` ### 다운로드할 파일 선택하기[[filter-files-to-download]] [`snapshot_download`] 함수는 리포지토리를 쉽게 다운로드할 수 있도록 해줍니다. 그러나 리포지토리의 모든 내용을 다운로드하고 싶지 않을 수도 있습니다. 예를 들어, `.safetensors` 가중치만 사용하고 싶다면, 모든 `.bin` 파일을 다운로드하지 않도록 할 수 있습니다. `allow_pattern`과 `ignore_pattern` 매개변수를 사용하여 원하는 파일만 다운로드할 수 있습니다. 이 매개변수들은 하나의 패턴이나 패턴의 리스트를 받을 수 있습니다. 패턴은 [여기](https://tldp.org/LDP/GNU-Linux-Tools-Summary/html/x11655.htm)에서 설명된 것처럼 표준 와일드카드(글로빙 패턴)입니다. 패턴 매칭은 [`fnmatch`](https://docs.python.org/3/library/fnmatch.html)에 기반합니다. 예를 들어, `allow_patterns`를 사용하여 JSON 구성 파일만 다운로드하는 방법은 다음과 같습니다: ```python >>> from huggingface_hub import snapshot_download >>> snapshot_download(repo_id="lysandre/arxiv-nlp", allow_patterns="*.json") ``` 반대로 `ignore_patterns`는 특정 파일을 다운로드에서 제외시킬 수 있습니다. 다음 예제는 `.msgpack`과 `.h5` 파일 확장자를 무시하는 방법입니다: ```python >>> from huggingface_hub import snapshot_download >>> snapshot_download(repo_id="lysandre/arxiv-nlp", ignore_patterns=["*.msgpack", "*.h5"]) ``` 마지막으로, 두 가지 매개변수를 함께 사용하여 다운로드를 정확하게 선택할 수 있습니다. 다음은 `vocab.json`을 제외한 모든 json 및 마크다운 파일을 다운로드하는 예제입니다. ```python >>> from huggingface_hub import snapshot_download >>> snapshot_download(repo_id="gpt2", allow_patterns=["*.md", "*.json"], ignore_patterns="vocab.json") ``` ## 로컬 폴더에 파일 다운로드하기[[download-files-to-local-folder]] Hub에서 파일을 다운로드하는 가장 좋은 (그리고 기본적인) 방법은 [캐시 시스템](./manage-cache)을 사용하는 것입니다. 캐시 위치는 `cache_dir` 매개변수로 설정하여 지정할 수 있습니다([`hf_hub_download`]과 [`snapshot_download`]에서 모두 사용 가능). 그러나 파일을 다운로드하여 특정 폴더에 넣고 싶은 경우도 있습니다. 이 기능은 `git` 명령어가 제공하는 기능과 비슷한 워크플로우를 만들 수 있습니다. 이 경우 `local_dir`과 `local_dir_use_symlinks` 매개변수를 사용하여 원하는 대로 파일을 넣을 수 있습니다: - `local_dir`은 시스템 내의 폴더 경로입니다. 다운로드한 파일은 리포지토리에 있는 것과 같은 파일 구조를 유지합니다. 예를 들어 `filename="data/train.csv"`와 `local_dir="path/to/folder"`라면, 반환된 파일 경로는 `"path/to/folder/data/train.csv"`가 됩니다. - `local_dir_use_symlinks`는 파일을 로컬 폴더에 어떻게 넣을지 정의합니다. - 기본 동작('자동')은 작은 파일(5MB 이하)은 복사하고 큰 파일은 심볼릭 링크를 사용하는 것입니다. 심볼릭 링크를 사용하면 대역폭과 디스크 공간을 모두 절약할 수 있습니다. 그러나 심볼릭 링크된 파일을 직접 수정하면 캐시가 손상될 수 있으므로 작은 파일에 대해서는 복사를 사용합니다. 5MB 임계값은 `HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD` 환경 변수로 설정할 수 있습니다. - `local_dir_use_symlinks=true`로 설정하면 디스크 공간을 최대한 절약하기 위해 모든 파일이 심볼릭 링크됩니다. 이는 예를 들어 수천 개의 작은 파일로 이루어진 대용량 데이터 세트를 다운로드할 때 유용합니다. - 마지막으로 심볼릭 링크를 전혀 사용하지 않으려면 심볼릭 링크를 비활성화하면 됩니다(`local_dir_use_symlinks=False`). 캐시 디렉토리는 파일이 이미 캐시되었는지 여부를 확인하는 데 계속 사용됩니다. 이미 캐시된 경우 파일이 캐시에서 **복사**됩니다(즉, 대역폭은 절약되지만 디스크 공간이 증가합니다). 파일이 아직 캐시되지 않은 경우 파일을 다운로드하여 로컬 디렉터리에 바로 넣습니다. 즉, 나중에 다른 곳에서 다시 사용하려면 **다시 다운로드**해야 합니다. 다음은 다양한 옵션을 요약한 표입니다. 이 표를 참고하여 자신의 사용 사례에 가장 적합한 매개변수를 선택하세요. | 파라미터 | 캐시되었는지 여부 | 반환된 파일경로 | 열람 권한 | 수정 권한 | 대역폭의 효율적인 사용 | 디스크의 효율적인 접근 | |---|:---:|:---:|:---:|:---:|:---:|:---:| | `local_dir=None` | | 캐시 속 심볼릭 링크 | ✅ | ❌
_(저장하면 캐시가 손상됩니다)_ | ✅ | ✅ | | `local_dir="path/to/folder"`
`local_dir_use_symlinks="auto"` | | 폴더 속 파일 또는 심볼릭 링크 | ✅ | ✅ _(소규모 파일의 경우)_
⚠️ _(대규모 파일의 경우 저장하기 전에 경로를 생성하지 마세요)_ | ✅ | ✅ | | `local_dir="path/to/folder"`
`local_dir_use_symlinks=True` | | 폴더 속 심볼릭 링크 | ✅ | ⚠️
_(저장하기 전에 경로를 생성하지 마세요)_ | ✅ | ✅ | | `local_dir="path/to/folder"`
`local_dir_use_symlinks=False` | 아니오 | 폴더 속 파일 | ✅ | ✅ | ❌
_(다시 실행하면 파일도 다시 다운로드됩니다)_ | ⚠️
(여러 폴더에서 실행하면 그만큼 복사본이 생깁니다) | | `local_dir="path/to/folder"`
`local_dir_use_symlinks=False` | 예 | 폴더 속 파일 | ✅ | ✅ | ⚠️
_(파일이 캐시되어 있어야 합니다)_ | ❌
_(파일이 중복됩니다)_ | **참고**: Windows 컴퓨터를 사용하는 경우 심볼릭 링크를 사용하려면 개발자 모드를 켜거나 관리자 권한으로 `huggingface_hub`를 실행해야 합니다. 자세한 내용은 [캐시 제한](../guides/manage-cache#limitations) 섹션을 참조하세요. ## CLI에서 파일 다운로드하기[[download-from-the-cli]] 터미널에서 `huggingface-cli download` 명령어를 사용하면 Hub에서 파일을 바로 다운로드할 수 있습니다. 이 명령어는 내부적으로 앞서 설명한 [`hf_hub_download`]과 [`snapshot_download`] 함수를 사용하고, 다운로드한 파일의 로컬 경로를 터미널에 출력합니다: ```bash >>> huggingface-cli download gpt2 config.json /home/wauplin/.cache/huggingface/hub/models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10/config.json ``` 기본적으로 (`huggingface-cli login` 명령으로) 로컬에 저장된 토큰을 사용합니다. 직접 인증하고 싶다면, `--token` 옵션을 사용하세요: ```bash >>> huggingface-cli download gpt2 config.json --token=hf_**** /home/wauplin/.cache/huggingface/hub/models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10/config.json ``` 여러 파일을 한 번에 다운로드하면 진행률 표시줄이 보이고, 파일이 있는 스냅샷 경로가 반환됩니다: ```bash >>> huggingface-cli download gpt2 config.json model.safetensors Fetching 2 files: 100%|████████████████████████████████████████████| 2/2 [00:00<00:00, 23831.27it/s] /home/wauplin/.cache/huggingface/hub/models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10 ``` 진행률 표시줄이나 잠재적 경고가 필요 없다면 `--quiet` 옵션을 사용하세요. 이 옵션은 스크립트에서 다른 명령어로 출력을 넘겨주려는 경우에 유용할 수 있습니다. ```bash >>> huggingface-cli download gpt2 config.json model.safetensors /home/wauplin/.cache/huggingface/hub/models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10 ``` 기본적으로 파일은 `HF_HOME` 환경 변수에 정의된 캐시 디렉터리(또는 지정하지 않은 경우 `~/.cache/huggingface/hub`)에 다운로드됩니다. 캐시 디렉터리는 `--cache-dir` 옵션으로 변경할 수 있습니다: ```bash >>> huggingface-cli download gpt2 config.json --cache-dir=./cache ./cache/models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10/config.json ``` 캐시 디렉터리 구조를 따르지 않고 로컬 폴더에 파일을 다운로드하려면 `--local-dir` 옵션을 사용하세요. 로컬 폴더로 다운로드하면 이 [표](https://huggingface.co/docs/huggingface_hub/guides/download#download-files-to-local-folder)에 나열된 제한 사항이 있습니다. ```bash >>> huggingface-cli download gpt2 config.json --local-dir=./models/gpt2 ./models/gpt2/config.json ``` 다른 리포지토리 유형이나 버전에서 파일을 다운로드하거나 glob 패턴을 사용하여 다운로드할 파일을 선택하거나 제외하도록 지정할 수 있는 인수들이 더 있습니다: ```bash >>> huggingface-cli download bigcode/the-stack --repo-type=dataset --revision=v1.2 --include="data/python/*" --exclu de="*.json" --exclude="*.zip" Fetching 206 files: 100%|████████████████████████████████████████████| 206/206 [02:31<2:31, ?it/s] /home/wauplin/.cache/huggingface/hub/datasets--bigcode--the-stack/snapshots/9ca8fa6acdbc8ce920a0cb58adcdafc495818ae7 ``` 인수들의 전체 목록을 보려면 다음 명령어를 실행하세요: ```bash huggingface-cli download --help ``` huggingface_hub-0.31.1/docs/source/ko/guides/hf_file_system.md000066400000000000000000000124221500667546600243640ustar00rootroot00000000000000 # Hugging Face Hub에서 파일 시스템 API를 통해 상호작용하기[[interact-with-the-hub-through-the-filesystem-api]] `huggingface_hub` 라이브러리는 [`HfApi`] 외에도 Hugging Face Hub에 대한 파이써닉한 [fsspec-compatible](https://filesystem-spec.readthedocs.io/en/latest/) 파일 인터페이스인 [`HfFileSystem`]을 제공합니다. [`HfFileSystem`]은 [`HfApi`]을 기반으로 구축되며, `cp`, `mv`, `ls`, `du`, `glob`, `get_file` 및 `put_file`과 같은 일반적인 파일 시스템 스타일 작업을 제공합니다. ## 사용법[[usage]] ```python >>> from huggingface_hub import HfFileSystem >>> fs = HfFileSystem() >>> # 디렉터리의 모든 파일 나열하기 >>> fs.ls("datasets/my-username/my-dataset-repo/data", detail=False) ['datasets/my-username/my-dataset-repo/data/train.csv', 'datasets/my-username/my-dataset-repo/data/test.csv'] >>> # 저장소(repo)에서 ".csv" 파일 모두 나열하기 >>> fs.glob("datasets/my-username/my-dataset-repo/**.csv") ['datasets/my-username/my-dataset-repo/data/train.csv', 'datasets/my-username/my-dataset-repo/data/test.csv'] >>> # 원격 파일 읽기 >>> with fs.open("datasets/my-username/my-dataset-repo/data/train.csv", "r") as f: ... train_data = f.readlines() >>> # 문자열로 원격 파일의 내용 읽기 >>> train_data = fs.read_text("datasets/my-username/my-dataset-repo/data/train.csv", revision="dev") >>> # 원격 파일 쓰기 >>> with fs.open("datasets/my-username/my-dataset-repo/data/validation.csv", "w") as f: ... f.write("text,label") ... f.write("Fantastic movie!,good") ``` 선택적 `revision` 인수를 전달하여 브랜치, 태그 이름 또는 커밋 해시와 같은 특정 커밋에서 작업을 실행할 수 있습니다. 파이썬에 내장된 `open`과 달리 `fsspec`의 `open`은 바이너리 모드 `"rb"`로 기본 설정됩니다. 이것은 텍스트 모드에서 읽기 위해 `"r"`, 쓰기 위해 `"w"`로 모드를 명시적으로 설정해야 함을 의미합니다. 파일에 추가하기(모드 `"a"` 및 `"ab"`)는 아직 지원되지 않습니다. ## 통합[[integrations]] [`HfFileSystem`]은 URL이 다음 구문을 따르는 경우 `fsspec`을 통합하는 모든 라이브러리에서 사용할 수 있습니다. ``` hf://[][@]/ ``` 여기서 `repo_type_prefix`는 Datasets의 경우 `datasets/`, Spaces의 경우 `spaces/`이며, 모델에는 URL에 접두사가 필요하지 않습니다. [`HfFileSystem`]이 Hub와의 상호작용을 단순화하는 몇 가지 흥미로운 통합 사례는 다음과 같습니다: * Hub 저장소에서 [Pandas](https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html#reading-writing-remote-files) DataFrame 읽기/쓰기: ```python >>> import pandas as pd >>> # 원격 CSV 파일을 데이터프레임으로 읽기 >>> df = pd.read_csv("hf://datasets/my-username/my-dataset-repo/train.csv") >>> # 데이터프레임을 원격 CSV 파일로 쓰기 >>> df.to_csv("hf://datasets/my-username/my-dataset-repo/test.csv") ``` 동일한 워크플로우를 [Dask](https://docs.dask.org/en/stable/how-to/connect-to-remote-data.html) 및 [Polars](https://pola-rs.github.io/polars/py-polars/html/reference/io.html) DataFrame에도 사용할 수 있습니다. * [DuckDB](https://duckdb.org/docs/guides/python/filesystems)를 사용하여 (원격) Hub 파일 쿼리: ```python >>> from huggingface_hub import HfFileSystem >>> import duckdb >>> fs = HfFileSystem() >>> duckdb.register_filesystem(fs) >>> # 원격 파일을 쿼리하고 결과를 데이터프레임으로 가져오기 >>> fs_query_file = "hf://datasets/my-username/my-dataset-repo/data_dir/data.parquet" >>> df = duckdb.query(f"SELECT * FROM '{fs_query_file}' LIMIT 10").df() ``` * [Zarr](https://zarr.readthedocs.io/en/stable/tutorial.html#io-with-fsspec)를 사용하여 Hub를 배열 저장소로 사용: ```python >>> import numpy as np >>> import zarr >>> embeddings = np.random.randn(50000, 1000).astype("float32") >>> # 저장소(repo)에 배열 쓰기 >>> with zarr.open_group("hf://my-username/my-model-repo/array-store", mode="w") as root: ... foo = root.create_group("embeddings") ... foobar = foo.zeros('experiment_0', shape=(50000, 1000), chunks=(10000, 1000), dtype='f4') ... foobar[:] = embeddings >>> # 저장소(repo)에서 배열 읽기 >>> with zarr.open_group("hf://my-username/my-model-repo/array-store", mode="r") as root: ... first_row = root["embeddings/experiment_0"][0] ``` ## 인증[[authentication]] 대부분의 경우 Hub와 상호작용하려면 Hugging Face 계정에 로그인해야 합니다. Hub에서 인증 방법에 대해 자세히 알아보려면 문서의 [인증](../quick-start#authentication) 섹션을 참조하세요. 또한 [`HfFileSystem`]에 `token`을 인수로 전달하여 프로그래밍 방식으로 로그인할 수 있습니다: ```python >>> from huggingface_hub import HfFileSystem >>> fs = HfFileSystem(token=token) ``` 이렇게 로그인하는 경우 소스 코드를 공유할 때 토큰이 실수로 누출되지 않도록 주의해야 합니다! huggingface_hub-0.31.1/docs/source/ko/guides/inference.md000066400000000000000000000413151500667546600233250ustar00rootroot00000000000000 # 서버에서 추론 진행하기[[run-inference-on-servers]] 추론은 훈련된 모델을 사용하여 새 데이터에 대한 예측을 수행하는 과정입니다. 이 과정은 계산이 많이 필요할 수 있으므로, 전용 서버에서 실행하는 것이 좋은 방안이 될 수 있습니다. `huggingface_hub` 라이브러리는 호스팅된 모델에 대한 추론을 실행하는 서비스를 호출하는 간편한 방법을 제공합니다. 다음과 같은 여러 서비스에 연결할 수 있습니다: - [추론 API](https://huggingface.co/docs/api-inference/index): Hugging Face의 인프라에서 가속화된 추론을 실행할 수 있는 서비스로 무료로 제공됩니다. 이 서비스는 추론을 시작하고 다양한 모델을 테스트하며 AI 제품의 프로토타입을 만드는 빠른 방법입니다. - [추론 엔드포인트](https://huggingface.co/docs/inference-endpoints/index): 모델을 제품 환경에 쉽게 배포할 수 있는 제품입니다. 사용자가 선택한 클라우드 환경에서 완전 관리되는 전용 인프라에서 Hugging Face를 통해 추론이 실행됩니다. 이러한 서비스들은 [`InferenceClient`] 객체를 사용하여 호출할 수 있습니다. 이는 이전의 [`InferenceApi`] 클라이언트를 대체하는 역할을 하며, 작업에 대한 특별한 지원을 추가하고 [추론 API](https://huggingface.co/docs/api-inference/index) 및 [추론 엔드포인트](https://huggingface.co/docs/inference-endpoints/index)에서 추론 작업을 처리합니다. 새 클라이언트로의 마이그레이션에 대한 자세한 내용은 [레거시 InferenceAPI 클라이언트](#legacy-inferenceapi-client) 섹션을 참조하세요. [`InferenceClient`]는 API에 HTTP 호출을 수행하는 Python 클라이언트입니다. HTTP 호출을 원하는 툴을 이용하여 직접 사용하려면 (curl, postman 등) [추론 API](https://huggingface.co/docs/api-inference/index) 또는 [추론 엔드포인트](https://huggingface.co/docs/inference-endpoints/index) 문서 페이지를 참조하세요. 웹 개발을 위해 [JS 클라이언트](https://huggingface.co/docs/huggingface.js/inference/README)가 출시되었습니다. 게임 개발에 관심이 있다면 [C# 프로젝트](https://github.com/huggingface/unity-api)를 살펴보세요. ## 시작하기[[getting-started]] text-to-image 작업을 시작해보겠습니다. ```python >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> image = client.text_to_image("An astronaut riding a horse on the moon.") >>> image.save("astronaut.png") ``` 우리는 기본 매개변수로 [`InferenceClient`]를 초기화했습니다. 수행하고자 하는 [작업](#supported-tasks)만 알면 됩니다. 기본적으로 클라이언트는 추론 API에 연결하고 작업을 완료할 모델을 선택합니다. 예제에서는 텍스트 프롬프트에서 이미지를 생성했습니다. 반환된 값은 파일로 저장할 수 있는 `PIL.Image` 객체입니다. API는 간단하게 설계되었습니다. 모든 매개변수와 옵션이 사용 가능하거나 설명되어 있는 것은 아닙니다. 각 작업에서 사용 가능한 모든 매개변수에 대해 자세히 알아보려면 [이 페이지](https://huggingface.co/docs/api-inference/detailed_parameters)를 확인하세요. ### 특정 모델 사용하기[[using-a-specific-model]] 특정 모델을 사용하고 싶다면 어떻게 해야 할까요? 매개변수로 직접 지정하거나 인스턴스 수준에서 직접 지정할 수 있습니다: ```python >>> from huggingface_hub import InferenceClient # 특정 모델을 위한 클라이언트를 초기화합니다. >>> client = InferenceClient(model="prompthero/openjourney-v4") >>> client.text_to_image(...) # 또는 일반적인 클라이언트를 사용하되 모델을 인수로 전달하세요. >>> client = InferenceClient() >>> client.text_to_image(..., model="prompthero/openjourney-v4") ``` Hugging Face Hub에는 20만 개가 넘는 모델이 있습니다! [`InferenceClient`]의 각 작업에는 추천되는 모델이 포함되어 있습니다. HF의 추천은 사전 고지 없이 시간이 지남에 따라 변경될 수 있음을 유의하십시오. 따라서 모델을 결정한 후에는 명시적으로 모델을 설정하는 것이 좋습니다. 또한 대부분의 경우 자신의 필요에 맞는 모델을 직접 찾고자 할 것입니다. 허브의 [모델](https://huggingface.co/models) 페이지를 방문하여 찾아보세요. ### 특정 URL 사용하기[[using-a-specific-url]] 위에서 본 예제들은 서버리스 추론 API를 사용합니다. 이는 빠르게 프로토타입을 정하고 테스트할 때 매우 유용합니다. 모델을 프로덕션 환경에 배포할 준비가 되면 전용 인프라를 사용해야 합니다. 그것이 [추론 엔드포인트](https://huggingface.co/docs/inference-endpoints/index)가 필요한 이유입니다. 이를 사용하면 모든 모델을 배포하고 개인 API로 노출시킬 수 있습니다. 한 번 배포되면 이전과 완전히 동일한 코드를 사용하여 연결할 수 있는 URL을 얻게 됩니다. `model` 매개변수만 변경하면 됩니다: ```python >>> from huggingface_hub import InferenceClient >>> client = InferenceClient(model="https://uu149rez6gw9ehej.eu-west-1.aws.endpoints.huggingface.cloud/deepfloyd-if") # 또는 >>> client = InferenceClient() >>> client.text_to_image(..., model="https://uu149rez6gw9ehej.eu-west-1.aws.endpoints.huggingface.cloud/deepfloyd-if") ``` ### 인증[[authentication]] [`InferenceClient`]로 수행된 호출은 [사용자 액세스 토큰](https://huggingface.co/docs/hub/security-tokens)을 사용하여 인증할 수 있습니다. 기본적으로 로그인한 경우 기기에 저장된 토큰을 사용합니다 ([인증 방법](https://huggingface.co/docs/huggingface_hub/quick-start#authentication)을 확인하세요). 로그인하지 않은 경우 인스턴스 매개변수로 토큰을 전달할 수 있습니다. ```python >>> from huggingface_hub import InferenceClient >>> client = InferenceClient(token="hf_***") ``` 추론 API를 사용할 때 인증은 필수가 아닙니다. 그러나 인증된 사용자는 서비스를 이용할 수 있는 더 높은 무료 티어를 받습니다. 토큰은 개인 모델이나 개인 엔드포인트에서 추론을 실행하려면 필수입니다. ## 지원되는 작업[[supported-tasks]] [`InferenceClient`]의 목표는 Hugging Face 모델에서 추론을 실행하기 위한 가장 쉬운 인터페이스를 제공하는 것입니다. 이는 가장 일반적인 작업들을 지원하는 간단한 API를 가지고 있습니다. 현재 지원되는 작업 목록은 다음과 같습니다: | 도메인 | 작업 | 지원 여부 | 문서 | |--------|--------------------------------|--------------|------------------------------------| | 오디오 | [오디오 분류](https://huggingface.co/tasks/audio-classification) | ✅ | [`~InferenceClient.audio_classification`] | | 오디오 | [오디오 투 오디오](https://huggingface.co/tasks/audio-to-audio) | ✅ | [`~InferenceClient.audio_to_audio`] | | | [자동 음성 인식](https://huggingface.co/tasks/automatic-speech-recognition) | ✅ | [`~InferenceClient.automatic_speech_recognition`] | | | [텍스트 투 스피치](https://huggingface.co/tasks/text-to-speech) | ✅ | [`~InferenceClient.text_to_speech`] | | 컴퓨터 비전 | [이미지 분류](https://huggingface.co/tasks/image-classification) | ✅ | [`~InferenceClient.image_classification`] | | | [이미지 분할](https://huggingface.co/tasks/image-segmentation) | ✅ | [`~InferenceClient.image_segmentation`] | | | [이미지 투 이미지](https://huggingface.co/tasks/image-to-image) | ✅ | [`~InferenceClient.image_to_image`] | | | [이미지 투 텍스트](https://huggingface.co/tasks/image-to-text) | ✅ | [`~InferenceClient.image_to_text`] | | | [객체 탐지](https://huggingface.co/tasks/object-detection) | ✅ | [`~InferenceClient.object_detection`] | | | [텍스트 투 이미지](https://huggingface.co/tasks/text-to-image) | ✅ | [`~InferenceClient.text_to_image`] | | | [제로샷 이미지 분류](https://huggingface.co/tasks/zero-shot-image-classification) | ✅ | [`~InferenceClient.zero_shot_image_classification`] | | 멀티모달 | [문서 질의 응답](https://huggingface.co/tasks/document-question-answering) | ✅ | [`~InferenceClient.document_question_answering`] | | | [시각적 질의 응답](https://huggingface.co/tasks/visual-question-answering) | ✅ | [`~InferenceClient.visual_question_answering`] | | 자연어 처리 | [대화형](https://huggingface.co/tasks/conversational) | ✅ | [`~InferenceClient.conversational`] | | | [특성 추출](https://huggingface.co/tasks/feature-extraction) | ✅ | [`~InferenceClient.feature_extraction`] | | | [마스크 채우기](https://huggingface.co/tasks/fill-mask) | ✅ | [`~InferenceClient.fill_mask`] | | | [질의 응답](https://huggingface.co/tasks/question-answering) | ✅ | [`~InferenceClient.question_answering`] | | | [문장 유사도](https://huggingface.co/tasks/sentence-similarity) | ✅ | [`~InferenceClient.sentence_similarity`] | | | [요약](https://huggingface.co/tasks/summarization) | ✅ | [`~InferenceClient.summarization`] | | | [테이블 질의 응답](https://huggingface.co/tasks/table-question-answering) | ✅ | [`~InferenceClient.table_question_answering`] | | | [텍스트 분류](https://huggingface.co/tasks/text-classification) | ✅ | [`~InferenceClient.text_classification`] | | | [텍스트 생성](https://huggingface.co/tasks/text-generation) | ✅ | [`~InferenceClient.text_generation`] | | | [토큰 분류](https://huggingface.co/tasks/token-classification) | ✅ | [`~InferenceClient.token_classification`] | | | [번역](https://huggingface.co/tasks/translation) | ✅ | [`~InferenceClient.translation`] | | | [제로샷 분류](https://huggingface.co/tasks/zero-shot-classification) | ✅ | [`~InferenceClient.zero_shot_classification`] | | 타블로 | [타블로 작업 분류](https://huggingface.co/tasks/tabular-classification) | ✅ | [`~InferenceClient.tabular_classification`] | | | [타블로 회귀](https://huggingface.co/tasks/tabular-regression) | ✅ | [`~InferenceClient.tabular_regression`] | 각 작업에 대해 더 자세히 알고 싶거나 사용 방법 및 각 작업에 대한 가장 인기 있는 모델을 알아보려면 [Tasks](https://huggingface.co/tasks) 페이지를 확인하세요. ## 비동기 클라이언트[[async-client]] `asyncio`와 `aiohttp`를 기반으로 한 클라이언트의 비동기 버전도 제공됩니다. `aiohttp`를 직접 설치하거나 `[inference]` 추가 옵션을 사용할 수 있습니다: ```sh pip install --upgrade huggingface_hub[inference] # 또는 # pip install aiohttp ``` 설치 후 모든 비동기 API 엔드포인트는 [`AsyncInferenceClient`]를 통해 사용할 수 있습니다. 초기화 및 API는 동기 전용 버전과 완전히 동일합니다. ```py # 코드는 비동기 asyncio 라이브러리 동시성 컨텍스트에서 실행되어야 합니다. # $ python -m asyncio >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> image = await client.text_to_image("An astronaut riding a horse on the moon.") >>> image.save("astronaut.png") >>> async for token in await client.text_generation("The Huggingface Hub is", stream=True): ... print(token, end="") a platform for sharing and discussing ML-related content. ``` `asyncio` 모듈에 대한 자세한 정보는 [공식 문서](https://docs.python.org/3/library/asyncio.html)를 참조하세요. ## 고급 팁[[advanced-tips]] 위 섹션에서는 [`InferenceClient`]의 주요 측면을 살펴보았습니다. 이제 몇 가지 고급 팁에 대해 자세히 알아보겠습니다. ### 타임아웃[[timeout]] 추론을 수행할 때 타임아웃이 발생하는 주요 원인은 두 가지입니다: - 추론 프로세스가 완료되는 데 오랜 시간이 걸리는 경우 - 모델이 사용 불가능한 경우, 예를 들어 Inference API를 처음으로 가져오는 경우 [`InferenceClient`]에는 이 두 가지를 처리하기 위한 전역 `timeout` 매개변수가 있습니다. 기본값은 `None`으로 설정되어 있으며, 클라이언트가 추론이 완료될 때까지 무기한으로 기다리게 합니다. 워크플로우에서 더 많은 제어를 원하는 경우 초 단위의 특정한 값으로 설정할 수 있습니다. 타임아웃 딜레이가 만료되면 [`InferenceTimeoutError`]가 발생합니다. 이를 코드에서 처리할 수 있습니다: ```python >>> from huggingface_hub import InferenceClient, InferenceTimeoutError >>> client = InferenceClient(timeout=30) >>> try: ... client.text_to_image(...) ... except InferenceTimeoutError: ... print("Inference timed out after 30s.") ``` ### 이진 입력[[binary-inputs]] 일부 작업에는 이미지 또는 오디오 파일을 처리할 때와 같이 이진 입력이 필요한 경우가 있습니다. 이 경우 [`InferenceClient`]는 최대한 다양한 유형을 융통성 있게 허용합니다: - 원시 `bytes` - 이진으로 열린 파일과 유사한 객체 (`with open("audio.flac", "rb") as f: ...`) - 로컬 파일을 가리키는 경로 (`str` 또는 `Path`) - 원격 파일을 가리키는 URL (`str`) (예: `https://...`). 이 경우 파일은 Inference API로 전송되기 전에 로컬로 다운로드됩니다. ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> client.image_classification("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg") [{'score': 0.9779096841812134, 'label': 'Blenheim spaniel'}, ...] ``` ## 레거시 InferenceAPI 클라이언트[[legacy-inferenceapi-client]] [`InferenceClient`]는 레거시 [`InferenceApi`] 클라이언트를 대체하여 작동합니다. 특정 작업에 대한 지원을 제공하고 [추론 API](https://huggingface.co/docs/api-inference/index) 및 [추론 엔드포인트](https://huggingface.co/docs/inference-endpoints/index)에서 추론을 처리합니다. 아래는 [`InferenceApi`]에서 [`InferenceClient`]로 마이그레이션하는 데 도움이 되는 간단한 가이드입니다. ### 초기화[[initialization]] 변경 전: ```python >>> from huggingface_hub import InferenceApi >>> inference = InferenceApi(repo_id="bert-base-uncased", token=API_TOKEN) ``` 변경 후: ```python >>> from huggingface_hub import InferenceClient >>> inference = InferenceClient(model="bert-base-uncased", token=API_TOKEN) ``` ### 특정 작업에서 실행하기[[run-on-a-specific-task]] 변경 전: ```python >>> from huggingface_hub import InferenceApi >>> inference = InferenceApi(repo_id="paraphrase-xlm-r-multilingual-v1", task="feature-extraction") >>> inference(...) ``` 변경 후: ```python >>> from huggingface_hub import InferenceClient >>> inference = InferenceClient() >>> inference.feature_extraction(..., model="paraphrase-xlm-r-multilingual-v1") ``` 위의 방법은 코드를 [`InferenceClient`]에 맞게 조정하는 권장 방법입니다. 이렇게 하면 `feature_extraction`과 같이 작업에 특화된 메소드를 활용할 수 있습니다. ### 사용자 정의 요청 실행[[run-custom-request]] 변경 전: ```python >>> from huggingface_hub import InferenceApi >>> inference = InferenceApi(repo_id="bert-base-uncased") >>> inference(inputs="The goal of life is [MASK].") [{'sequence': 'the goal of life is life.', 'score': 0.10933292657136917, 'token': 2166, 'token_str': 'life'}] ``` ### 매개변수와 함께 실행하기[[run-with-parameters]] 변경 전: ```python >>> from huggingface_hub import InferenceApi >>> inference = InferenceApi(repo_id="typeform/distilbert-base-uncased-mnli") >>> inputs = "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!" >>> params = {"candidate_labels":["refund", "legal", "faq"]} >>> inference(inputs, params) {'sequence': 'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!', 'labels': ['refund', 'faq', 'legal'], 'scores': [0.9378499388694763, 0.04914155602455139, 0.013008488342165947]} ``` huggingface_hub-0.31.1/docs/source/ko/guides/inference_endpoints.md000066400000000000000000000355551500667546600254210ustar00rootroot00000000000000# 추론 엔드포인트[[inference-endpoints]] 추론 엔드포인트는 Hugging Face가 관리하는 전용 및 자동 확장 인프라에 `transformers`, `sentence-transformers` 및 `diffusers` 모델을 쉽게 배포할 수 있는 안전한 프로덕션 솔루션을 제공합니다. 추론 엔드포인트는 [Hub](https://huggingface.co/models)의 모델로 구축됩니다. 이 가이드에서는 `huggingface_hub`를 사용하여 프로그래밍 방식으로 추론 엔드포인트를 관리하는 방법을 배웁니다. 추론 엔드포인트 제품 자체에 대한 자세한 내용은 [공식 문서](https://huggingface.co/docs/inference-endpoints/index)를 참조하세요. 이 가이드에서는 `huggingface_hub`가 올바르게 설치 및 로그인되어 있다고 가정합니다. 아직 그렇지 않은 경우 [빠른 시작 가이드](https://huggingface.co/docs/huggingface_hub/quick-start#quickstart)를 참조하세요. 추론 엔드포인트 API를 지원하는 최소 버전은 `v0.19.0`입니다. ## 추론 엔드포인트 생성[[create-an-inference-endpoint]] 첫 번째 단계는 [`create_inference_endpoint`]를 사용하여 추론 엔드포인트를 생성하는 것입니다: ```py >>> from huggingface_hub import create_inference_endpoint >>> endpoint = create_inference_endpoint( ... "my-endpoint-name", ... repository="gpt2", ... framework="pytorch", ... task="text-generation", ... accelerator="cpu", ... vendor="aws", ... region="us-east-1", ... type="protected", ... instance_size="x2", ... instance_type="intel-icl" ... ) ``` 예시에서는 `"my-endpoint-name"`라는 `protected` 추론 엔드포인트를 생성하여 `text-generation`을 위한 [gpt2](https://huggingface.co/gpt2)를 제공합니다. `protected` 추론 엔드포인트 API에 액세스하려면 토큰이 필요합니다. 또한 벤더, 지역, 액셀러레이터, 인스턴스 유형, 크기와 같은 하드웨어 요구 사항을 구성하기 위한 추가 정보를 제공해야 합니다. 사용 가능한 리소스 목록은 [여기](https://api.endpoints.huggingface.cloud/#/v2%3A%3Aprovider/list_vendors)에서 확인할 수 있습니다. 또한 [웹 인터페이스](https://ui.endpoints.huggingface.co/new)를 사용하여 편리하게 수동으로 추론 엔드포인트를 생성할 수 있습니다. 고급 설정 및 사용법에 대한 자세한 내용은 [이 가이드](https://huggingface.co/docs/inference-endpoints/guides/advanced)를 참조하세요. [`create_inference_endpoint`]에서 반환된 값은 [`InferenceEndpoint`] 개체입니다: ```py >>> endpoint InferenceEndpoint(name='my-endpoint-name', namespace='Wauplin', repository='gpt2', status='pending', url=None) ``` 이것은 엔드포인트에 대한 정보를 저장하는 데이터클래스입니다. `name`, `repository`, `status`, `task`, `created_at`, `updated_at` 등과 같은 중요한 속성에 접근할 수 있습니다. 필요한 경우 `endpoint.raw`를 통해 서버로부터의 원시 응답에도 접근할 수 있습니다. 추론 엔드포인트가 생성되면 [개인 대시보드](https://ui.endpoints.huggingface.co/)에서 확인할 수 있습니다. ![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/huggingface_hub/inference_endpoints_created.png) #### 사용자 정의 이미지 사용[[using-a-custom-image]] 기본적으로 추론 엔드포인트는 Hugging Face에서 제공하는 도커 이미지로 구축됩니다. 그러나 `custom_image` 매개변수를 사용하여 모든 도커 이미지를 지정할 수 있습니다. 일반적인 사용 사례는 [text-generation-inference](https://github.com/huggingface/text-generation-inference) 프레임워크를 사용하여 LLM을 실행하는 것입니다. 다음과 같이 수행할 수 있습니다: ```python # TGI에서 Zephyr-7b-beta를 실행하는 추론 엔드포인트 시작하기 >>> from huggingface_hub import create_inference_endpoint >>> endpoint = create_inference_endpoint( ... "aws-zephyr-7b-beta-0486", ... repository="HuggingFaceH4/zephyr-7b-beta", ... framework="pytorch", ... task="text-generation", ... accelerator="gpu", ... vendor="aws", ... region="us-east-1", ... type="protected", ... instance_size="x1", ... instance_type="nvidia-a10g", ... custom_image={ ... "health_route": "/health", ... "env": { ... "MAX_BATCH_PREFILL_TOKENS": "2048", ... "MAX_INPUT_LENGTH": "1024", ... "MAX_TOTAL_TOKENS": "1512", ... "MODEL_ID": "/repository" ... }, ... "url": "ghcr.io/huggingface/text-generation-inference:1.1.0", ... }, ... ) ``` `custom_image`에 전달할 값은 도커 컨테이너의 URL과 이를 실행하기 위한 구성이 포함된 딕셔너리입니다. 자세한 내용은 [Swagger 문서](https://api.endpoints.huggingface.cloud/#/v2%3A%3Aendpoint/create_endpoint)를 참조하세요. ### 기존 추론 엔드포인트 가져오기 또는 리스트 조회[[get-or-list-existing-inference-endpoints]] 경우에 따라 이전에 생성한 추론 엔드포인트를 관리해야 할 수 있습니다. 이름을 알고 있는 경우 [`get_inference_endpoint`]를 사용하여 [`InferenceEndpoint`] 개체를 가져올 수 있습니다. 또는 [`list_inference_endpoints`]를 사용하여 모든 추론 엔드포인트 리스트를 검색할 수 있습니다. 두 메소드 모두 선택적 `namespace` 매개변수를 허용합니다. 속해 있는 조직의 `namespace`를 설정할 수 있습니다. 그렇지 않으면 기본적으로 사용자 이름이 사용됩니다. ```py >>> from huggingface_hub import get_inference_endpoint, list_inference_endpoints # 엔드포인트 개체 가져오기 >>> get_inference_endpoint("my-endpoint-name") InferenceEndpoint(name='my-endpoint-name', namespace='Wauplin', repository='gpt2', status='pending', url=None) # 조직의 모든 추론 엔드포인트 나열 >>> list_inference_endpoints(namespace="huggingface") [InferenceEndpoint(name='aws-starchat-beta', namespace='huggingface', repository='HuggingFaceH4/starchat-beta', status='paused', url=None), ...] # 사용자가 속해있는 모든 조직의 엔드포인트 나열 >>> list_inference_endpoints(namespace="*") [InferenceEndpoint(name='aws-starchat-beta', namespace='huggingface', repository='HuggingFaceH4/starchat-beta', status='paused', url=None), ...] ``` ## 배포 상태 확인[[check-deployment-status]] 이 가이드의 나머지 부분에서는 `endpoint`라는 이름의 [`InferenceEndpoint`] 객체를 가지고 있다고 가정합니다. 엔드포인트에 `status` 속성이 [`InferenceEndpointStatus`] 유형이라는 것을 알 수 있었습니다. 추론 엔드포인트가 배포되고 접근 가능하면 상태가 `"running"`이 되고 `url` 속성이 설정됩니다: ```py >>> endpoint InferenceEndpoint(name='my-endpoint-name', namespace='Wauplin', repository='gpt2', status='running', url='https://jpj7k2q4j805b727.us-east-1.aws.endpoints.huggingface.cloud') ``` `추론 엔드포인트가 "running"` 상태에 도달하기 전에 일반적으로 `"initializing"` 또는 `"pending"` 단계를 거칩니다. [`~InferenceEndpoint.fetch`]를 실행하여 엔드포인트의 새로운 상태를 가져올 수 있습니다. [`InferenceEndpoint`]의 다른 메소드와 마찬가지로 이 메소드는 서버에 요청을 하며, `endpoint`의 내부 속성이 변경됩니다: ```py >>> endpoint.fetch() InferenceEndpoint(name='my-endpoint-name', namespace='Wauplin', repository='gpt2', status='pending', url=None) ``` 추론 엔드포인트가 실행될 때까지 기다리면서 상태를 가져오는 대신 [`~InferenceEndpoint.wait`]를 직접 호출할 수 있습니다. 이 헬퍼는 `timeout`과 `fetch_every` 매개변수를 입력으로 받아 (초 단위) 추론 엔드포인트가 배포될 때까지 스레드를 차단합니다. 기본값은 각각 `None`(제한 시간 없음)과 `5`초입니다. ```py # 엔드포인트 보류 >>> endpoint InferenceEndpoint(name='my-endpoint-name', namespace='Wauplin', repository='gpt2', status='pending', url=None) # 10초 대기 => InferenceEndpointTimeoutError 발생 >>> endpoint.wait(timeout=10) raise InferenceEndpointTimeoutError("Timeout while waiting for Inference Endpoint to be deployed.") huggingface_hub._inference_endpoints.InferenceEndpointTimeoutError: Timeout while waiting for Inference Endpoint to be deployed. # 추가 대기 >>> endpoint.wait() InferenceEndpoint(name='my-endpoint-name', namespace='Wauplin', repository='gpt2', status='running', url='https://jpj7k2q4j805b727.us-east-1.aws.endpoints.huggingface.cloud') ``` `timeout`이 설정되어 있고 추론 엔드포인트를 불러오는 데 너무 오래 걸리면, [`InferenceEndpointTimeoutError`] 제한 시간 초과 오류가 발생합니다. ## 추론 실행[[run-inference]] 추론 엔드포인트가 실행되면, 마침내 추론을 실행할 수 있습니다! [`InferenceEndpoint`]에는 각각 [`InferenceClient`]와 [`AsyncInferenceClient`]를 반환하는 `client`와 `async_client` 속성이 있습니다. ```py # 텍스트 생성 작업 실행: >>> endpoint.client.text_generation("I am") ' not a fan of the idea of a "big-budget" movie. I think it\'s a' # 비동기 컨텍스트에서도 마찬가지로 실행: >>> await endpoint.async_client.text_generation("I am") ``` 추론 엔드포인트가 실행 중이 아니면 [`InferenceEndpointError`] 오류가 발생합니다: ```py >>> endpoint.client huggingface_hub._inference_endpoints.InferenceEndpointError: Cannot create a client for this Inference Endpoint as it is not yet deployed. Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again. ``` [`InferenceClient`]를 사용하는 방법에 대한 자세한 내용은 [추론 가이드](../guides/inference)를 참조하세요. ## 라이프사이클 관리[[manage-lifecycle]] 이제 추론 엔드포인트를 생성하고 추론을 실행하는 방법을 살펴보았으니, 라이프사이클을 관리하는 방법을 살펴봅시다. 이 섹션에서는 [`~InferenceEndpoint.pause`], [`~InferenceEndpoint.resume`], [`~InferenceEndpoint.scale_to_zero`], [`~InferenceEndpoint.update`] 및 [`~InferenceEndpoint.delete`] 등의 메소드를 살펴볼 것입니다. 모든 메소드는 편의를 위해 [`InferenceEndpoint`]에 추가된 별칭입니다. 원한다면 `HfApi`에 정의된 일반 메소드 [`pause_inference_endpoint`], [`resume_inference_endpoint`], [`scale_to_zero_inference_endpoint`], [`update_inference_endpoint`] 및 [`delete_inference_endpoint`]를 사용할 수도 있습니다. ### 일시 중지 또는 0으로 확장[[pause-or-scale-to-zero]] 추론 엔드포인트를 사용하지 않을 때 비용을 절감하기 위해 [`~InferenceEndpoint.pause`]를 사용하여 일시 중지하거나 [`~InferenceEndpoint.scale_to_zero`]를 사용하여 0으로 스케일링할 수 있습니다. *일시 중지* 또는 *0으로 스케일링*된 추론 엔드포인트는 비용이 들지 않습니다. 이 두 가지의 차이점은 *일시 중지* 엔드포인트는 [`~InferenceEndpoint.resume`]를 사용하여 명시적으로 *재개*해야 한다는 것입니다. 반대로 *0으로 스케일링*된 엔드포인트는 추론 호출이 있으면 추가 콜드 스타트 지연과 함께 자동으로 시작됩니다. 추론 엔드포인트는 일정 기간 비활성화된 후 자동으로 0으로 스케일링되도록 구성할 수도 있습니다. ```py # 엔드포인트 일시중지 및 재시작 >>> endpoint.pause() InferenceEndpoint(name='my-endpoint-name', namespace='Wauplin', repository='gpt2', status='paused', url=None) >>> endpoint.resume() InferenceEndpoint(name='my-endpoint-name', namespace='Wauplin', repository='gpt2', status='pending', url=None) >>> endpoint.wait().client.text_generation(...) ... # 0으로 스케일링 >>> endpoint.scale_to_zero() InferenceEndpoint(name='my-endpoint-name', namespace='Wauplin', repository='gpt2', status='scaledToZero', url='https://jpj7k2q4j805b727.us-east-1.aws.endpoints.huggingface.cloud') # 엔드포인트는 'running'은 아니지만 URL�을 가지고 있으며 첫 번째 호출 시 다시 시작됩니다. ``` ### 모델 또는 하드웨어 요구 사항 업데이트[[update-model-or-hardware-requirements]] 경우에 따라 새로운 엔드포인트를 생성하지 않고 추론 엔드포인트를 업데이트하고 싶을 수 있습니다. 호스팅된 모델이나 모델 실행에 필요한 하드웨어 요구 사항을 업데이트할 수 있습니다. 이렇게 하려면 [`~InferenceEndpoint.update`]를 사용합니다: ```py # 타겟 모델 변경 >>> endpoint.update(repository="gpt2-large") InferenceEndpoint(name='my-endpoint-name', namespace='Wauplin', repository='gpt2-large', status='pending', url=None) # 복제본 갯수 업데이트 >>> endpoint.update(min_replica=2, max_replica=6) InferenceEndpoint(name='my-endpoint-name', namespace='Wauplin', repository='gpt2-large', status='pending', url=None) # 더 큰 인스턴스로 업데이트 >>> endpoint.update(accelerator="cpu", instance_size="x4", instance_type="intel-icl") InferenceEndpoint(name='my-endpoint-name', namespace='Wauplin', repository='gpt2-large', status='pending', url=None) ``` ### 엔드포인트 삭제[[delete-the-endpoint]] 마지막으로 더 이상 추론 엔드포인트를 사용하지 않을 경우, [`~InferenceEndpoint.delete()`]를 호출하기만 하면 됩니다. 이것은 돌이킬 수 없는 작업이며, 구성, 로그 및 사용 메트릭을 포함한 엔드포인트를 완전히 제거합니다. 삭제된 추론 엔드포인트는 복원할 수 없습니다. ## 엔드 투 엔드 예제[an-end-to-end-example] 추론 엔드포인트의 일반적인 사용 사례는 한 번에 여러 개의 작업을 처리하여 인프라 비용을 제한하는 것입니다. 이 가이드에서 본 것을 사용하여 이 프로세스를 자동화할 수 있습니다: ```py >>> import asyncio >>> from huggingface_hub import create_inference_endpoint # 엔드포인트 시작 + 초기화될 때까지 대기 >>> endpoint = create_inference_endpoint(name="batch-endpoint",...).wait() # 추론 실행 >>> client = endpoint.client >>> results = [client.text_generation(...) for job in jobs] # 비동기 추론 실행 >>> async_client = endpoint.async_client >>> results = asyncio.gather(*[async_client.text_generation(...) for job in jobs]) # 엔드포인트 중지 >>> endpoint.pause() ``` 또는 추론 엔드포인트가 이미 존재하고 일시 중지된 경우: ```py >>> import asyncio >>> from huggingface_hub import get_inference_endpoint # 엔드포인트 가져오기 + 초기화될 때까지 대기 >>> endpoint = get_inference_endpoint("batch-endpoint").resume().wait() # 추론 실행 >>> async_client = endpoint.async_client >>> results = asyncio.gather(*[async_client.text_generation(...) for job in jobs]) # 엔드포인트 중지 >>> endpoint.pause() ``` huggingface_hub-0.31.1/docs/source/ko/guides/integrations.md000066400000000000000000000604401500667546600240750ustar00rootroot00000000000000 # Hub와 어떤 머신 러닝 프레임워크든 통합[[integrate-any-ml-framework-with-the-hub]] Hugging Face Hub는 커뮤니티와 모델을 공유하는 것을 쉽게 만들어줍니다. 이는 오픈소스 생태계의 [수십 가지 라이브러리](https://huggingface.co/docs/hub/models-libraries)를 지원합니다. 저희는 항상 협업적인 머신 러닝을 발전시키기 위해 이 라이브러리를 확대하고자 노력하고 있습니다. `huggingface_hub` 라이브러리는 어떤 Python 스크립트든지 쉽게 파일을 업로드하고 가져올 수 있는 중요한 역할을 합니다. 라이브러리를 Hub와 통합하는 네 가지 주요 방법이 있습니다: 1. **Hub에 업로드하기**: 모델을 Hub에 업로드하는 메소드를 구현합니다. 이에는 모델 가중치뿐만 아니라 [모델 카드](https://huggingface.co/docs/huggingface_hub/how-to-model-cards) 및 모델 실행에 필요한 다른 관련 정보나 데이터(예: 훈련 로그)가 포함됩니다. 이 메소드는 일반적으로 `push_to_hub()`라고 합니다. 2. **Hub에서 다운로드하기**: Hub에서 모델을 가져오는 메소드를 구현합니다. 이 메소드는 모델 구성/가중치를 다운로드하고 모델을 가져와야 합니다. 이 메소드는 일반적으로 `from_pretrained` 또는 `load_from_hub()`라고 합니다. 3. **추론 API**: 라이브러리에서 지원하는 모델에 대해 무료로 추론을 실행할 수 있도록 당사 서버를 사용합니다. 4. **위젯**: Hub의 모델 랜딩 페이지에 위젯을 표시합니다. 이를 통해 사용자들은 브라우저에서 빠르게 모델을 시도할 수 있습니다. 이 가이드에서는 앞의 두 가지 주제에 중점을 둘 것입니다. 우리는 라이브러리를 통합하는 데 사용할 수 있는 두 가지 주요 방법을 소개하고 각각의 장단점을 설명할 것입니다. 두 가지 중 어떤 것을 선택할지에 대한 도움이 되도록 끝 부분에 내용이 요약되어 있습니다. 이는 단지 가이드라는 것을 명심하고 상황에 맞게 적응시킬 수 있는 가이드라는 점을 유념하십시오. 추론 및 위젯에 관심이 있는 경우 [이 가이드](https://huggingface.co/docs/hub/models-adding-libraries#set-up-the-inference-api)를 참조할 수 있습니다. 양쪽 모두에서 라이브러리를 Hub와 통합하고 [문서](https://huggingface.co/docs/hub/models-libraries)에 목록에 게시하고자 하는 경우에는 언제든지 연락하실 수 있습니다. ## 유연한 접근 방식: 도우미(helper)[[a-flexible-approach-helpers]] 라이브러리를 Hub에 통합하는 첫 번째 접근 방법은 실제로 `push_to_hub` 및 `from_pretrained` 메소드를 직접 구현하는 것입니다. 이를 통해 업로드/다운로드할 파일 및 입력을 처리하는 방법에 대한 완전한 유연성을 제공받을 수 있습니다. 이를 위해 [파일 업로드](./upload) 및 [파일 다운로드](./download) 가이드를 참조하여 자세히 알아볼 수 있습니다. 예를 들어 FastAI 통합이 구현된 방법을 보면 됩니다 ([`push_to_hub_fastai`] 및 [`from_pretrained_fastai`]를 참조). 라이브러리마다 구현 방식은 다를 수 있지만, 워크플로우는 일반적으로 비슷합니다. ### from_pretrained[[frompretrained]] 일반적으로 `from_pretrained` 메소드는 다음과 같은 형태를 가집니다: ```python def from_pretrained(model_id: str) -> MyModelClass: # Hub로부터 모델을 다운로드 cached_model = hf_hub_download( repo_id=repo_id, filename="model.pkl", library_name="fastai", library_version=get_fastai_version(), ) # 모델 가져오기 return load_model(cached_model) ``` ### push_to_hub[[pushtohub]] `push_to_hub` 메소드는 종종 리포지토리 생성, 모델 카드 생성 및 가중치 저장을 처리하기 위해 조금 더 복잡한 접근 방식이 필요합니다. 일반적으로 모든 이러한 파일을 임시 폴더에 저장한 다음 업로드하고 나중에 삭제하는 방식이 흔히 사용됩니다. ```python def push_to_hub(model: MyModelClass, repo_name: str) -> None: api = HfApi() # 해당 리포지토리가 아직 없다면 리포지토리를 생성하고 관련된 리포지토리 ID를 가져옵니다. repo_id = api.create_repo(repo_name, exist_ok=True) # 모든 파일을 임시 디렉토리에 저장하고 이를 단일 커밋으로 푸시합니다. with TemporaryDirectory() as tmpdir: tmpdir = Path(tmpdir) # 가중치 저장 save_model(model, tmpdir / "model.safetensors") # model card 생성 card = generate_model_card(model) (tmpdir / "README.md").write_text(card) # 로그 저장 # 설정 저장 # 평가 지표를 저장 # ... # Hub에 푸시 return api.upload_folder(repo_id=repo_id, folder_path=tmpdir) ``` 물론 이는 단순한 예시에 불과합니다. 더 복잡한 조작(원격 파일 삭제, 가중치를 실시간으로 업로드, 로컬로 가중치를 유지 등)에 관심이 있다면 [파일 업로드](./upload) 가이드를 참조해 주세요. ### 제한 사항[[limitations]] 이러한 방식은 유연성을 가지고 있지만, 유지보수 측면에서 일부 단점을 가지고 있습니다. Hugging Face 사용자들은 `huggingface_hub`와 함께 작업할 때 추가 기능에 익숙합니다. 예를 들어, Hub에서 파일을 로드할 때 다음과 같은 매개변수를 제공하는 것이 일반적입니다: - `token`: 개인 리포지토리에서 다운로드하기 위한 토큰 - `revision`: 특정 브랜치에서 다운로드하기 위한 리비전 - `cache_dir`: 특정 디렉터리에 파일을 캐시하기 위한 디렉터리 - `force_download`/`resume_download`/`local_files_only`: 캐시를 재사용할 것인지 여부를 결정하는 매개변수 - `proxies`: HTTP 세션 구성 모델을 푸시할 때는 유사한 매개변수가 지원됩니다: - `commit_message`: 사용자 정의 커밋 메시지 - `private`: 개인 리포지토리를 만들어야 할 경우 - `create_pr`: `main`에 푸시하는 대신 PR을 만드는 경우 - `branch`: `main` 브랜치 대신 브랜치에 푸시하는 경우 - `allow_patterns/ignore_patterns`: 업로드할 파일을 필터링하는 매개변수 - `token` - ... 이러한 매개변수는 위에서 본 구현에 추가하여 `huggingface_hub` 메소드로 전달할 수 있습니다. 그러나 매개변수가 변경되거나 새로운 기능이 추가되는 경우에는 패키지를 업데이트해야 합니다. 이러한 매개변수를 지원하는 것은 유지 관리할 문서가 더 많아진다는 것을 의미합니다. 이러한 제한 사항을 완화할 수 있는 방법을 보려면 다음 섹션인 **클래스 상속**으로 이동해 보겠습니다. ## 더욱 복잡한 접근법: 클래스 상속[[a-more-complex-approach-class-inheritance]] 위에서 보았듯이 Hub와 통합하기 위해 라이브러리에 포함해야 할 주요 메소드는 파일을 업로드 (`push_to_hub`) 와 파일 다운로드 (`from_pretrained`)입니다. 이러한 메소드를 직접 구현할 수 있지만, 이에는 몇 가지 주의할 점이 있습니다. 이를 해결하기 위해 `huggingface_hub`은 클래스 상속을 사용하는 도구를 제공합니다. 이 도구가 어떻게 작동하는지 살펴보겠습니다! 많은 경우에 라이브러리는 이미 Python 클래스를 사용하여 모델을 구현합니다. 이 클래스에는 모델의 속성 및 로드, 실행, 훈련 및 평가하는 메소드가 포함되어 있습니다. 접근 방식은 믹스인을 사용하여 이 클래스를 확장하여 업로드 및 다운로드 기능을 포함하는 것입니다. [믹스인(Mixin)](https://stackoverflow.com/a/547714)은 기존 클래스에 여러 상속을 통해 특정 기능을 확장하기 위해 설계된 클래스입니다. `huggingface_hub`은 자체 믹스인인 [`ModelHubMixin`]을 제공합니다. 여기서 핵심은 동작과 이를 사용자 정의하는 방법을 이해하는 것입니다. [`ModelHubMixin`] 클래스는 세 개의 *공개* 메소드(`push_to_hub`, `save_pretrained`, `from_pretrained`)를 구현합니다. 이 메소드들은 사용자가 라이브러리를 사용하여 모델을 로드/저장할 때 호출하는 메소드입니다. 또한 [`ModelHubMixin`]은 두 개의 *비공개* 메소드(`_save_pretrained` 및 `_from_pretrained`)를 정의합니다. 라이브러리를 통합하려면 이 메소드들을 구현해야 합니다. : 1. 모델 클래스를 [`ModelHubMixin`]에서 상속합니다. 2. 비공개 메소드를 구현합니다: - [`~ModelHubMixin._save_pretrained`]: 디렉터리 경로를 입력으로 받아 모델을 해당 디렉터리에 저장하는 메소드입니다. 이 메소드에는 모델 카드, 모델 가중치, 구성 파일, 훈련 로그 및 그림 등 해당 모델에 대한 모든 관련 정보를 저장하기 위한 로직을 작성해야 합니다. [모델 카드](https://huggingface.co/docs/hub/model-cards)는 모델을 설명하는 데 특히 중요합니다. 더 자세한 내용은 [구현 가이드](./model-cards)를 확인하세요. - [`~ModelHubMixin._from_pretrained`]: `model_id`를 입력으로 받아 인스턴스화된 모델을 반환하는 **클래스 메소드**입니다. 이 메소드는 관련 파일을 다운로드하고 가져와야 합니다. 3. 완료했습니다! [`ModelHubMixin`]의 장점은 파일의 직렬화/로드에만 신경을 쓰면 되기 때문에 즉시 사용할 수 있다는 것입니다. 리포지토리 생성, 커밋, PR 또는 리비전과 같은 사항에 대해 걱정할 필요가 없습니다. [`ModelHubMixin`]은 또한 공개 메소드가 문서화되고 타입에 주석이 달려있는지를 확인하며, Hub 모델의 다운로드 수를 볼 수 있도록 합니다. 이 모든 것은 [`ModelHubMixin`]에 의해 처리되며 사용자에게 제공됩니다. ### 자세한 예시: PyTorch[[a-concrete-example-pytorch]] 위에서 언급한 내용의 좋은 예시는 Pytorch 프레임워크를 통합한 [`PyTorchModelHubMixin`]입니다. 바로 사용 가능할 수 있는 메소드입니다. #### 어떻게 사용하나요?[[how-to-use-it]] 다음은 Hub에서 PyTorch 모델을 로드/저장하는 방법입니다: ```python >>> import torch >>> import torch.nn as nn >>> from huggingface_hub import PyTorchModelHubMixin # PyTorch 모델을 여러분이 흔히 사용하는 방식과 완전히 동일하게 정의하세요. >>> class MyModel( ... nn.Module, ... PyTorchModelHubMixin, # 다중 상속 ... library_name="keras-nlp", ... tags=["keras"], ... repo_url="https://github.com/keras-team/keras-nlp", ... docs_url="https://keras.io/keras_nlp/", ... # ^ 모델 카드를 생성하는 데 선택적인 메타데이터입니다. ... ): ... def __init__(self, hidden_size: int = 512, vocab_size: int = 30000, output_size: int = 4): ... super().__init__() ... self.param = nn.Parameter(torch.rand(hidden_size, vocab_size)) ... self.linear = nn.Linear(output_size, vocab_size) ... def forward(self, x): ... return self.linear(x + self.param) # 1. 모델 생성 >>> model = MyModel(hidden_size=128) # 설정은 입력 및 기본값을 기반으로 자동으로 생성됩니다. >>> model.param.shape[0] 128 # 2. (선택사항) 모델을 로컬 디렉터리에 저장합니다. >>> model.save_pretrained("path/to/my-awesome-model") # 3. 모델 가중치를 Hub에 푸시합니다. >>> model.push_to_hub("my-awesome-model") # 4. Hub로부터 모델을 초기화합니다. => 이때 설정은 보존됩니다. >>> model = MyModel.from_pretrained("username/my-awesome-model") >>> model.param.shape[0] 128 # 모델 카드가 올바르게 작성되었습니다. >>> from huggingface_hub import ModelCard >>> card = ModelCard.load("username/my-awesome-model") >>> card.data.tags ["keras", "pytorch_model_hub_mixin", "model_hub_mixin"] >>> card.data.library_name "keras-nlp" ``` #### 구현[[implementation]] 실제 구현은 매우 간단합니다. 전체 구현은 [여기](https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/hub_mixin.py)에서 찾을 수 있습니다. 1. 클래스를 `ModelHubMixin`으로부터 상속하세요: ```python from huggingface_hub import ModelHubMixin class PyTorchModelHubMixin(ModelHubMixin): (...) ``` 2. `_save_pretrained` 메소드를 구현하세요: ```py from huggingface_hub import ModelHubMixin class PyTorchModelHubMixin(ModelHubMixin): (...) def _save_pretrained(self, save_directory: Path) -> None: """PyTorch 모델의 가중치를 로컬 디렉터리에 저장합니다.""" save_model_as_safetensor(self.module, str(save_directory / SAFETENSORS_SINGLE_FILE)) ``` 3. `_from_pretrained` 메소드를 구현하세요: ```python class PyTorchModelHubMixin(ModelHubMixin): (...) @classmethod # 반드시 클래스 메소드여야 합니다! def _from_pretrained( cls, *, model_id: str, revision: str, cache_dir: str, force_download: bool, proxies: Optional[Dict], resume_download: bool, local_files_only: bool, token: Union[str, bool, None], map_location: str = "cpu", # 추가 인자 strict: bool = False, # 추가 인자 **model_kwargs, ): """PyTorch의 사전 학습된 가중치와 모델을 반환합니다.""" model = cls(**model_kwargs) if os.path.isdir(model_id): print("Loading weights from local directory") model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE) return cls._load_as_safetensor(model, model_file, map_location, strict) model_file = hf_hub_download( repo_id=model_id, filename=SAFETENSORS_SINGLE_FILE, revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, token=token, local_files_only=local_files_only, ) return cls._load_as_safetensor(model, model_file, map_location, strict) ``` 이게 전부입니다! 이제 라이브러리를 통해 Hub로부터 파일을 업로드하고 다운로드할 수 있습니다. ### 고급 사용법[[advanced-usage]] 위의 섹션에서는 [`ModelHubMixin`]이 어떻게 작동하는지 간단히 살펴보았습니다. 이번 섹션에서는 Hugging Face Hub와 라이브러리 통합을 개선하기 위한 더 고급 기능 중 일부를 살펴보겠습니다. #### 모델 카드[[model-card]] [`ModelHubMixin`]은 모델 카드를 자동으로 생성합니다. 모델 카드는 모델과 함께 제공되는 중요한 정보를 제공하는 파일입니다. 모델 카드는 추가 메타데이터가 포함된 간단한 Markdown 파일입니다. 모델 카드는 발견 가능성, 재현성 및 공유를 위해 중요합니다! 더 자세한 내용은 [모델 카드 가이드](https://huggingface.co/docs/hub/model-cards)를 확인하세요. 모델 카드를 반자동으로 생성하는 것은 라이브러리로 푸시된 모든 모델이 `library_name`, `tags`, `license`, `pipeline_tag` 등과 같은 공통 메타데이터를 공유하도록 하는 좋은 방법입니다. 이를 통해 모든 모델이 Hub에서 쉽게 검색 가능하게 되고, Hub에 접속한 사용자에게 일부 리소스 링크를 제공합니다. [`ModelHubMixin`]을 상속할 때 메타데이터를 직접 정의할 수 있습니다: ```py class UniDepthV1( nn.Module, PyTorchModelHubMixin, library_name="unidepth", repo_url="https://github.com/lpiccinelli-eth/UniDepth", docs_url=..., pipeline_tag="depth-estimation", license="cc-by-nc-4.0", tags=["monocular-metric-depth-estimation", "arxiv:1234.56789"] ): ... ``` 기본적으로는 제공된 정보로 일반적인 모델 카드가 생성됩니다(예: [pyp1/VoiceCraft_giga830M](https://huggingface.co/pyp1/VoiceCraft_giga830M)). 그러나 사용자 정의 모델 카드 템플릿을 정의할 수도 있습니다! 이 예에서는 `VoiceCraft` 클래스로 푸시된 모든 모델에 자동으로 인용 부분과 라이선스 세부 정보가 포함됩니다. 모델 카드 템플릿을 정의하는 방법에 대한 자세한 내용은 [모델 카드 가이드](./model-cards)를 참조하세요. ```py MODEL_CARD_TEMPLATE = """ --- # For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 # Doc / guide: https://huggingface.co/docs/hub/model-cards {{ card_data }} --- This is a VoiceCraft model. For more details, please check out the official Github repo: https://github.com/jasonppy/VoiceCraft. This model is shared under a Attribution-NonCommercial-ShareAlike 4.0 International license. ## Citation @article{peng2024voicecraft, author = {Peng, Puyuan and Huang, Po-Yao and Li, Daniel and Mohamed, Abdelrahman and Harwath, David}, title = {VoiceCraft: Zero-Shot Speech Editing and Text-to-Speech in the Wild}, journal = {arXiv}, year = {2024}, } """ class VoiceCraft( nn.Module, PyTorchModelHubMixin, library_name="voicecraft", model_card_template=MODEL_CARD_TEMPLATE, ... ): ... ``` 마지막으로, 모델 카드 생성 프로세스를 동적 값으로 확장하려면 [`~ModelHubMixin.generate_model_card`] 메소드를 재정의할 수 있습니다: ```py from huggingface_hub import ModelCard, PyTorchModelHubMixin class UniDepthV1(nn.Module, PyTorchModelHubMixin, ...): (...) def generate_model_card(self, *args, **kwargs) -> ModelCard: card = super().generate_model_card(*args, **kwargs) card.data.metrics = ... # 메타데이터에 메트릭 추가 card.text += ... # 모델 카드에 섹션 추가 return card ``` #### 구성[[config]] [`ModelHubMixin`]은 모델 구성을 처리합니다. 모델을 인스턴스화할 때 입력 값들을 자동으로 확인하고 이를 `config.json` 파일에 직렬화합니다. 이렇게 함으로써 두 가지 이점이 제공됩니다: 1. 사용자는 정확히 동일한 매개변수로 모델을 다시 가져올 수 있습니다. 2. `config.json` 파일이 자동으로 생성되면 Hub에서 분석이 가능해집니다(즉, "다운로드" 횟수가 기록됩니다). 하지만 이것이 실제로 어떻게 작동하는 걸까요? 사용자 관점에서 프로세스가 가능한 매끄럽도록 하기 위해 여러 규칙이 존재합니다: - 만약 `__init__` 메소드가 `config` 입력을 기대한다면, 이는 자동으로 `config.json`으로 저장됩니다. - 만약 `config` 입력 매개변수에 데이터 클래스 유형(예: `config: Optional[MyConfigClass] = None`)의 어노테이션이 있다면, config 값은 올바르게 역직렬화됩니다. - 초기화할 때 전달된 모든 값들도 구성 파일에 저장됩니다. 이는 `config` 입력을 기대하지 않더라도 이점을 얻을 수 있다는 것을 의미합니다. 예시: ```py class MyModel(ModelHubMixin): def __init__(value: str, size: int = 3): self.value = value self.size = size (...) # _save_pretrained / _from_pretrained 구현 model = MyModel(value="my_value") model.save_pretrained(...) # config.json 파일에는 전달된 값과 기본 값이 모두 포함됩니다. {"value": "my_value", "size": 3} ``` 그러나 값이 JSON으로 직렬화될 수 없는 경우, 기본적으로 구성 파일을 저장할 때 해당 값은 무시됩니다. 그러나 경우에 따라 라이브러리가 이미 직렬화할 수 없는 사용자 정의 객체를 예상하고 있고 해당 유형을 업데이트하고 싶지 않은 경우가 있습니다. 그렇다면 [`ModelHubMixin`]을 상속할 때 어떤 유형에 대한 사용자 지정 인코더/디코더를 전달할 수 있습니다. 이는 조금 더 많은 작업이 필요하지만 내부 로직을 변경하지 않고도 라이브러리를 Hub에 통합할 수 있도록 보장합니다. 여기서 `argparse.Namespace` 구성을 입력으로 받는 클래스의 구체적인 예가 있습니다: ```py class VoiceCraft(nn.Module): def __init__(self, args): self.pattern = self.args.pattern self.hidden_size = self.args.hidden_size ... ``` 한 가지 해결책은 `__init__` 시그니처를 `def __init__(self, pattern: str, hidden_size: int)`로 업데이트하고 클래스를 인스턴스화하는 모든 스니펫을 업데이트하는 것입니다. 이 방법은 유효한 방법이지만, 라이브러리를 사용하는 하위 응용 프로그램을 망가뜨릴 수 있습니다. 다른 해결책은 `argparse.Namespace`를 사전으로 변환하는 간단한 인코더/디코더를 제공하는 것입니다. ```py from argparse import Namespace class VoiceCraft( nn.Module, PyTorchModelHubMixin, # 믹스인을 상속합니다. coders={ Namespace: ( lambda x: vars(x), # Encoder: `Namespace`를 유효한 JSON 형태로 변환하는 방법은 무엇인가요? lambda data: Namespace(**data), # Decoder: 딕셔너리에서 Namespace를 재구성하는 방법은 무엇인가요? ) } ): def __init__(self, args: Namespace): # `args`에 주석을 답니다. self.pattern = self.args.pattern self.hidden_size = self.args.hidden_size ... ``` 위의 코드 스니펫에서는 클래스의 내부 로직과 `__init__` 시그니처가 변경되지 않았습니다. 이는 기존의 모든 코드 스니펫이 여전히 작동한다는 것을 의미합니다. 이를 달성하기 위해 다음 과정을 수행하면 됩니다: 1. 믹스인(`PytorchModelHubMixin`)으로부터 상속합니다. 2. 상속 시 `coders` 매개변수를 전달합니다. 이는 키가 처리하려는 사용자 지정 유형이고, 값은 튜플 `(인코더, 디코더)`입니다. - 인코더는 지정된 유형의 객체를 입력으로 받아서 jsonable 값으로 반환합니다. 이는 `save_pretrained`로 모델을 저장할 때 사용됩니다. - 디코더는 원시 데이터(일반적으로 딕셔너리 타입)를 입력으로 받아서 초기 객체를 재구성합니다. 이는 `from_pretrained`로 모델을 로드할 때 사용됩니다. - `__init__` 시그니처에 유형 주석을 추가합니다. 이는 믹스인에게 클래스가 기대하는 유형과, 따라서 어떤 디코더를 사용해야 하는지를 알려주는 데 중요합니다. 위의 예제는 간단한 예시이기 때문에 인코더/디코더 함수는 견고하지 않습니다. 구체적인 구현을 위해서는 코너 케이스를 적절하게 처리해야 할 것입니다. ## 빠른 비교[[quick-comparison]] 두 가지 접근 방법에 대한 장단점을 간단히 정리해보겠습니다. 아래 표는 단순히 예시일 뿐입니다. 각자 다른 프레임워크에는 고려해야 할 특정 사항이 있을 수 있습니다. 이 가이드는 통합을 다루는 아이디어와 지침을 제공하기 위한 것입니다. 언제든지 궁금한 점이 있으면 문의해 주세요! | 통합 | helpers 사용 시 | [`ModelHubMixin`] 사용 시 | |:---:|:---:|:---:| | 사용자 경험 | `model = load_from_hub(...)`
`push_to_hub(model, ...)` | `model = MyModel.from_pretrained(...)`
`model.push_to_hub(...)` | | 유연성 | 매우 유연합니다.
구현을 완전히 제어합니다. | 유연성이 떨어집니다.
프레임워크에는 모델 클래스가 있어야 합니다. | | 유지 관리 | 구성 및 새로운 기능에 대한 지원을 추가하기 위한 유지 관리가 더 필요합니다. 사용자가 보고한 문제를 해결해야할 수도 있습니다. | Hub와의 대부분의 상호 작용이 `huggingface_hub`에서 구현되므로 유지 관리가 줄어듭니다. | | 문서화 / 타입 주석 | 수동으로 작성해야 합니다. | `huggingface_hub`에서 부분적으로 처리됩니다. | | 다운로드 횟수 표시기 | 수동으로 처리해야 합니다. | 클래스에 `config` 속성이 있다면 기본적으로 활성화됩니다. | | 모델 카드 | 수동으로 처리해야 합니다. | library_name, tags 등을 활용하여 기본적으로 생성됩니다. | huggingface_hub-0.31.1/docs/source/ko/guides/manage-cache.md000066400000000000000000000713031500667546600236600ustar00rootroot00000000000000 # `huggingface_hub` 캐시 시스템 관리하기[[manage-huggingfacehub-cache-system]] ## 캐싱 이해하기[[understand-caching]] Hugging Face Hub 캐시 시스템은 Hub에 의존하는 라이브러리 간에 공유되는 중앙 캐시로 설계되었습니다. v0.8.0에서 수정한 파일 간에 다시 다운로드하는 것을 방지하도록 업데이트되었습니다. 캐시 시스템은 다음과 같이 설계되었습니다: ``` ├─ ├─ ├─ ``` ``는 보통 사용자의 홈 디렉토리입니다. 그러나 모든 메소드에서 `cache_dir` 인수를 사용하거나 `HF_HOME` 또는 `HF_HUB_CACHE` 환경 변수를 지정하여 사용자 정의할 수 있습니다. 모델, 데이터셋, 스페이스는 공통된 루트를 공유합니다. 각 리포지토리는 리포지토리 유형과 네임스페이스(조직 또는 사용자 이름이 있을 경우), 리포지토리 이름을 포함합니다: ``` ├─ models--julien-c--EsperBERTo-small ├─ models--lysandrejik--arxiv-nlp ├─ models--bert-base-cased ├─ datasets--glue ├─ datasets--huggingface--DataMeasurementsFiles ├─ spaces--dalle-mini--dalle-mini ``` Hub로부터 모든 파일이 이 폴더들 안에 다운로드됩니다. 캐싱은 파일이 이미 존재하고 업데이트되지 않은 경우, 파일을 두 번 다운로드하지 않도록 해줍니다. 하지만 파일이 업데이트되었고 최신 파일을 요청하면, 최신 파일을 다운로드합니다 (이전 파일은 그대로 유지되어 필요할 때 다시 사용할 수 있습니다). 이를 위해 모든 폴더는 동일한 구조를 가집니다: ``` ├─ datasets--glue │ ├─ refs │ ├─ blobs │ ├─ snapshots ... ``` 각 폴더는 다음과 같은 내용을 포함하도록 구성되었습니다: ### Refs[[refs]] `refs` 폴더에는 주어진 참조의 최신 수정 버전을 나타내는 파일이 포함되어 있습니다. 예를 들어, 이전에 리포지토리의 `main` 브랜치에서 파일을 가져온 경우, `refs` 폴더에는 `main`이라는 이름의 파일이 포함되며, 이 파일 자체에는 현재 헤드의 커밋 식별자가 들어 있습니다. 만약 `main`의 최신 커밋 식별자가 `aaaaaa`라면, 그 파일에는 `aaaaaa`가 들어 있습니다. 같은 브랜치가 새로운 커밋으로 업데이트되어 `bbbbbb`라는 식별자를 갖게 되면, 해당 참조에서 파일을 다시 다운로드할 때 `refs/main` 파일은 `bbbbbb`로 업데이트됩니다. ### Blobs[[blobs]] `blobs` 폴더에는 실제로 다운로드된 파일이 포함되어 있습니다. 각 파일의 이름은 해당 파일의 해시값입니다. ### Snapshots[[snapshots]] `snapshots` 폴더에는 위에서 언급한 blobs에 대한 심볼릭 링크가 포함되어 있습니다. 이 폴더는 여러 개의 하위 폴더로 구성되어 있으며, 각 폴더는 알려진 수정 버전을 나타냅니다. 위 설명에서, 처음에 `aaaaaa` 버전에서 파일을 가져왔고, 그 후에 `bbbbbb` 버전에서 파일을 가져왔습니다. 이 상황에서 `snapshots` 폴더에는 `aaaaaa`와 `bbbbbb`라는 두 개의 폴더가 있습니다. 이 폴더들 각각에는 다운로드한 파일의 이름을 가진 심볼릭 링크가 있습니다. 예를 들어, `aaaaaa` 버전에서 `README.md` 파일을 다운로드했다면, 다음과 같은 경로가 생깁니다: ``` //snapshots/aaaaaa/README.md ``` 그 `README.md` 파일은 실제로 해당 파일의 해시를 가진 blob에 대한 심볼릭 링크입니다. 이와 같은 구조를 생성함으로써 파일 공유 메커니즘이 열리게 됩니다. 동일한 파일을 `bbbbbb` 버전에서 가져온 경우, 동일한 해시를 가지게 되어 파일을 다시 다운로드할 필요가 없습니다. ### .no_exist (advanced)[[noexist-advanced]] `blobs`, `refs`, `snapshots` 폴더 외에도 캐시에서 `.no_exist` 폴더를 찾을 수 있습니다. 이 폴더는 한 번 다운로드하려고 시도했지만 Hub에 존재하지 않는 파일을 기록합니다. 이 폴더의 구조는 `snapshots` 폴더와 동일하며, 알려진 각 수정 버전에 대해 하나의 하위 폴더를 갖습니다: ``` //.no_exist/aaaaaa/config_that_does_not_exist.json ``` `snapshots` 폴더와 달리, 파일은 단순히 빈 파일입니다 (심볼릭 링크가 아님). 이 예에서 `"config_that_does_not_exist.json"` 파일은 `"aaaaaa"` 버전에 대해 Hub에 존재하지 않습니다. 빈 파일만 저장하므로, 이 폴더는 디스크 사용량을 크게 차지하지 않기에 무시할 수 있습니다. 그렇다면 이제 여러분은 왜 이 정보가 관련이 있는지 궁금해 할지도 모릅니다. 몇몇 경우에서는 프레임워크가 모델에 대한 옵션 파일들을 불러오려고 시도합니다. 존재하지 않는 옵션 파일들을 저장하면 가능한 옵션 파일당 1개의 HTTP 호출을 절약할 수 있어 모델을 더 빠르게 불러올 수 있습니다. 이는 예를 들어 각 토크나이저가 추가 파일을 지원하는 `transformers`에서 발생합니다. 처음으로 토크나이저를 로드할 때, 다음 초기화를 위해 로딩 시간을 더 빠르게 하기 위해 옵션 파일이 존재하는지 여부를 캐시합니다. HTTP 요청을 만들지 않고 로컬로 캐시된 파일이 있는지 테스트하려면, [`try_to_load_from_cache`] 헬퍼를 사용할 수 있습니다. 이것은 파일이 존재하고 캐시된 경우에는 파일 경로를, 존재하지 않음이 캐시된 경우에는 `_CACHED_NO_EXIST` 객체를, 알 수 없는 경우에는 `None`을 반환합니다. ```python from huggingface_hub import try_to_load_from_cache, _CACHED_NO_EXIST filepath = try_to_load_from_cache() if isinstance(filepath, str): # 파일이 존재하고 캐시됩니다 ... elif filepath is _CACHED_NO_EXIST: # 파일의 존재여부가 캐시됩니다 ... else: # 파일은 캐시되지 않습니다 ... ``` ### 캐시 구조 예시[[in-practice]] 실제로는 캐시는 다음과 같은 트리 구조를 가질 것입니다: ```text [ 96] . └── [ 160] models--julien-c--EsperBERTo-small ├── [ 160] blobs │ ├── [321M] 403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd │ ├── [ 398] 7cb18dc9bafbfcf74629a4b760af1b160957a83e │ └── [1.4K] d7edf6bd2a681fb0175f7735299831ee1b22b812 ├── [ 96] refs │ └── [ 40] main └── [ 128] snapshots ├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f │ ├── [ 52] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812 │ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd └── [ 128] bbc77c8132af1cc5cf678da3f1ddf2de43606d48 ├── [ 52] README.md -> ../../blobs/7cb18dc9bafbfcf74629a4b760af1b160957a83e └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd ``` ### 제한사항[[limitations]] 효율적인 캐시 시스템을 갖기 위해 `huggingface-hub`은 심볼릭 링크를 사용합니다. 그러나 모든 기기에서 심볼릭 링크를 지원하지는 않습니다. 특히 Windows에서 이러한 한계가 있다는 것이 알려져 있습니다. 이런 경우에는 `huggingface_hub`이 `blobs/` 디렉터리를 사용하지 않고 대신 파일을 직접 `snapshots/` 디렉터리에 저장합니다. 이 해결책을 통해 사용자는 Hub에서 파일을 다운로드하고 캐시하는 방식을 정확히 동일하게 사용할 수 있습니다. 캐시를 검사하고 삭제하는 도구들도 지원됩니다. 그러나 캐시 시스템은 동일한 리포지토리의 여러 수정 버전을 다운로드하는 경우 같은 파일이 여러 번 다운로드될 수 있기 때문에 효율적이지 않을 수 있습니다. Windows 기기에서 심볼릭 링크 기반 캐시 시스템의 이점을 누리려면, [개발자 모드를 활성화](https://docs.microsoft.com/ko-kr/windows/apps/get-started/enable-your-device-for-development)하거나 Python을 관리자 권한으로 실행해야 합니다. 심볼릭 링크가 지원되지 않는 경우, 사용자에게 캐시 시스템의 낮은 버전을 사용 중임을 알리는 경고 메시지가 표시됩니다. 이 경고는 `HF_HUB_DISABLE_SYMLINKS_WARNING` 환경 변수를 true로 설정하여 비활성화할 수 있습니다. ## 캐싱 자산[[caching-assets]] Hub에서 파일을 캐시하는 것 외에도, 하위 라이브러리들은 종종 `huggingface_hub`에 직접 처리되지 않는 HF와 관련된 다른 파일을 캐시해야 할 때가 있습니다 (예: GitHub에서 다운로드한 파일, 전처리된 데이터, 로그 등). 이러한 파일, 즉 '자산(assets)'을 캐시하기 위해 [`cached_assets_path`]를 사용할 수 있습니다. 이 헬퍼는 요청한 라이브러리의 이름과 선택적으로 네임스페이스 및 하위 폴더 이름을 기반으로 HF 캐시의 경로를 통일된 방식으로 생성합니다. 목표는 모든 하위 라이브러리가 자산을 자체 방식대로(예: 구조에 대한 규칙 없음) 관리할 수 있도록 하는 것입니다. 그러나 올바른 자산 폴더 내에 있어야 합니다. 그러한 라이브러리는 `huggingface_hub`의 도구를 활용하여 캐시를 관리할 수 있으며, 특히 CLI 명령을 통해 자산의 일부를 스캔하고 삭제할 수 있습니다. ```py from huggingface_hub import cached_assets_path assets_path = cached_assets_path(library_name="datasets", namespace="SQuAD", subfolder="download") something_path = assets_path / "something.json" # 자산 폴더에서 원하는 대로 작업하세요! ``` [`cached_assets_path`]는 자산을 저장하는 권장 방법이지만 필수는 아닙니다. 이미 라이브러리가 자체 캐시를 사용하는 경우 해당 캐시를 자유롭게 사용하세요! ### 자산 캐시 구조 예시[[assets-in-practice]] 실제로는 자산 캐시는 다음과 같은 트리 구조를 가질 것입니다: ```text assets/ └── datasets/ │ ├── SQuAD/ │ │ ├── downloaded/ │ │ ├── extracted/ │ │ └── processed/ │ ├── Helsinki-NLP--tatoeba_mt/ │ ├── downloaded/ │ ├── extracted/ │ └── processed/ └── transformers/ ├── default/ │ ├── something/ ├── bert-base-cased/ │ ├── default/ │ └── training/ hub/ └── models--julien-c--EsperBERTo-small/ ├── blobs/ │ ├── (...) │ ├── (...) ├── refs/ │ └── (...) └── [ 128] snapshots/ ├── 2439f60ef33a0d46d85da5001d52aeda5b00ce9f/ │ ├── (...) └── bbc77c8132af1cc5cf678da3f1ddf2de43606d48/ └── (...) ``` ## 캐시 스캔하기[[scan-your-cache]] 현재 캐시된 파일은 로컬 디렉토리에서 삭제되지 않습니다. 브랜치의 새로운 수정 버전을 다운로드할 때 이전 파일은 다시 필요할 경우를 대비하여 보관됩니다. 따라서 디스크 공간을 많이 차지하는 리포지토리와 수정 버전을 파악하기 위해 캐시 디렉토리를 스캔하는 것이 유용할 수 있습니다. `huggingface_hub`은 이를 수행할 수 있는 헬퍼를 제공하며, `huggingface-cli`를 통해 또는 Python 스크립트에서 사용할 수 있습니다. ### 터미널에서 캐시 스캔하기[[scan-cache-from-the-terminal]] HF 캐시 시스템을 스캔하는 가장 쉬운 방법은 `huggingface-cli` 도구의 `scan-cache` 명령을 사용하는 것입니다. 이 명령은 캐시를 스캔하고 리포지토리 ID, 리포지토리 유형, 디스크 사용량, 참조 및 전체 로컬 경로와 같은 정보가 포함된 보고서를 출력합니다. 아래 코드 조각은 4개의 모델과 2개의 데이터셋이 캐시된 폴더에서의 스캔 보고서를 보여줍니다. ```text ➜ huggingface-cli scan-cache REPO ID REPO TYPE SIZE ON DISK NB FILES LAST_ACCESSED LAST_MODIFIED REFS LOCAL PATH --------------------------- --------- ------------ -------- ------------- ------------- ------------------- ------------------------------------------------------------------------- glue dataset 116.3K 15 4 days ago 4 days ago 2.4.0, main, 1.17.0 /home/wauplin/.cache/huggingface/hub/datasets--glue google/fleurs dataset 64.9M 6 1 week ago 1 week ago refs/pr/1, main /home/wauplin/.cache/huggingface/hub/datasets--google--fleurs Jean-Baptiste/camembert-ner model 441.0M 7 2 weeks ago 16 hours ago main /home/wauplin/.cache/huggingface/hub/models--Jean-Baptiste--camembert-ner bert-base-cased model 1.9G 13 1 week ago 2 years ago /home/wauplin/.cache/huggingface/hub/models--bert-base-cased t5-base model 10.1K 3 3 months ago 3 months ago main /home/wauplin/.cache/huggingface/hub/models--t5-base t5-small model 970.7M 11 3 days ago 3 days ago refs/pr/1, main /home/wauplin/.cache/huggingface/hub/models--t5-small Done in 0.0s. Scanned 6 repo(s) for a total of 3.4G. Got 1 warning(s) while scanning. Use -vvv to print details. ``` 더 자세한 보고서를 얻으려면 `--verbose` 옵션을 사용하세요. 각 리포지토리에 대해 다운로드된 모든 수정 버전의 목록을 얻게 됩니다. 위에서 설명한대로, 2개의 수정 버전 사이에 변경되지 않는 파일들은 심볼릭 링크를 통해 공유됩니다. 이는 디스크 상의 리포지토리 크기가 각 수정 버전의 크기의 합보다 작을 것으로 예상됨을 의미합니다. 예를 들어, 여기서 `bert-base-cased`는 1.4G와 1.5G의 두 가지 수정 버전이 있지만 총 디스크 사용량은 단 1.9G입니다. ```text ➜ huggingface-cli scan-cache -v REPO ID REPO TYPE REVISION SIZE ON DISK NB FILES LAST_MODIFIED REFS LOCAL PATH --------------------------- --------- ---------------------------------------- ------------ -------- ------------- ----------- ---------------------------------------------------------------------------------------------------------------------------- glue dataset 9338f7b671827df886678df2bdd7cc7b4f36dffd 97.7K 14 4 days ago main, 2.4.0 /home/wauplin/.cache/huggingface/hub/datasets--glue/snapshots/9338f7b671827df886678df2bdd7cc7b4f36dffd glue dataset f021ae41c879fcabcf823648ec685e3fead91fe7 97.8K 14 1 week ago 1.17.0 /home/wauplin/.cache/huggingface/hub/datasets--glue/snapshots/f021ae41c879fcabcf823648ec685e3fead91fe7 google/fleurs dataset 129b6e96cf1967cd5d2b9b6aec75ce6cce7c89e8 25.4K 3 2 weeks ago refs/pr/1 /home/wauplin/.cache/huggingface/hub/datasets--google--fleurs/snapshots/129b6e96cf1967cd5d2b9b6aec75ce6cce7c89e8 google/fleurs dataset 24f85a01eb955224ca3946e70050869c56446805 64.9M 4 1 week ago main /home/wauplin/.cache/huggingface/hub/datasets--google--fleurs/snapshots/24f85a01eb955224ca3946e70050869c56446805 Jean-Baptiste/camembert-ner model dbec8489a1c44ecad9da8a9185115bccabd799fe 441.0M 7 16 hours ago main /home/wauplin/.cache/huggingface/hub/models--Jean-Baptiste--camembert-ner/snapshots/dbec8489a1c44ecad9da8a9185115bccabd799fe bert-base-cased model 378aa1bda6387fd00e824948ebe3488630ad8565 1.5G 9 2 years ago /home/wauplin/.cache/huggingface/hub/models--bert-base-cased/snapshots/378aa1bda6387fd00e824948ebe3488630ad8565 bert-base-cased model a8d257ba9925ef39f3036bfc338acf5283c512d9 1.4G 9 3 days ago main /home/wauplin/.cache/huggingface/hub/models--bert-base-cased/snapshots/a8d257ba9925ef39f3036bfc338acf5283c512d9 t5-base model 23aa4f41cb7c08d4b05c8f327b22bfa0eb8c7ad9 10.1K 3 1 week ago main /home/wauplin/.cache/huggingface/hub/models--t5-base/snapshots/23aa4f41cb7c08d4b05c8f327b22bfa0eb8c7ad9 t5-small model 98ffebbb27340ec1b1abd7c45da12c253ee1882a 726.2M 6 1 week ago refs/pr/1 /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/98ffebbb27340ec1b1abd7c45da12c253ee1882a t5-small model d0a119eedb3718e34c648e594394474cf95e0617 485.8M 6 4 weeks ago /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/d0a119eedb3718e34c648e594394474cf95e0617 t5-small model d78aea13fa7ecd06c29e3e46195d6341255065d5 970.7M 9 1 week ago main /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/d78aea13fa7ecd06c29e3e46195d6341255065d5 Done in 0.0s. Scanned 6 repo(s) for a total of 3.4G. Got 1 warning(s) while scanning. Use -vvv to print details. ``` #### Grep 예시[[grep-example]] 출력이 테이블 형식으로 되어 있기 때문에 `grep`과 유사한 도구를 사용하여 항목을 필터링할 수 있습니다. 여기에는 Unix 기반 머신에서 "t5-small" 모델의 수정 버전만 필터링하는 예제가 있습니다. ```text ➜ eval "huggingface-cli scan-cache -v" | grep "t5-small" t5-small model 98ffebbb27340ec1b1abd7c45da12c253ee1882a 726.2M 6 1 week ago refs/pr/1 /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/98ffebbb27340ec1b1abd7c45da12c253ee1882a t5-small model d0a119eedb3718e34c648e594394474cf95e0617 485.8M 6 4 weeks ago /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/d0a119eedb3718e34c648e594394474cf95e0617 t5-small model d78aea13fa7ecd06c29e3e46195d6341255065d5 970.7M 9 1 week ago main /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/d78aea13fa7ecd06c29e3e46195d6341255065d5 ``` ### 파이썬에서 캐시 스캔하기[[scan-cache-from-python]] 보다 고급 기능을 사용하려면, CLI 도구에서 호출되는 파이썬 유틸리티인 [`scan_cache_dir`]을 사용할 수 있습니다. 이를 사용하여 4가지 데이터 클래스를 중심으로 구조화된 자세한 보고서를 얻을 수 있습니다: - [`HFCacheInfo`]: [`scan_cache_dir`]에 의해 반환되는 완전한 보고서 - [`CachedRepoInfo`]: 캐시된 리포지토리에 관한 정보 - [`CachedRevisionInfo`]: 리포지토리 내의 캐시된 수정 버전(예: "snapshot")에 관한 정보 - [`CachedFileInfo`]: 스냅샷 내의 캐시된 파일에 관한 정보 다음은 간단한 사용 예시입니다. 자세한 내용은 참조 문서를 참고하세요. ```py >>> from huggingface_hub import scan_cache_dir >>> hf_cache_info = scan_cache_dir() HFCacheInfo( size_on_disk=3398085269, repos=frozenset({ CachedRepoInfo( repo_id='t5-small', repo_type='model', repo_path=PosixPath(...), size_on_disk=970726914, nb_files=11, last_accessed=1662971707.3567169, last_modified=1662971107.3567169, revisions=frozenset({ CachedRevisionInfo( commit_hash='d78aea13fa7ecd06c29e3e46195d6341255065d5', size_on_disk=970726339, snapshot_path=PosixPath(...), # 수정 버전 간에 blobs가 공유되기 때문에 `last_accessed`가 없습니다. last_modified=1662971107.3567169, files=frozenset({ CachedFileInfo( file_name='config.json', size_on_disk=1197 file_path=PosixPath(...), blob_path=PosixPath(...), blob_last_accessed=1662971707.3567169, blob_last_modified=1662971107.3567169, ), CachedFileInfo(...), ... }), ), CachedRevisionInfo(...), ... }), ), CachedRepoInfo(...), ... }), warnings=[ CorruptedCacheException("Snapshots dir doesn't exist in cached repo: ..."), CorruptedCacheException(...), ... ], ) ``` ## 캐시 정리하기[[clean-your-cache]] 캐시를 스캔하는 것은 흥미로울 수 있지만 실제로 해야 할 다음 작업은 일반적으로 드라이브의 일부 공간을 확보하기 위해 일부를 삭제하는 것입니다. 이는 `delete-cache` CLI 명령을 사용하여 가능합니다. 또한 캐시를 스캔할 때 반환되는 [`HFCacheInfo`] 객체에서 [`~HFCacheInfo.delete_revisions`] 헬퍼를 사용하여 프로그래밍 방식으로도 사용할 수 있습니다. ### 전략적으로 삭제하기[[delete-strategy]] 캐시를 삭제하려면 삭제할 수정 버전 목록을 전달해야 합니다. 이 도구는 이 목록을 기반으로 공간을 확보하기 위한 전략을 정의합니다. 이는 어떤 파일과 폴더가 삭제될지를 설명하는 [`DeleteCacheStrategy`] 객체를 반환합니다. [`DeleteCacheStrategy`]를 통해 사용 가능한 공간을 확보 할 수 있습니다. 삭제에 동의하면 삭제를 실행하여 삭제를 유효하게 만들어야 합니다. 불일치를 피하기 위해 전략 객체를 수동으로 편집할 수 없습니다. 수정 버전을 삭제하기 위한 전략은 다음과 같습니다: - 수정 버전 심볼릭 링크가 있는 `snapshot` 폴더가 삭제됩니다. - 삭제할 수정 버전에만 대상이 되는 blobs 파일도 삭제됩니다. - 수정 버전이 1개 이상의 `refs`에 연결되어 있는 경우, 참조가 삭제됩니다. - 리포지토리의 모든 수정 버전이 삭제되는 경우 전체 캐시된 리포지토리가 삭제됩니다. 수정 버전 해시는 모든 리포지토리를 통틀어 고유합니다. 이는 수정 버전을 제거할 때 `repo_id`나 `repo_type`을 제공할 필요가 없음을 의미합니다. 캐시에서 수정 버전을 찾을 수 없는 경우 무시됩니다. 또한 삭제 중에 파일 또는 폴더를 찾을 수 없는 경우 경고가 기록되지만 오류가 발생하지 않습니다. [`DeleteCacheStrategy`] 객체에 포함된 다른 경로에 대해 삭제가 계속됩니다. ### 터미널에서 캐시 정리하기[[clean-cache-from-the-terminal]] HF 캐시 시스템에서 일부 수정 버전을 삭제하는 가장 쉬운 방법은 `huggingface-cli` 도구의 `delete-cache` 명령을 사용하는 것입니다. 이 명령에는 두 가지 모드가 있습니다. 기본적으로 사용자에게 삭제할 수정 버전을 선택하도록 TUI(터미널 사용자 인터페이스)가 표시됩니다. 이 TUI는 현재 베타 버전으로, 모든 플랫폼에서 테스트되지 않았습니다. 만약 TUI가 작동하지 않는다면 `--disable-tui` 플래그를 사용하여 비활성화할 수 있습니다. #### TUI 사용하기[[using-the-tui]] 이것은 기본 모드입니다. 이를 사용하려면 먼저 다음 명령을 실행하여 추가 종속성을 설치해야 합니다: ``` pip install huggingface_hub["cli"] ``` 그러고 명령어를 실행합니다: ``` huggingface-cli delete-cache ``` 이제 선택/해제할 수 있는 수정 버전 목록이 표시됩니다:
사용방법: - `up` 및 `down` 키를 사용하여 커서를 이동합니다. - `space` 키를 눌러 항목을 토글(선택/해제)합니다. - 수정 버전이 선택된 경우 첫 번째 줄이 업데이트되어 얼마나 많은 공간이 해제될지 표시됩니다. - 선택을 확인하려면 `enter` 키를 누릅니다. - 작업을 취소하고 종료하려면 첫 번째 항목("None of the following")을 선택합니다. 이 항목이 선택된 경우, 다른 항목이 선택되었는지 여부에 관계없이 삭제 프로세스가 취소됩니다. 그렇지 않으면 `ctrl+c` 를 눌러 TUI를 종료할 수도 있습니다. 삭제할 수정 버전을 선택하고 `enter` 를 누르면 마지막 확인 메시지가 표시됩니다. 다시 `enter` 를 누르면 삭제됩니다. 취소하려면 `n` 을 입력하세요. ```txt ✗ huggingface-cli delete-cache --dir ~/.cache/huggingface/hub ? Select revisions to delete: 2 revision(s) selected. ? 2 revisions selected counting for 3.1G. Confirm deletion ? Yes Start deletion. Done. Deleted 1 repo(s) and 0 revision(s) for a total of 3.1G. ``` #### TUI 없이 작업하기[[without-tui]] 위에서 언급한대로, TUI 모드는 현재 베타 버전이며 선택 사항입니다. 사용 중인 기기에서 작동하지 않을 수도 있거나 편리하지 않을 수 있습니다. 다른 방법은 `--disable-tui` 플래그를 사용하는 것입니다. 이 프로세스는 TUI 모드와 매우 유사하게 삭제할 수정 버전 목록을 수동으로 검토하라는 요청이 표시됩니다. 그러나 이 수동 단계는 터미널에서 직접 발생하는 것이 아니라 임시 파일에 자동으로 생성되며, 이를 수동으로 편집할 수 있습니다. 이 파일에는 헤더에 필요한 모든 사용방법이 포함되어 있습니다. 텍스트 편집기에서 이 파일을 열어 `#`으로 주석 처리/해제하면 수정 버전을 쉽게 선택/해제 할 수 있습니다. 검토를 완료하고 파일 편집이 완료되었다면 터미널로 돌아가 ``를 눌러 파일을 저장하세요. 기본적으로 업데이트된 수정 버전 목록으로 확보될 공간의 양을 계산합니다. 파일을 계속 편집할 수도 있고, `"y"`를 눌러 변경 사항을 확정할 수 있습니다. ```sh huggingface-cli delete-cache --disable-tui ``` Example of command file: ```txt # INSTRUCTIONS # ------------ # This is a temporary file created by running `huggingface-cli delete-cache` with the # `--disable-tui` option. It contains a set of revisions that can be deleted from your # local cache directory. # # Please manually review the revisions you want to delete: # - Revision hashes can be commented out with '#'. # - Only non-commented revisions in this file will be deleted. # - Revision hashes that are removed from this file are ignored as well. # - If `CANCEL_DELETION` line is uncommented, the all cache deletion is cancelled and # no changes will be applied. # # Once you've manually reviewed this file, please confirm deletion in the terminal. This # file will be automatically removed once done. # ------------ # KILL SWITCH # ------------ # Un-comment following line to completely cancel the deletion process # CANCEL_DELETION # ------------ # REVISIONS # ------------ # Dataset chrisjay/crowd-speech-africa (761.7M, used 5 days ago) ebedcd8c55c90d39fd27126d29d8484566cd27ca # Refs: main # modified 5 days ago # Dataset oscar (3.3M, used 4 days ago) # 916f956518279c5e60c63902ebdf3ddf9fa9d629 # Refs: main # modified 4 days ago # Dataset wikiann (804.1K, used 2 weeks ago) 89d089624b6323d69dcd9e5eb2def0551887a73a # Refs: main # modified 2 weeks ago # Dataset z-uo/male-LJSpeech-italian (5.5G, used 5 days ago) # 9cfa5647b32c0a30d0adfca06bf198d82192a0d1 # Refs: main # modified 5 days ago ``` ### 파이썬에서 캐시 정리하기[[clean-cache-from-python]] 더 유연하게 사용하려면, 프로그래밍 방식으로 [`~HFCacheInfo.delete_revisions`] 메소드를 사용할 수도 있습니다. 간단한 예제를 살펴보겠습니다. 자세한 내용은 참조 문서를 확인하세요. ```py >>> from huggingface_hub import scan_cache_dir >>> delete_strategy = scan_cache_dir().delete_revisions( ... "81fd1d6e7847c99f5862c9fb81387956d99ec7aa" ... "e2983b237dccf3ab4937c97fa717319a9ca1a96d", ... "6c0e6080953db56375760c0471a8c5f2929baf11", ... ) >>> print("Will free " + delete_strategy.expected_freed_size_str) Will free 8.6G >>> delete_strategy.execute() Cache deletion done. Saved 8.6G. ``` huggingface_hub-0.31.1/docs/source/ko/guides/manage-spaces.md000066400000000000000000000345361500667546600241020ustar00rootroot00000000000000 # Space 관리하기[[manage-your-space]] 이 가이드에서는 `huggingface_hub`를 사용하여 Space 런타임([보안 정보](https://huggingface.co/docs/hub/spaces-overview#managing-secrets), [하드웨어](https://huggingface.co/docs/hub/spaces-gpus) 및 [저장소](https://huggingface.co/docs/hub/spaces-storage#persistent-storage))를 관리하는 방법을 살펴보겠습니다. ## 간단한 예제: 보안 정보 및 하드웨어 구성하기.[[a-simple-example-configure-secrets-and-hardware]] 다음은 Hub에서 Space를 생성하고 설정하는 통합 예시입니다. **1. Hub에 Space 생성하기.** ```py >>> from huggingface_hub import HfApi >>> repo_id = "Wauplin/my-cool-training-space" >>> api = HfApi() # Gradio SDK 예제 >>> api.create_repo(repo_id=repo_id, repo_type="space", space_sdk="gradio") ``` **1. (bis) Space 복제하기.** 기존의 Space에서부터 시작하는 대신 새로운 Space를 구축하고 싶을 때 유용할 수 있습니다. 또한 공개된 Space의 구성/설정을 제어하고 싶을 때도 유용합니다. 자세한 내용은 [`duplicate_space`]를 참조하세요. ```py >>> api.duplicate_space("multimodalart/dreambooth-training") ``` **2. 선호하는 솔루션을 사용하여 코드 업로드하기.** 다음은 로컬 폴더 `src/`를 사용자의 컴퓨터에서 Space로 업로드하는 예시입니다: ```py >>> api.upload_folder(repo_id=repo_id, repo_type="space", folder_path="src/") ``` 이 단계에서는 앱이 이미 무료로 Hub에서 실행 중이어야 합니다! 그러나 더 많은 보안 정보와 업그레이드된 하드웨어를 이용하여 추가적으로 구성할 수 있습니다. **3. 보안 정보와 변수 설정하기** Space에서 작동하려면 일부 보안 키, 토큰 또는 변수가 필요할 수 있습니다. 자세한 내용은 [문서](https://huggingface.co/docs/hub/spaces-overview#managing-secrets)를 참조하세요. Space에서 생성된 HF 토큰으로 이미지 데이터 세트를 Hub에 업로드하는 경우를 예로 들어봅시다. ```py >>> api.add_space_secret(repo_id=repo_id, key="HF_TOKEN", value="hf_api_***") >>> api.add_space_variable(repo_id=repo_id, key="MODEL_REPO_ID", value="user/repo") ``` 보안 정보와 변수는 삭제할 수도 있습니다: ```py >>> api.delete_space_secret(repo_id=repo_id, key="HF_TOKEN") >>> api.delete_space_variable(repo_id=repo_id, key="MODEL_REPO_ID") ``` Space 내에서 보안 정보는 환경 변수로 사용할 수 있습니다 (Streamlit를 사용하는 경우 Streamlit Secrets를 사용). API를 통해 가져올 필요가 없습니다! Space 구성(보안 정보 또는 하드웨어)이 변경되면 앱이 다시 시작됩니다. **보너스: Space 생성 또는 복제 시 보안 정보와 변수 설정하기!** Space를 생성하거나 복제할 때 보안 정보와 변수를 설정할 수 있습니다: ```py >>> api.create_repo( ... repo_id=repo_id, ... repo_type="space", ... space_sdk="gradio", ... space_secrets=[{"key"="HF_TOKEN", "value"="hf_api_***"}, ...], ... space_variables=[{"key"="MODEL_REPO_ID", "value"="user/repo"}, ...], ... ) ``` ```py >>> api.duplicate_space( ... from_id=repo_id, ... secrets=[{"key"="HF_TOKEN", "value"="hf_api_***"}, ...], ... variables=[{"key"="MODEL_REPO_ID", "value"="user/repo"}, ...], ... ) ``` **4. 하드웨어 구성** 기본적으로 Space는 무료로 CPU 환경에서 실행됩니다. GPU에서 실행하기 위해 하드웨어를 업그레이드 할 수도 있습니다. 하드웨어를 업그레이드하려면 결제 카드 또는 커뮤니티 그랜트가 필요합니다. 자세한 내용은 [문서](https://huggingface.co/docs/hub/spaces-gpus)를 참조하세요. ```py # `SpaceHardware` enum 사용 >>> from huggingface_hub import SpaceHardware >>> api.request_space_hardware(repo_id=repo_id, hardware=SpaceHardware.T4_MEDIUM) # 또는 간단히 문자열 값 전달 >>> api.request_space_hardware(repo_id=repo_id, hardware="t4-medium") ``` Space가 서버에서 다시 로드되어야 하기 때문에 하드웨어 업데이트는 즉시 이루어지지 않습니다. Space가 어떤 하드웨어에서 실행되고 있는지 언제든지 확인하여 요청이 충족되었는지 확인할 수 있습니다. ```py >>> runtime = api.get_space_runtime(repo_id=repo_id) >>> runtime.stage "RUNNING_BUILDING" >>> runtime.hardware "cpu-basic" >>> runtime.requested_hardware "t4-medium" ``` 이제 완전히 구성된 Space를 가지게 되었습니다. 사용이 끝난 후에는 Space를 "cpu-classic"으로 다운그레이드하는 것을 잊지 마세요. **보너스: Space를 생성하거나 복제할 때 하드웨어 요청하기!** Space가 구축되면 업그레이드된 하드웨어가 자동으로 할당됩니다. ```py >>> api.create_repo( ... repo_id=repo_id, ... repo_type="space", ... space_sdk="gradio" ... space_hardware="cpu-upgrade", ... space_storage="small", ... space_sleep_time="7200", # 2시간을 초로 환산 ... ) ``` ```py >>> api.duplicate_space( ... from_id=repo_id, ... hardware="cpu-upgrade", ... storage="small", ... sleep_time="7200", # 2시간을 초로 환산 ... ) ``` **5. Space 일시 중지 및 다시 시작** 기본적으로 Space가 업그레이드된 하드웨어에서 실행 중이면 절대로 중단되지 않습니다. 그러나 요금이 부과되는 것을 피하려면 사용하지 않을 때 일시 중지하는 것이 좋습니다. 이는 [`pause_space`]를 사용하여 가능합니다. 일시 중지된 Space는 Space 소유자가 UI를 통해 또는 [`restart_space`]를 사용하여 API를 통해 다시 시작할 때까지 비활성화됩니다. 일시 중지된 모드에 대한 자세한 내용은 [이 섹션](https://huggingface.co/docs/hub/spaces-gpus#pause)을 참조하세요. ```py # 과금을 피하기 위해 Space를 일시 중지하세요 >>> api.pause_space(repo_id=repo_id) # (...) # 필요할 때 다시 시작하세요 >>> api.restart_space(repo_id=repo_id) ``` 다른 방법은 Space에 대한 제한 시간을 설정하는 것입니다. Space가 제한 시간을 초과하여 비활성화되면 Space가 sleep 상태로 전환됩니다. Space를 방문한 방문자가 다시 시작시킬 수 있습니다. [`set_space_sleep_time`]를 사용하여 제한 시간을 설정할 수 있습니다. Sleeping 모드에 대한 자세한 내용은 [이 섹션](https://huggingface.co/docs/hub/spaces-gpus#sleep-time)을 참조하세요. ```py # 동작이 멈춘 후 1시간 후에 Space를 sleep 상태로 설정하세요 >>> api.set_space_sleep_time(repo_id=repo_id, sleep_time=3600) ``` 참고: 'cpu-basic' 하드웨어를 사용하는 경우 사용자 정의 sleep 시간을 구성할 수 없습니다. Space가 48시간 동안 동작을 멈추면 자동으로 일시 중지됩니다. **보너스: 하드웨어를 요청하는 동안 sleep 시간 설정하기** 업그레이드된 하드웨어가 Space에 자동으로 할당됩니다. ```py >>> api.request_space_hardware(repo_id=repo_id, hardware=SpaceHardware.T4_MEDIUM, sleep_time=3600) ``` **보너스: Space를 생성하거나 복제할 때 sleep 시간 설정하기!** ```py >>> api.create_repo( ... repo_id=repo_id, ... repo_type="space", ... space_sdk="gradio" ... space_hardware="t4-medium", ... space_sleep_time="3600", ... ) ``` ```py >>> api.duplicate_space( ... from_id=repo_id, ... hardware="t4-medium", ... sleep_time="3600", ... ) ``` **6. Space에 지속적으로 저장소 추가하기** Space를 다시 시작할 때 지속적으로 디스크 공간에 접근할 수 있는 원하는 저장소 계층을 선택할 수 있습니다. 이는 기존의 하드 드라이브와 같이 디스크에서 읽고 쓸 수 있음을 의미합니다. 자세한 내용은 [문서](https://huggingface.co/docs/hub/spaces-storage#persistent-storage)를 참조하세요. ```py >>> from huggingface_hub import SpaceStorage >>> api.request_space_storage(repo_id=repo_id, storage=SpaceStorage.LARGE) ``` 또한 모든 데이터를 영구적으로 삭제하여 저장소를 삭제할 수 있습니다. ```py >>> api.delete_space_storage(repo_id=repo_id) ``` 참고: 한 번 승인된 저장소의 저장소 계층을 낮출 수 없습니다. 그렇게 하려면, 먼저 저장소를 삭제한 다음 새로운 원하는 계층을 요청해야 합니다. **보너스: Space를 생성하거나 복제할 때 저장소 요청하기!** ```py >>> api.create_repo( ... repo_id=repo_id, ... repo_type="space", ... space_sdk="gradio" ... space_storage="large", ... ) ``` ```py >>> api.duplicate_space( ... from_id=repo_id, ... storage="large", ... ) ``` ## 고급 기능: Space를 일시적으로 업그레이드하기![[more-advanced-temporarily-upgrade-your-space-]] Space는 다양한 사용 사례를 허용합니다. 때로는 특정 하드웨어에서 Space를 일시적으로 실행한 다음 무언가를 수행한 후 종료하고 싶을 수 있습니다. 이 섹션에서는 Space를 활용하여 필요할 때 모델을 세밀하게 조정하는 방법에 대해 탐색할 것입니다. 이는 특정 문제를 해결하는 한 가지 방법에 불과합니다. 이를 바탕으로 사용 사례에 맞게 조정해서 사용해야 합니다. 모델을 세밀하게 조정하기 위한 Space가 있다고 가정해 봅시다. 입력으로 모델 ID와 데이터 세트 ID를 받는 Gradio 앱입니다. 작업 흐름은 다음과 같습니다: 0. (사용자에게 모델과 데이터 세트를 요청) 1. Hub에서 모델을 로드합니다. 2. Hub에서 데이터 세트를 로드합니다. 3. 데이터 세트로 모델을 미세 조정합니다. 4. 새 모델을 Hub에 업로드합니다. 단계 3.에서는 사용자 정의 하드웨어가 필요하지만 유료 GPU에서 Space를 항상 실행하고 싶지는 않을 것입니다. 이 때는 학습을 위해 하드웨어를 동적으로 요청한 다음 종료해야 합니다. 하드웨어를 요청하면 Space가 다시 시작되므로 앱은 현재 수행 중인 작업을 어떻게든 "기억"해야 합니다. 이를 수행하는 여러 가지 방법이 있습니다. 이 가이드에서는 "작업 스케줄러"로서 Dataset을 사용하는 하나의 해결책을 살펴보겠습니다. ### 앱 구조[[app-skeleton]] 다음은 구현된 앱의 모습입니다. 시작할 때 예약된 작업이 있는지 확인하고 있다면 적절한 하드웨어에서 실행합니다. 작업이 완료되면 하드웨어를 무료 요금제 CPU로 다시 설정하고 사용자에게 새 작업을 요청합니다. 이 예시는 일반적인 데모처럼 병렬 액세스를 지원하지 않습니다. 특히 학습이 진행되는 동안 인터페이스가 비활성화됩니다. 저장소를 개인으로 설정하여 단일 사용자임을 보장하는 것이 좋습니다. ```py # Space는 하드웨어를 요청하기 위해 토큰이 필요합니다: Secret으로 설정하세요! HF_TOKEN = os.environ.get("HF_TOKEN") # Space를 가진 repo_id TRAINING_SPACE_ID = "Wauplin/dreambooth-training" from huggingface_hub import HfApi, SpaceHardware api = HfApi(token=HF_TOKEN) # Space 시작 시 예약된 작업을 확인합니다. 예약된 작업이 있는 경우 모델을 미세 조정합니다. 그렇지 않은 경우, # 새 작업을 요청할 수 있는 인터페이스를 표시합니다. task = get_task() if task is None: # Gradio 앱 시작 def gradio_fn(task): # 사용자 요청 시 작업 추가 및 하드웨어 요청 add_task(task) api.request_space_hardware(repo_id=TRAINING_SPACE_ID, hardware=SpaceHardware.T4_MEDIUM) gr.Interface(fn=gradio_fn, ...).launch() else: runtime = api.get_space_runtime(repo_id=TRAINING_SPACE_ID) # GPU를 사용 중인지 확인합니다. if runtime.hardware == SpaceHardware.T4_MEDIUM: # 그렇다면, 기본 모델을 데이터 세트로 미세 조정합니다! train_and_upload(task) # 그런 다음, 작업을 "DONE"으로 표시합니다. mark_as_done(task) # 잊지 말아야 할 것: CPU 하드웨어로 다시 설정 api.request_space_hardware(repo_id=TRAINING_SPACE_ID, hardware=SpaceHardware.CPU_BASIC) else: api.request_space_hardware(repo_id=TRAINING_SPACE_ID, hardware=SpaceHardware.T4_MEDIUM) ``` ### 작업 스케줄러[[task-scheduler]] 작업 스케줄링은 여러 가지 방법으로 수행할 수 있습니다. 여기에는 간단한 CSV 파일을 데이터 세트로 사용하여 작업 스케줄링을 하는 예시입니다. ```py # 'tasks.csv' 파일을 포함하는 데이터 세트의 Dataset ID. # 여기서는 입력(기본 모델 및 데이터 세트)과 상태(PENDING 또는 DONE)가 포함된 'tasks.csv' 기본 예제가 주어집니다. # multimodalart/sd-fine-tunable,Wauplin/concept-1,DONE # multimodalart/sd-fine-tunable,Wauplin/concept-2,PENDING TASK_DATASET_ID = "Wauplin/dreambooth-task-scheduler" def _get_csv_file(): return hf_hub_download(repo_id=TASK_DATASET_ID, filename="tasks.csv", repo_type="dataset", token=HF_TOKEN) def get_task(): with open(_get_csv_file()) as csv_file: csv_reader = csv.reader(csv_file, delimiter=',') for row in csv_reader: if row[2] == "PENDING": return row[0], row[1] # model_id, dataset_id def add_task(task): model_id, dataset_id = task with open(_get_csv_file()) as csv_file: with open(csv_file, "r") as f: tasks = f.read() api.upload_file( repo_id=repo_id, repo_type=repo_type, path_in_repo="tasks.csv", # 작업을 추가하기 위한 빠르고 더러운 방법 path_or_fileobj=(tasks + f"\n{model_id},{dataset_id},PENDING").encode() ) def mark_as_done(task): model_id, dataset_id = task with open(_get_csv_file()) as csv_file: with open(csv_file, "r") as f: tasks = f.read() api.upload_file( repo_id=repo_id, repo_type=repo_type, path_in_repo="tasks.csv", # 작업을 DONE으로 설정하는 빠르고 더러운 방법 path_or_fileobj=tasks.replace( f"{model_id},{dataset_id},PENDING", f"{model_id},{dataset_id},DONE" ).encode() ) ``` huggingface_hub-0.31.1/docs/source/ko/guides/model-cards.md000066400000000000000000000271031500667546600235600ustar00rootroot00000000000000 # 모델 카드 생성 및 공유[[create-and-share-model-cards]] `huggingface_hub` 라이브러리는 모델 카드를 생성, 공유, 업데이트할 수 있는 파이썬 인터페이스를 제공합니다. Hub의 모델 카드가 무엇인지, 그리고 실제로 어떻게 작동하는지에 대한 자세한 내용을 확인하려면 [전용 설명 페이지](https://huggingface.co/docs/hub/models-cards)를 방문하세요. [신규 (베타)! 우리의 실험적인 모델 카드 크리에이터 앱을 사용해 보세요](https://huggingface.co/spaces/huggingface/Model_Cards_Writing_Tool) ## Hub에서 모델 카드 불러오기[[load-a-model-card-from-the-hub]] Hub에서 기존 카드를 불러오려면 [`ModelCard.load`] 기능을 사용하면 됩니다. 이 문서에서는 [`nateraw/vit-base-beans`](https://huggingface.co/nateraw/vit-base-beans)에서 카드를 불러오겠습니다. ```python from huggingface_hub import ModelCard card = ModelCard.load('nateraw/vit-base-beans') ``` 이 카드에는 접근하거나 활용할 수 있는 몇 가지 유용한 속성이 있습니다: - `card.data`: 모델 카드의 메타데이터와 함께 [`ModelCardData`] 인스턴스를 반환합니다. 이 인스턴스에 `.to_dict()`를 호출하여 표현을 사전으로 가져옵니다. - `card.text`: *메타데이터 헤더를 제외*한 카드의 텍스트를 반환합니다. - `card.content`: *메타데이터 헤더를 포함*한 카드의 텍스트 콘텐츠를 반환합니다. ## 모델 카드 만들기[[create-model-cards]] ### 텍스트에서 생성[[from-text]] 텍스트로 모델 카드의 초기 내용을 설정하려면, 카드의 텍스트 내용을 초기화 시 `ModelCard`에 전달하면 됩니다. ```python content = """ --- language: en license: mit --- # 내 모델 카드 """ card = ModelCard(content) card.data.to_dict() == {'language': 'en', 'license': 'mit'} # True ``` 이 작업을 수행하는 또 다른 방법은 f-strings를 사용하는 것입니다. 다음 예에서 우리는: - 모델 카드에 YAML 블록을 삽입할 수 있도록 [`ModelCardData.to_yaml`]을 사용해서 우리가 정의한 메타데이터를 YAML로 변환합니다. - Python f-strings를 통해 템플릿 변수를 사용할 방법을 보여줍니다. ```python card_data = ModelCardData(language='en', license='mit', library='timm') example_template_var = 'nateraw' content = f""" --- { card_data.to_yaml() } --- # 내 모델 카드 이 모델은 [@{example_template_var}](https://github.com/ {example_template_var})에 의해 생성되었습니다 """ card = ModelCard(content) print(card) ``` 위 예시는 다음과 같은 모습의 카드를 남깁니다: ``` --- language: en license: mit library: timm --- # 내 모델 카드 This model created by [@nateraw](https://github.com/nateraw) ``` ### Jinja 템플릿으로부터[[from-a-jinja-template]] `Jinja2`가 설치되어 있으면, jinja 템플릿 파일에서 모델 카드를 만들 수 있습니다. 기본적인 예를 살펴보겠습니다: ```python from pathlib import Path from huggingface_hub import ModelCard, ModelCardData # jinja 템플릿 정의 template_text = """ --- {{ card_data }} --- # MyCoolModel 모델용 모델 카드 이 모델은 이런 저런 것들을 합니다. 이 모델은 [[@{{ author }}](https://hf.co/{{author}})에 의해 생성되었습니다. """.strip() # 템플릿을 파일에 쓰기 Path('custom_template.md').write_text(template_text) # 카드 메타데이터 정의 card_data = ModelCardData(language='en', license='mit', library_name='keras') # 템플릿에서 카드를 만들고 원하는 Jinja 템플릿 변수를 전달합니다. # 우리의 경우에는 작성자를 전달하겠습니다. card = ModelCard.from_template(card_data, template_path='custom_template.md', author='nateraw') card.save('my_model_card_1.md') print(card) ``` 결과 카드의 마크다운은 다음과 같습니다: ``` --- language: en license: mit library_name: keras --- # MyCoolModel 모델용 모델 카드 이 모델은 이런 저런 것들을 합니다. 이 모델은 [@nateraw](https://hf.co/nateraw)에 의해 생성되었습니다. ``` 카드 데이터를 업데이트하면 카드 자체에 반영됩니다. ``` card.data.library_name = 'timm' card.data.language = 'fr' card.data.license = 'apache-2.0' print(card) ``` 이제 보시다시피 메타데이터 헤더가 업데이트되었습니다: ``` --- language: fr license: apache-2.0 library_name: timm --- # MyCoolModel 모델용 모델 카드 이 모델은 이런 저런 것들을 합니다. 이 모델은 [@nateraw](https://hf.co/nateraw)에 의해 생성되었습니다. ``` 카드 데이터를 업데이트할 때 [`ModelCard.validate`]를 불러와 Hub에 대해 카드가 여전히 유효한지 확인할 수 있습니다. 이렇게 하면 Hugging Face Hub에 설정된 모든 유효성 검사 규칙을 통과할 수 있습니다. ### 기본 템플릿으로부터[[from-the-default-template]] 자체 템플릿을 사용하는 대신에, 많은 섹션으로 구성된 기능이 풍부한 [기본 템플릿](https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/modelcard_template.md)을 사용할 수도 있습니다. 내부적으론 [Jinja2](https://jinja.palletsprojects.com/en/3.1.x/) 를 사용하여 템플릿 파일을 작성합니다. `from_template`를 사용하려면 jinja2를 설치해야 합니다. `pip install Jinja2`를 사용하면 됩니다. ```python card_data = ModelCardData(language='en', license='mit', library_name='keras') card = ModelCard.from_template( card_data, model_id='my-cool-model', model_description="this model does this and that", developers="Nate Raw", repo="https://github.com/huggingface/huggingface_hub", ) card.save('my_model_card_2.md') print(card) ``` ## 모델 카드 공유하기[[share-model-cards]] Hugging Face Hub로 인증받은 경우(`huggingface-cli login` 또는 [`login`] 사용) 간단히 [`ModelCard.push_to_hub`]를 호출하여 카드를 Hub에 푸시할 수 있습니다. 이를 수행하는 방법을 살펴보겠습니다. 먼저 인증된 사용자의 네임스페이스 아래에 'hf-hub-modelcards-pr-test'라는 새로운 레포지토리를 만듭니다: ```python from huggingface_hub import whoami, create_repo user = whoami()['name'] repo_id = f'{user}/hf-hub-modelcards-pr-test' url = create_repo(repo_id, exist_ok=True) ``` 그런 다음 기본 템플릿에서 카드를 만듭니다(위 섹션에서 정의한 것과 동일): ```python card_data = ModelCardData(language='en', license='mit', library_name='keras') card = ModelCard.from_template( card_data, model_id='my-cool-model', model_description="this model does this and that", developers="Nate Raw", repo="https://github.com/huggingface/huggingface_hub", ) ``` 마지막으로 이를 Hub로 푸시하겠습니다. ```python card.push_to_hub(repo_id) ``` 결과 카드는 [여기](https://huggingface.co/nateraw/hf-hub-modelcards-pr-test/blob/main/README.md)에서 확인할 수 있습니다. PR로 카드를 푸시하고 싶다면 `push_to_hub`를 호출할 때 `create_pr=True`라고 지정하면 됩니다. ```python card.push_to_hub(repo_id, create_pr=True) ``` 이 명령으로 생성된 결과 PR은 [여기](https://huggingface.co/nateraw/hf-hub-modelcards-pr-test/discussions/3)에서 볼 수 있습니다. ## 메타데이터 업데이트[[update-metadata]] 이 섹션에서는 레포 카드에 있는 메타데이터와 업데이트 방법을 확인합니다. `메타데이터`는 모델, 데이터 세트, Spaces에 대한 높은 수준의 정보를 제공하는 해시맵(또는 키 값) 컨텍스트를 말합니다. 모델의 `pipeline type`, `model_id` 또는 `model_desc` 설명 등의 정보가 포함될 수 있습니다. 자세한 내용은 [모델 카드](https://huggingface.co/docs/hub/model-cards#model-card-metadata), [데이터 세트 카드](https://huggingface.co/docs/hub/datasets-cards#dataset-card-metadata) 및 [�Spaces 설정](https://huggingface.co/docs/hub/spaces-settings#spaces-settings) 을 참조하세요. 이제 메타데이터를 업데이트하는 방법에 대한 몇 가지 예를 살펴보겠습니다. 첫 번째 예부터 살펴보겠습니다: ```python >>> from huggingface_hub import metadata_update >>> metadata_update("username/my-cool-model", {"pipeline_tag": "image-classification"}) ``` 두 줄의 코드를 사용하면 메타데이터를 업데이트하여 새로운 `파이프라인_태그`를 설정할 수 있습니다. 기본적으로 카드에 이미 존재하는 키는 업데이트할 수 없습니다. 그렇게 하려면 `overwrite=True`를 명시적으로 전달해야 합니다. ```python >>> from huggingface_hub import metadata_update >>> metadata_update("username/my-cool-model", {"pipeline_tag": "text-generation"}, overwrite=True) ``` 쓰기 권한이 없는 저장소에 일부 변경 사항을 제안하려는 경우가 종종 있습니다. 소유자가 귀하의 제안을 검토하고 병합할 수 있도록 해당 저장소에 PR을 생성하면 됩니다. ```python >>> from huggingface_hub import metadata_update >>> metadata_update("someone/model", {"pipeline_tag": "text-classification"}, create_pr=True) ``` ## 평가 결과 포함하기[[include-evaluation-results]] 메타데이터 `모델-인덱스`에 평가 결과를 포함하려면 관련 평가 결과와 함께 [EvalResult] 또는 `EvalResult` 목록을 전달하면 됩니다. 내부적으론 `card.data.to _dict()`를 호출하면 `모델-인덱스`가 생성됩니다. 자세한 내용은 [Hub 문서의 이 섹션](https://huggingface.co/docs/hub/models-cards#evaluation-results)을 참조하십시오. 이 기능을 사용하려면 [ModelCardData]에 `model_name` 속성을 포함해야 합니다. ```python card_data = ModelCardData( language='en', license='mit', model_name='my-cool-model', eval_results = EvalResult( task_type='image-classification', dataset_type='beans', dataset_name='Beans', metric_type='accuracy', metric_value=0.7 ) ) card = ModelCard.from_template(card_data) print(card.data) ``` 결과 `card.data`는 다음과 같이 보여야 합니다: ``` language: en license: mit model-index: - name: my-cool-model results: - task: type: image-classification dataset: name: Beans type: beans metrics: - type: accuracy value: 0.7 ``` `EvalResult`: 공유하고 싶은 평가 결과가 둘 이상 있는 경우 `EvalResults` 목록을 전달하기만 하면 됩니다: ```python card_data = ModelCardData( language='en', license='mit', model_name='my-cool-model', eval_results = [ EvalResult( task_type='image-classification', dataset_type='beans', dataset_name='Beans', metric_type='accuracy', metric_value=0.7 ), EvalResult( task_type='image-classification', dataset_type='beans', dataset_name='Beans', metric_type='f1', metric_value=0.65 ) ] ) card = ModelCard.from_template(card_data) card.data ``` 그러면 다음 `card.data`가 남게 됩니다: ``` language: en license: mit model-index: - name: my-cool-model results: - task: type: image-classification dataset: name: Beans type: beans metrics: - type: accuracy value: 0.7 - type: f1 value: 0.65 ``` huggingface_hub-0.31.1/docs/source/ko/guides/overview.md000066400000000000000000000142511500667546600232340ustar00rootroot00000000000000 # How-to 가이드 [[howto-guides]] 특정 목표를 달성하는 데 도움이 되는 실용적인 가이드들입니다. huggingface_hub로 실제 문제를 해결하는 방법을 배우려면 다음 문서들을 살펴보세요. huggingface_hub-0.31.1/docs/source/ko/guides/repository.md000066400000000000000000000306271500667546600236120ustar00rootroot00000000000000 # 리포지토리 생성과 관리[[create-and-manage-a-repository]] Hugging Face Hub는 Git 리포지토리 모음입니다. [Git](https://git-scm.com/)은 협업을 할 때 여러 프로젝트 버전을 쉽게 관리하기 위해 널리 사용되는 소프트웨어 개발 도구입니다. 이 가이드에서는 Hub의 리포지토리 사용법인 다음 내용을 다룹니다: - 리포지토리 생성과 삭제. - 태그 및 브랜치 관리. - 리포지토리 이름 변경. - 리포지토리 공개 여부. - 리포지토리 복사본 관리. GitLab/GitHub/Bitbucket과 같은 플랫폼을 사용해 본 경험이 있다면, 모델 리포지토리를 관리하기 위해 `git` CLI를 사용해 git 리포지토리를 클론(`git clone`)하고 변경 사항을 커밋(`git add, git commit`)하고 커밋한 내용을 푸시(`git push`) 하는것이 가장 먼저 떠오를 것입니다. 이 명령어들은 Hugging Face Hub에서도 사용할 수 있습니다. 하지만 소프트웨어 엔지니어링과 머신러닝은 동일한 요구 사항과 워크플로우를 공유하지 않습니다. 모델 리포지토리는 다양한 프레임워크와 도구를 위한 대규모 모델 가중치 파일을 유지관리 할 수 있으므로, 리포지토리를 복제하면 대규모 로컬 폴더를 유지관리하고 막대한 크기의 파일을 다루게 될 수 있습니다. 결과적으로 Hugging Face의 커스텀 HTTP 방법을 사용하는 것이 더욱 효율적일 수 있습니다. 더 자세한 내용은 [Git vs HTTP paradigm](../concepts/git_vs_http) 문서를 참조하세요. Hub에 리포지토리를 생성하고 관리하려면, 로그인이 되어 있어야 합니다. 로그인이 안 되어있다면 [이 문서](../quick-start#authentication)를 참고해 주세요. 이 가이드에서는 로그인이 되어있다는 가정하에 진행됩니다. ## 리포지토리 생성 및 삭제[[repo-creation-and-deletion]] 첫 번째 단계는 어떻게 리포지토리를 생성하고 삭제하는지를 알아야 합니다. 사용자 이름 네임스페이스 아래에 소유한 리포지토리 또는 쓰기 권한이 있는 조직의 리포지토리만 관리할 수 있습니다. ### 리포지토리 생성[[create-a-repository]] [`create_repo`] 함수로 함께 빈 리포지토리를 만들고 `repo_id` 매개변수를 사용하여 이름을 정하세요. `repo_id`는 사용자 이름 또는 조직 이름 뒤에 리포지토리 이름이 따라옵니다: `username_or_org/repo_name`. ```py >>> from huggingface_hub import create_repo >>> create_repo("lysandre/test-model") 'https://huggingface.co/lysandre/test-model' ``` 기본적으로 [`create_repo`]는 모델 리포지토리를 만듭니다. 하지만 `repo_type` 매개변수를 사용하여 다른 유형의 리포지토리를 지정할 수 있습니다. 예를 들어 데이터셋 리포지토리를 만들고 싶다면: ```py >>> from huggingface_hub import create_repo >>> create_repo("lysandre/test-dataset", repo_type="dataset") 'https://huggingface.co/datasets/lysandre/test-dataset' ``` 리포지토리를 만들 때, `private` 매개변수를 사용하여 가시성을 설정할 수 있습니다. ```py >>> from huggingface_hub import create_repo >>> create_repo("lysandre/test-private", private=True) ``` 추후 리포지토리 가시성을 변경하고 싶다면, [`update_repo_settings`] 함수를 이용해 바꿀 수 있습니다. ### 리포지토리 삭제[[delete-a-repository]] [`delete_repo`]를 사용하여 리포지토리를 삭제할 수 있습니다. 리포지토리를 삭제하기 전에 신중히 결정하세요. 왜냐하면, 삭제하고 나서 다시 되돌릴 수 없는 프로세스이기 때문입니다! 삭제하려는 리포지토리의 `repo_id`를 지정하세요: ```py >>> delete_repo(repo_id="lysandre/my-corrupted-dataset", repo_type="dataset") ``` ### 리포지토리 복제(Spaces 전용)[[duplicate-a-repository-only-for-spaces]] 가끔 다른 누군가의 리포지토리를 복사하여, 상황에 맞게 수정하고 싶을 때가 있습니다. 이는 [`duplicate_space`]를 사용하여 Space에 복사할 수 있습니다. 이 함수를 사용하면 리포지토리 전체를 복제할 수 있습니다. 그러나 여전히 하드웨어, 절전 시간, 리포지토리, 변수 및 비밀번호와 같은 자체 설정을 구성해야 합니다. 자세한 내용은 [Manage your Space](./manage-spaces) 문서를 참조하십시오. ```py >>> from huggingface_hub import duplicate_space >>> duplicate_space("multimodalart/dreambooth-training", private=False) RepoUrl('https://huggingface.co/spaces/nateraw/dreambooth-training',...) ``` ## 파일 다운로드와 업로드[[upload-and-download-files]] 이제 리포지토리를 생성했으므로, 변경 사항을 푸시하고 파일을 다운로드하는 것에 관심이 있을 것입니다. 이 두 가지 주제는 각각 자체 가이드가 필요합니다. 리포지토리 사용하는 방법에 대해 알아보려면 [업로드](./upload) 및 [다운로드](./download) 문서를 참조하세요. ## 브랜치와 태그[[branches-and-tags]] Git 리포지토리는 동일한 리포지토리의 다른 버전을 저장하기 위해 브랜치들을 사용합니다. 태그는 버전을 출시할 때와 같이 리포지토리의 특정 상태를 표시하는 데 사용될 수도 있습니다. 일반적으로 브랜치와 태그는 [git 참조](https://git-scm.com/book/en/v2/Git-Internals-Git-References) 로 참조됩니다. ### 브랜치 생성과 태그[[create-branches-and-tags]] [`create_branch`]와 [`create_tag`]를 이용하여 새로운 브랜치와 태그를 생성할 수 있습니다. ```py >>> from huggingface_hub import create_branch, create_tag # `main` 브랜치를 기반으로 Space 저장소에 새 브랜치를 생성합니다. >>> create_branch("Matthijs/speecht5-tts-demo", repo_type="space", branch="handle-dog-speaker") # `v0.1-release` 브랜치를 기반으로 Dataset 저장소에 태그를 생성합니다. >>> create_tag("bigcode/the-stack", repo_type="dataset", revision="v0.1-release", tag="v0.1.1", tag_message="Bump release version.") ``` 같은 방식으로 [`delete_branch`]와 [`delete_tag`] 함수를 사용하여 브랜치 또는 태그를 삭제할 수 있습니다. ### 모든 브랜치와 태그 나열[[list-all-branches-and-tags]] [`list_repo_refs`]를 사용하여 리포지토리로부터 현재 존재하는 git 참조를 나열할 수 있습니다: ```py >>> from huggingface_hub import list_repo_refs >>> list_repo_refs("bigcode/the-stack", repo_type="dataset") GitRefs( branches=[ GitRefInfo(name='main', ref='refs/heads/main', target_commit='18edc1591d9ce72aa82f56c4431b3c969b210ae3'), GitRefInfo(name='v1.1.a1', ref='refs/heads/v1.1.a1', target_commit='f9826b862d1567f3822d3d25649b0d6d22ace714') ], converts=[], tags=[ GitRefInfo(name='v1.0', ref='refs/tags/v1.0', target_commit='c37a8cd1e382064d8aced5e05543c5f7753834da') ] ) ``` ## 리포지토리 설정 변경[[change-repository-settings]] 리포지토리는 구성할 수 있는 몇 가지 설정이 있습니다. 대부분의 경우 브라우저의 리포지토리 설정 페이지에서 직접 설정할 것입니다. 설정을 바꾸려면 리포지토리에 대한 쓰기 액세스 권한이 있어야 합니다(사용자 리포지토리거나, 조직의 구성원이어야 함). 이 주제에서는 `huggingface_hub`를 사용하여 프로그래밍 방식으로 구성할 수 있는 설정을 알아보겠습니다. Spaces를 위한 특정 설정들(하드웨어, 환경변수 등)을 구성하기 위해서는 [Manage your Spaces](../guides/manage-spaces) 문서를 참조하세요. ### 가시성 업데이트[[update-visibility]] 리포지토리는 공개 또는 비공개로 설정할 수 있습니다. 비공개 리포지토리는 해당 저장소의 사용자 혹은 소속된 조직의 구성원만 볼 수 있습니다. 다음과 같이 리포지토리를 비공개로 변경할 수 있습니다. ```py >>> from huggingface_hub import update_repo_settings >>> update_repo_settings(repo_id=repo_id, private=True) ``` ### 리포지토리 이름 변경[[rename-your-repository]] [`move_repo`]를 사용하여 Hub에 있는 리포지토리 이름을 변경할 수 있습니다. 이 함수를 사용하여 개인에서 조직 리포지토리로 이동할 수도 있습니다. 이렇게 하면 [일부 제한 사항](https://hf.co/docs/hub/repositories-settings#renaming-or-transferring-a-repo)이 있으므로 주의해야 합니다. 예를 들어, 다른 사용자에게 리포지토리를 이전할 수는 없습니다. ```py >>> from huggingface_hub import move_repo >>> move_repo(from_id="Wauplin/cool-model", to_id="huggingface/cool-model") ``` ## 리포지토리의 로컬 복사본 관리[[manage-a-local-copy-of-your-repository]] 위에 설명한 모든 작업은 HTTP 요청을 사용하여 작업할 수 있습니다. 그러나 경우에 따라 로컬 복사본을 가지고 익숙한 Git 명령어를 사용하여 상호 작용하는 것이 편리할 수 있습니다. [`Repository`] 클래스는 Git 명령어와 유사한 기능을 제공하는 함수를 사용하여 Hub의 파일 및 리포지토리와 상호 작용할 수 있습니다. 이는 이미 알고 있고 좋아하는 Git 및 Git-LFS 방법을 사용하는 래퍼(wrapper)입니다. 시작하기 전에 Git-LFS가 설치되어 있는지 확인하세요([여기서](https://git-lfs.github.com/) 설치 지침을 확인할 수 있습니다). [`Repository`]는 [`HfApi`]에 구현된 HTTP 기반 대안을 선호하여 중단되었습니다. 아직 많은 레거시 코드에서 사용되고 있기 때문에 [`Repository`]가 완전히 제거되는 건 `v1.0` 릴리스에서만 이루어집니다. 자세한 내용은 [해당 설명 페이지](./concepts/git_vs_http)를 참조하세요. ### 로컬 리포지토리 사용[[use-a-local-repository]] 로컬 리포지토리 경로를 사용하여 [`Repository`] 객체를 생성하세요: ```py >>> from huggingface_hub import Repository >>> repo = Repository(local_dir="//") ``` ### 복제[[clone]] `clone_from` 매개변수는 Hugging Face 리포지토리 ID에서 로컬 디렉터리로 리포지토리를 복제합니다. 이때 `local_dir` 매개변수를 사용하여 로컬 디렉터리에 저장합니다: ```py >>> from huggingface_hub import Repository >>> repo = Repository(local_dir="w2v2", clone_from="facebook/wav2vec2-large-960h-lv60") ``` `clone_from`은 URL을 사용해 리포지토리를 복제할 수 있습니다. ```py >>> repo = Repository(local_dir="huggingface-hub", clone_from="https://huggingface.co/facebook/wav2vec2-large-960h-lv60") ``` `clone_from` 매개변수를 [`create_repo`]와 결합하여 리포지토리를 만들고 복제할 수 있습니다. ```py >>> repo_url = create_repo(repo_id="repo_name") >>> repo = Repository(local_dir="repo_local_path", clone_from=repo_url) ``` 리포지토리를 복제할 때 `git_user` 및 `git_email` 매개변수를 지정함으로써 복제한 리포지토리에 Git 사용자 이름과 이메일을 설정할 수 있습니다. 사용자가 해당 리포지토리에 커밋하면 Git은 커밋 작성자를 인식합니다. ```py >>> repo = Repository( ... "my-dataset", ... clone_from="/", ... token=True, ... repo_type="dataset", ... git_user="MyName", ... git_email="me@cool.mail" ... ) ``` ### 브랜치[[branch]] 브랜치는 현재 코드와 파일에 영향을 미치지 않으면서 협업과 실험에 중요합니다.[`~Repository.git_checkout`]을 사용하여 브랜치 간에 전환할 수 있습니다. 예를 들어, `branch1`에서 `branch2`로 전환하려면: ```py >>> from huggingface_hub import Repository >>> repo = Repository(local_dir="huggingface-hub", clone_from="/", revision='branch1') >>> repo.git_checkout("branch2") ``` ### 끌어오기[[pull]] [`~Repository.git_pull`]은 원격 리포지토리로부터의 변경사항을 현재 로컬 브랜치에 업데이트하게 합니다. ```py >>> from huggingface_hub import Repository >>> repo.git_pull() ``` 브랜치가 원격에서의 새 커밋으로 업데이트 된 후에 로컬 커밋을 수행하고자 한다면 `rebase=True`를 설정하세요: ```py >>> repo.git_pull(rebase=True) ``` huggingface_hub-0.31.1/docs/source/ko/guides/search.md000066400000000000000000000057121500667546600226350ustar00rootroot00000000000000 # Hub에서 검색하기[[search-the-hub]] 이 튜토리얼에서는 `huggingface_hub`를 사용하여 Hub에서 모델, 데이터 세트 및 Spaces를 검색하는 방법을 배웁니다. ## 리포지토리를 어떻게 나열하나요?[[how-to-list-repositories-]] `huggingface_hub` 라이브러리에는 Hub와 상호작용하기 위한 HTTP 클라이언트[`HfApi`]가 포함되어 있습니다. 이를 통해, Hub에 저장된 모델, 데이터셋, 그리고 Spaces를 나열할 수 있습니다. ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> models = api.list_models() ``` [`list_models`]의 출력은 Hub에 저장되어 있는 모델들을 나열한 결과입니다. 마찬가지로, [`list_datasets`]를 사용하여 데이터 세트를 나열하고 [`list_spaces`]를 사용하여 Spaces를 나열할 수 있습니다. ## 리포지토리를 어떻게 필터링하나요?[[how-to-filter-repositories-]] 리포지토리를 나열하는 것도 유용하지만, 검색을 필터링하고 싶을 수도 있습니다. 리스트에는 다음과 같은 여러 속성이 있습니다. - `filter` - `author` - `search` - ... 이 매개변수 중 두 개는 직관적입니다(`author` 및 `search`). 그렇다면 `filter`는 어떤 것을 나타낼까요? `filter`는 [`ModelFilter`] 객체(또는 [`DatasetFilter`])를 입력으로 받습니다. 이를 이용해 필터링 하고 싶은 모델을 지정하여 인스턴스를 생성할 수 있습니다. PyTorch로 작동되고 imagenet 데이터 세트로 훈련된, 이미지 분류를 위한 Hub의 모든 모델을 찾는 방법으로 예를 들어보겠습니다. 이 과정은 단일 [ModelFilter]를 사용하여 수행할 수 있습니다. 이때, 필터링 속성들은 '논리적 AND'로 결합되어, 지정한 모든 조건을 만족하는 모델만 선택됩니다. ```py models = hf_api.list_models( filter=ModelFilter( task="image-classification", library="pytorch", trained_dataset="imagenet" ) ) ``` 필터링하는 과정에서 모델을 정렬하고 상위 결과만 선택할 수도 있습니다. 다음 예제는 Hub에서 가장 많이 다운로드된 상위 5개 데이터 세트를 가져옵니다. ```py >>> list(list_datasets(sort="downloads", direction=-1, limit=5)) [DatasetInfo( id='argilla/databricks-dolly-15k-curated-en', author='argilla', sha='4dcd1dedbe148307a833c931b21ca456a1fc4281', last_modified=datetime.datetime(2023, 10, 2, 12, 32, 53, tzinfo=datetime.timezone.utc), private=False, downloads=8889377, (...) ``` Hub에서 사용 가능한 필터에 대해 살펴보려면 웹브라우저에서 [모델](https://huggingface.co/models) 및 [데이터 세트](https://huggingface.co/datasets) 페이지를 방문하여 일부 매개변수를 검색한 다음, URL에서 값들을 확인해보세요. huggingface_hub-0.31.1/docs/source/ko/guides/upload.md000066400000000000000000000777671500667546600226770ustar00rootroot00000000000000 # Hub에 파일 업로드하기[[upload-files-to-the-hub]] 파일과 작업물을 공유하는 것은 Hub의 주요 특성 중 하나입니다. `huggingface_hub`는 Hub에 파일을 업로드하기 위한 몇 가지 옵션을 제공합니다. 이러한 기능을 단독으로 사용하거나 라이브러리에 통합하여 해당 라이브러리의 사용자가 Hub와 더 편리하게 상호작용할 수 있도록 도울 수 있습니다. 이 가이드에서는 파일을 푸시하는 다양한 방법에 대해 설명합니다: - Git을 사용하지 않고 푸시하기. - [Git LFS](https://git-lfs.github.com/)를 사용하여 매우 큰 파일을 푸시하기. - `commit` 컨텍스트 매니저를 사용하여 푸시하기. - [`~Repository.push_to_hub`] 함수를 사용하여 푸시하기. Hub에 파일을 업로드 하려면 허깅페이스 계정으로 로그인해야 합니다. 인증에 대한 자세한 내용은 [이 페이지](../quick-start#authentication)를 참조해 주세요. ## 파일 업로드하기[[upload-a-file]] [`create_repo`]로 리포지토리를 생성했다면, [`upload_file`]을 사용하여 해당 리포지토리에 파일을 업로드할 수 있습니다. 업로드할 파일의 본 경로, 리포지토리에서 파일을 업로드할 위치, 대상 리포지토리의 이름을 지정합니다. 리포지토리의 유형을 `dataset`, `model`, `space`로 선택적으로 설정할 수 있습니다. ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.upload_file( ... path_or_fileobj="/path/to/local/folder/README.md", ... path_in_repo="README.md", ... repo_id="username/test-dataset", ... repo_type="dataset", ... ) ``` ## 폴더 업로드[[upload-a-folder]] 로컬 폴더를 리포지토리에 업로드하려면 [`upload_folder`] 함수를 사용합니다. 업로드할 로컬 폴더의 본 경로, 리포지토리에서 폴더를 업로드할 위치, 대상 리포지토리의 이름을 지정합니다. 리포지토리의 유형을 `dataset`, `model`, `space`로 선택적으로 설정할 수 있습니다. ```py >>> from huggingface_hub import HfApi >>> api = HfApi() # 로컬 폴더에 있는 모든 콘텐츠를 원격 Space에 업로드 합니다. # 파일은 기본적으로 리포지토리의 루트 디렉토리에 업로드 됩니다. >>> api.upload_folder( ... folder_path="/path/to/local/space", ... repo_id="username/my-cool-space", ... repo_type="space", ... ) ``` 기본적으로 어떤 파일을 커밋할지 여부를 알기 위해 `.gitignore` 파일을 참조하게 됩니다. 기본적으로 커밋에 `.gitignore` 파일이 있는지 확인하고, 없는 경우 Hub에 파일이 있는지 확인합니다. 디렉터리의 루트 경로에 있는 `.gitignore` 파일만 사용된다는 점을 주의하세요. 하위 경로에는 `.gitignore` 파일이 있는지 확인하지 않습니다. 하드코딩된 `.gitignore` 파일을 사용하지 않으려면 `allow_patterns` 와 `ignore_patterns` 인수를 사용하여 업로드할 파일을 필터링할 수 있습니다. 이 매개변수들은 단일 패턴 또는 패턴 리스트를 허용합니다. 패턴의 형식은 [이 문서](https://tldp.org/LDP/GNU-Linux-Tools-Summary/html/x11655.htm)에 설명된 대로 표준 와일드카드(글로빙 패턴)입니다. `allow_patterns`과 `ignore_patterns`을 모두 사용하면 두 가지 설정이 모두 적용됩니다. `.gitignore` 파일과 allow/ignore 패턴 외에 하위 경로 있는 모든 `.git/` 폴더는 무시됩니다. ```py >>> api.upload_folder( ... folder_path="/path/to/local/folder", ... path_in_repo="my-dataset/train", # 특정 폴더에 업로드 ... repo_id="username/test-dataset", ... repo_type="dataset", ... ignore_patterns="**/logs/*.txt", # 모든 로그 텍스트 파일을 무시 ... ) ``` `delete_patterns` 인수를 사용하여 동일한 커밋에서 리포지토리에서 삭제할 파일을 지정할 수도 있습니다. 이 방법은 파일을 푸시하기 전에 원격 폴더를 정리하고 싶은데 어떤 파일이 이미 존재하는지 모르는 경우에 유용합니다. 다음은 로컬 `./logs` 폴더를 원격 `/experiment/logs/` 폴더에 업로드하는 예시입니다. 폴더 내의 txt 파일만을 업로드 하게 되며 그 전에 리포지토리에 있던 모든 이전 txt 파일이 삭제됩니다. 이 모든 과정이 단 한 번의 커밋으로 이루어집니다. ```py >>> api.upload_folder( ... folder_path="/path/to/local/folder/logs", ... repo_id="username/trained-model", ... path_in_repo="experiment/logs/", ... allow_patterns="*.txt", # 모든 로컬 텍스트 파일을 업로드 ... delete_patterns="*.txt", # 모든 이전 텍스트 파일을 삭제 ... ) ``` ## CLI에서 업로드[[upload-from-the-cli]] 터미널에서 `huggingface-cli upload` 명령어를 사용하여 Hub에 파일을 직접 업로드할 수 있습니다. 내부적으로는 위에서 설명한 것과 동일한 [`upload_file`] 와 [`upload_folder`] 함수를 사용합니다. 다음과 같이 단일 파일 또는 전체 폴더를 업로드할 수 있습니다: ```bash # 사용례: huggingface-cli upload [repo_id] [local_path] [path_in_repo] >>> huggingface-cli upload Wauplin/my-cool-model ./models/model.safetensors model.safetensors https://huggingface.co/Wauplin/my-cool-model/blob/main/model.safetensors >>> huggingface-cli upload Wauplin/my-cool-model ./models . https://huggingface.co/Wauplin/my-cool-model/tree/main ``` `local_path` 와 `path_in_repo`는 선택 사항이며 주어지지 않을 시 임의로 추정됩니다. `local_path`가 설정되지 않은 경우, 이 툴은 로컬 폴더나 파일에 `repo_id`와 같은 이름이 있는지 확인하며, 발견된 경우 해당 폴더나 파일이 업로드됩니다. 같은 이름의 폴더나 파일을 찾지 못한다면 사용자에게 `local_path`를 명시하도록 요청하는 예외 처리가 발생합니다. 어떤 경우든 `path_in_repo`가 설정되어 있지 않으면 파일이 리포지토리의 루트 디렉터리에 업로드됩니다. CLI 업로드 명령어에 대한 자세한 내용은 [CLI 가이드](./cli#huggingface-cli-upload)를 참조하세요. ## 고급 기능[[advanced-features]] 대부분의 경우, Hub에 파일을 업로드하는 데 [`upload_file`]과 [`upload_folder`] 이상이 필요하지 않습니다. 하지만 `huggingface_hub`에는 작업을 더 쉽게 할 수 있는 고급 기능이 있습니다. 그 기능들을 살펴봅시다! ### 논블로킹 업로드[[non-blocking-uploads]] 메인 스레드를 멈추지 않고 데이터를 푸시하고 싶은 경우가 있습니다. 이는 모델 학습을 계속 진행하면서 로그와 아티팩트를 업로드할 때 특히 유용합니다. 이렇게 하려면 [`upload_file`]과 [[`upload_folder`] 에 `run_as_future` 인수를 사용하고 [`concurrent.futures.Future`](https://docs.python.org/3/library/concurrent.futures.html#future-objects)객체를 반환받아 업로드 상태를 확인하는 데 사용할 수 있습니다. ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> future = api.upload_folder( # 백그라운드에서 업로드 작업 수행 (논블로킹) ... repo_id="username/my-model", ... folder_path="checkpoints-001", ... run_as_future=True, ... ) >>> future Future(...) >>> future.done() False >>> future.result() # 업로드가 완료될 때까지 대기 (블로킹) ... ``` `run_as_future=True`를 사용하면 백그라운드 작업이 큐에 대기됩니다. 이는 작업이 올바른 순서로 실행된다는 것을 의미합니다. 백그라운드 작업은 주로 데이터를 업로드하거나 커밋을 생성하는 데 유용하지만, 이 외에도 [`run_as_future`]를 사용하여 원하는 메소드를 대기열에 넣을 수 있습니다. 예를 들어, 해당 기능을 사용하여 백그라운드에서 리포지토리를 만든 다음 그대로 데이터를 업로드할 수 있습니다. 업로드 메소드에 내장된 `run_as_future` 인수는 본 기능의 별칭입니다. ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.run_as_future(api.create_repo, "username/my-model", exists_ok=True) Future(...) >>> api.upload_file( ... repo_id="username/my-model", ... path_in_repo="file.txt", ... path_or_fileobj=b"file content", ... run_as_future=True, ... ) Future(...) ``` ### 청크 단위로 폴더 업로드하기[[upload-a-folder-by-chunks]] [`upload_folder`]를 사용하면 전체 폴더를 Hub에 쉽게 업로드할 수 있지만, 대용량 폴더(수천 개의 파일 또는 수백 GB의 용량)의 경우 문제가 될 수 있습니다. 파일이 많은 폴더가 있는 경우 여러 커밋에 걸쳐 업로드하는 것이 좋습니다. 업로드 중에 오류나 연결 문제가 발생해도 처음부터 다시 시작할 필요는 없습니다. 여러 커밋으로 폴더를 업로드하려면 `multi_commits=True`를 인수로 전달하면 됩니다. 내부적으로 `huggingface_hub`는 업로드/삭제할 파일을 나열하고 여러 커밋으로 분할합니다. 커밋을 분할하는 전략은 업로드할 파일의 수와 크기에 따라 결정됩니다. 모든 커밋을 푸시하기 위해 Hub에 PR이 열리게 되며, PR이 준비되면 여러 커밋이 단일 커밋으로 뭉쳐집니다. 완료하기 전에 프로세스가 중단된 경우 스크립트를 다시 실행하여 업로드를 재개할 수 있습니다. 생성된 PR이 자동으로 감지되고 업로드가 중단된 지점부터 업로드가 재개됩니다. 업로드 진행 상황을 더 잘 이해하고 싶다면 `multi_commits_verbose=True`를 인수로 전달하면 됩니다. 다음은 여러 커밋으로 체크포인트 폴더를 데이터셋에 업로드하는 예제입니다. Hub에 PR이 생성되고 업로드가 완료되면 자동으로 병합됩니다. PR을 계속 열어두고 수동으로 검토하려면 `create_pr=True`를 인수로 전달하면 됩니다. ```py >>> upload_folder( ... folder_path="local/checkpoints", ... repo_id="username/my-dataset", ... repo_type="dataset", ... multi_commits=True, ... multi_commits_verbose=True, ... ) ``` 업로드 전략(즉, 생성되는 커밋)을 더 잘 제어하고 싶으면 로우 레벨의 [`plan_multi_commits`] 와 [`create_commits_on_pr`] 메서드를 살펴보세요. `multi_commits`은 아직 실험적인 기능입니다. 해당 API와 동작은 향후 사전 고지 없이 변경될 수 있습니다. ### 예약된 업로드[[scheduled-uploads]] 허깅 페이스 Hub를 사용하면 데이터를 쉽게 저장하고 버전업할 수 있지만, 동일한 파일을 수천 번 업데이트할 때는 몇 가지 제한이 있습니다. 예를 들어, 배포된 Space에 대한 교육 프로세스 또는 사용자 로그를 저장하고 싶을 때 Hub에 데이터 집합으로 데이터를 업로드하는 것이 좋아 보이지만, 이를 제대로 하기 어려울 수 있습니다. 데이터의 모든 업데이트를 버전으로 만들게 되면 git 리포지토리를 사용할 수 없는 상태로 만들어 버리기 때문입니다. [`CommitScheduler`] 클래스는 이 문제에 대한 해결책을 제공합니다. 이 클래스는 로컬 폴더를 Hub에 정기적으로 푸시하는 백그라운드 작업을 실행합니다. 일부 텍스트를 입력으로 받아 두 개의 번역을 생성한 다음, 사용자가 선호하는 번역을 선택할 수 있는 라디오 스페이스가 있다고 가정해 보겠습니다. 이 스페이스의 각 실행에 대해 입력, 출력 및 사용자 기본 설정을 저장하여 결과를 분석하려고 하는데, 이것은 [`CommitScheduler`]의 완벽한 사용 사례가 될 수 있습니다. Hub에 데이터(잠재적으로 수백만 개의 사용자 피드백)를 저장하고 싶지만, 굳이 각 사용자의 입력을 _실시간_ 으로 저장할 필요는 없으니 데이터를 로컬 JSON 파일에 저장한 다음 10분마다 업로드하면 됩니다. 예제 코드는 다음과 같습니다: ```py >>> import json >>> import uuid >>> from pathlib import Path >>> import gradio as gr >>> from huggingface_hub import CommitScheduler # 데이터를 저장할 파일을 선언합니다. UUID를 이용하여 중복을 방지합니다. >>> feedback_file = Path("user_feedback/") / f"data_{uuid.uuid4()}.json" >>> feedback_folder = feedback_file.parent # 정기 업로드를 예약합니다. 원격 리포지토리와 로컬 폴더가 없을시 생성합니다. >>> scheduler = CommitScheduler( ... repo_id="report-translation-feedback", ... repo_type="dataset", ... folder_path=feedback_folder, ... path_in_repo="data", ... every=10, ... ) # 사용자가 피드백을 제출할 때 호출받을 함수를 정의합니다. (Gradio 안에서 호출받게 됩니다) >>> def save_feedback(input_text:str, output_1: str, output_2:str, user_choice: int) -> None: ... """ ... JSON Lines 파일에 입/출력과 사용자 피드백을 추가합니다. 타 사용자에 의한 동시적인 쓰기를 지양하기 위해 스레드 락을 사용합니다. ... """ ... with scheduler.lock: ... with feedback_file.open("a") as f: ... f.write(json.dumps({"input": input_text, "output_1": output_1, "output_2": output_2, "user_choice": user_choice})) ... f.write("\n") # Gradio를 시작합니다. >>> with gr.Blocks() as demo: >>> ... # Gradio 데모를 정의하고 `save_feedback`을 사용합니다 >>> demo.launch() ``` 사용자 입력/출력 및 피드백은 Hub에서 데이터 세트의 형태로 사용할 수 있습니다. 고유한 JSON 파일 이름을 사용하면 이전 실행이나 다른 스페이스/복제본이 동일한 리포지토리에 동시에 푸시하는 경우의 데이터를 덮어쓰지 않도록 보장할 수 있습니다. [`CommitScheduler`]에 대한 상세 사항은 다음과 같습니다: - **추가 전용:** 스케줄러는 폴더에 콘텐츠를 추가만 한다고 가정합니다. 기존 파일에 데이터를 추가하거나 새 파일을 만들 때만 사용하여야 합니다. 파일을 삭제하거나 덮어쓰면 리포지토리가 손상될 수 있습니다. - **git 히스토리**: 기본적으로 스케줄러는 `매 분마다` 폴더를 커밋합니다. git 리포지토리를 너무 많이 오염시키지 않으려면 최소값을 5분으로 설정하는 것이 좋습니다. 또한 스케줄러는 빈 커밋을 피하도록 설계되었는데, 만약 폴더에서 새 콘텐츠가 감지되지 않으면 예약된 커밋을 삭제합니다. - **에러:** 스케줄러는 백그라운드 스레드로 실행되고, 이는 클래스를 인스턴스화할 때 시작되며 절대 멈추지 않습니다. 만약 업로드 중에 오류가 발생하면(예: 연결 문제), 스케줄러는 이를 아무 말 없이 무시하고 다음 예약된 커밋에서 재시도 합니다. - **스레드 안전:** 대부분의 경우 파일 락에 대해 걱정할 필요 없이 파일에 쓰기 작업을 수행 할 수 있습니다. 스케줄러는 업로드하는 동안 대상 폴더에 콘텐츠를 쓰더라도 충돌하거나 손상되지 않습니다. 그러나, 부하가 많은 앱의 경우 이런 작업에서 _동시성 문제_ 가 발생할 수 있습니다. 이 경우, `scheduler.lock`을 사용하여 스레드 안전을 보장하는 것이 좋습니다. 이 락은 스케줄러가 폴더에서 변경 사항을 검색할 때만 차단되며, 데이터를 업로드할 때는 차단되지 않습니다. 따라서 Space의 사용자 환경에는 영향을 미치지 않습니다. #### 스페이스 지속성 데모[[space-persistence-demo]] 스페이스에서 Hub의 데이터셋으로 데이터를 영속하는 것이 [`CommitScheduler`]의 주요 사용 사례입니다. 각 사용 사례에 따라 데이터 구조를 다르게 설정해야 할 수도 있습니다. 데이터 구조는 동시 사용자와 재시작에 대해 견고해야 하며, 이는 대개 UUID를 생성 해야 함을 의미합니다. 견고함 뿐만 아니라, 재사용성을 위해 🤗 데이터 세트 라이브러리에서 읽을 수 있는 형식으로 데이터를 업로드해야 합니다. 이 [스페이스](https://huggingface.co/spaces/Wauplin/space_to_dataset_saver) 예제에서 여러 가지 데이터 형식을 저장하는 방법을 보여줍니다(각자의 필요에 맞게 조정해야 할 수도 있습니다). #### 사용자 지정 업로드[[custom-uploads]] [`CommitScheduler`]는 데이터가 추가 전용이며 "있는 그대로" 업로드해야 한다고 가정합니다. 그러나 데이터 업로드 방식을 사용자 스스로 정의하고 싶을 때도 있는데, [`CommitScheduler`]를 상속받는 클래스를 생성하고 `push_to_hub` 메서드를 덮어쓰면 됩니다(원하는 방식으로 자유롭게 덮어쓰세요). 이렇게 하면 해당 클래스가 백그라운드 스레드에서 `매 분마다` 호출됩니다. 동시성 및 오류에 대해 걱정할 필요는 없지만 빈 커밋이나 중복된 데이터를 푸시하는 것과 같은 케이스들에 주의해야 합니다. 아래의 (단순화된) 예제에서는 `push_to_hub`를 덮어써서 모든 PNG 파일을 단일 아카이브에 압축하여 Hub의 리포지토리에 과부하가 걸리는 것을 방지합니다:. ```py class ZipScheduler(CommitScheduler): def push_to_hub(self): # 1. PNG 파일들을 나열합니다. png_files = list(self.folder_path.glob("*.png")) if len(png_files) == 0: return None # 커밋할 것이 없다면 일찍 리턴합니다. # 2. png 파일들을 단일 Zip 파일로 압축합니다. with tempfile.TemporaryDirectory() as tmpdir: archive_path = Path(tmpdir) / "train.zip" with zipfile.ZipFile(archive_path, "w", zipfile.ZIP_DEFLATED) as zip: for png_file in png_files: zip.write(filename=png_file, arcname=png_file.name) # 3. 압축된 파일을 업로드 합니다. self.api.upload_file(..., path_or_fileobj=archive_path) # 4. 로컬 png 파일을 삭제하여 다음에 다시 업로드 되는 일을 방지합니다. for png_file in png_files: png_file.unlink() ``` `push_to_hub`를 덮어쓰면 다음과 같은 [`CommitScheduler`]의 속성에 접근할 수 있습니다: - [`HfApi`] 클라이언트: `api` - 폴더 매개변수: `folder_path` 및 `path_in_repo` - 리포지토리 매개변수: `repo_id`, `repo_type`, `revision` - 스레드 락: `lock` 사용자 정의 스케줄러의 더 많은 예제는 사용 사례에 따른 다양한 구현이 포함된 [데모 스페이스](https://huggingface.co/spaces/Wauplin/space_to_dataset_saver)를 참조하세요. ### create_commit[[createcommit]] [`upload_file`] 및 [`upload_folder`] 함수는 일반적으로 사용하기 편리한 하이 레벨 API입니다. 로우 레벨에서 작업할 필요가 없다면 이 함수들을 먼저 사용해 볼 것을 권장합니다. 만약 커밋 레벨에서 작업하고 싶다면 [`create_commit`] 함수를 직접 사용할 수 있습니다. [`create_commit`]이 지원하는 작업 유형은 세 가지입니다: - [`CommitOperationAdd`] 는 파일을 Hub에 업로드합니다. 파일이 이미 있는 경우 파일 내용을 덮어씁니다. 이 작업은 두 개의 인수를 받습니다: - `path_in_repo`: 파일을 업로드할 리포지토리 경로. - `path_or_fileobj`: Hub에 업로드할 파일의 파일 시스템상 파일 경로 또는 파일 스타일 객체. - [`CommitOperationDelete`]는 리포지토리에서 파일 또는 폴더를 제거합니다. 이 작업은 `path_in_repo`를 인수로 받습니다. - [`CommitOperationCopy`]는 리포지토리 내의 파일을 복사합니다. 이 작업은 세 가지 인수를 받습니다: - `src_path_in_repo`: 복사할 파일의 리포지토리 경로. - `path_in_repo`: 파일 붙여넣기를 수행할 리포지토리 경로. - `src_revision`: 선택 사항 - 다른 브랜치/리비전에서 파일을 복사하려는 경우 필요한 복사할 파일의 리비전. 예를 들어, Hub 리포지토리에서 두 개의 파일을 업로드하고 한 개의 파일을 삭제하려는 경우: 1. 파일을 추가하거나 삭제하고 폴더를 삭제하기 위해 적절한 `CommitOperation`을 사용합니다: ```py >>> from huggingface_hub import HfApi, CommitOperationAdd, CommitOperationDelete >>> api = HfApi() >>> operations = [ ... CommitOperationAdd(path_in_repo="LICENSE.md", path_or_fileobj="~/repo/LICENSE.md"), ... CommitOperationAdd(path_in_repo="weights.h5", path_or_fileobj="~/repo/weights-final.h5"), ... CommitOperationDelete(path_in_repo="old-weights.h5"), ... CommitOperationDelete(path_in_repo="logs/"), ... CommitOperationCopy(src_path_in_repo="image.png", path_in_repo="duplicate_image.png"), ... ] ``` 2. 작업을 [`create_commit`]에 전달합니다: ```py >>> api.create_commit( ... repo_id="lysandre/test-model", ... operations=operations, ... commit_message="Upload my model weights and license", ... ) ``` 다음 함수들은 [`upload_file`] 및 [`upload_folder`] 외에도 내부적으로 [`create_commit`]을 사용합니다: - [`delete_file`]은 Hub의 리포지토리에서 단일 파일을 삭제합니다. - [`delete_folder`]는 Hub의 리포지토리에서 전체 폴더를 삭제합니다. - [`metadata_update`]는 리포지토리의 메타데이터를 업데이트합니다. 자세한 내용은 [`HfApi`] 의 레퍼런스를 참조하세요. ### 커밋하기 전에 LFS 파일 미리 업로드하기[[preupload-lfs-files-before-commit]] 경우에 따라 커밋 호출을 **하기 전에** 대용량 파일을 S3에 업로드해야 할 수도 있습니다. 예를 들어 인메모리에 생성된 여러 개의 샤드에 있는 데이터 세트를 커밋하는 경우, 샤드를 하나씩 업로드해야 메모리 부족 문제를 피할 수 있을 것입니다. 이 문제에 대한 해결책은 각 샤드를 리포지토리에 별도의 커밋으로 업로드하는 것입니다. 이 방법은 완벽하게 유효하지만, 수십 개의 커밋을 생성하여 잠재적으로 git 히스토리를 엉망으로 만들 수 있다는 단점이 있습니다. 이 문제를 극복하기 위해 파일을 하나씩 S3에 업로드한 다음 마지막에 하나의 커밋을 생성할 수 있습니다. 이는 [`preupload_lfs_files`]와 [`create_commit`]을 함께 사용하면 됩니다. 이 방법은 고급 사용자를 위한 방식입니다. 사전에 파일을 미리 업로드하는 로우 레벨 로직을 처리하는 대신 [`upload_file`], [`upload_folder`] 또는 [`create_commit`]을 직접 사용하는 것이 대부분의 경우에 적합합니다. [`preupload_lfs_files`]의 주요 주의 사항은 커밋이 실제로 이루어질 때까지는 Hub의 리포지토리에서 업로드 파일에 액세스할 수 없다는 것입니다. 궁금한 점이 있으면 언제든지 Discord나 GitHub 이슈로 문의해 주세요. 다음은 파일을 미리 업로드하는 방법을 보여주는 간단한 예시입니다: ```py >>> from huggingface_hub import CommitOperationAdd, preupload_lfs_files, create_commit, create_repo >>> repo_id = create_repo("test_preupload").repo_id >>> operations = [] # 생성될 모든 `CommitOperationAdd` 객체를 나열합니다. >>> for i in range(5): ... content = ... # bin 자료를 생성합니다. ... addition = CommitOperationAdd(path_in_repo=f"shard_{i}_of_5.bin", path_or_fileobj=content) ... preupload_lfs_files(repo_id, additions=[addition]) ... operations.append(addition) >>> # 커밋을 생성합니다. >>> create_commit(repo_id, operations=operations, commit_message="Commit all shards") ``` 먼저, [`CommitOperationAdd`] 오브젝트를 하나씩 생성합니다. 실전을 상정한 예제에서는, 여기에 생성된 샤드를 포함합니다. 각 파일은 다음 파일을 생성하기 전에 업로드됩니다. [`preupload_lfs_files`] 단계에서는 **`CommitOperationAdd` 오브젝트가 변경됩니다.** 따라서[`create_commit`]에 직접 전달할 때만 사용해야 합니다. 오브젝트의 주요 업데이트는 **바이너리 콘텐츠가 제거**된다는 것인데, 이는 따로 레퍼런스를 저장하지 않으면 가비지 콜렉팅 됨을 의미합니다. 이미 업로드된 콘텐츠를 메모리상에 남기고 싶지 않기 때문입니다. 마지막으로 모든 작업을 [`create_commit`]에 전달하여 커밋을 생성합니다. 아직 처리되지 않은 추가 작업(추가, 삭제 또는 복사)도 전달하면 올바르게 처리됩니다. ## 대용량 업로드를 위한 팁과 요령[[tips-and-tricks-for-large-uploads]] 리포지토리에 있는 대량의 데이터를 처리할 때 주의해야 할 몇 가지 제한 사항이 있습니다. 데이터를 스트리밍하는 데 걸리는 시간을 고려하면, 프로세스 마지막에 업로드/푸시가 실패하거나 hf.co에서 또는 로컬에서 작업할 때 성능 저하가 발생하는 것은 매우 성가신 일이 될 수 있습니다. Hub에서 리포지토리를 구성하는 방법에 대한 모범 사례는 [리포지토리 제한 사항 및 권장 사항](https://huggingface.co/docs/hub/repositories-recommendations) 가이드를 참조하세요. 다음으로 업로드 프로세스를 최대한 원활하게 진행할 수 있는 몇 가지 실용적인 팁을 살펴보겠습니다. - **작게 시작하세요**: 업로드 스크립트를 테스트할 때는 소량의 데이터로 시작하는 것이 좋습니다. 소량의 데이터를 처리하는데 적은 시간이 들기 때문에 스크립트를 반복하는 것이 더 쉽습니다. - **실패를 예상하세요**: 대량의 데이터를 스트리밍하는 것은 어려운 일입니다. 어떤 일이 일어날지 알 수 없지만, 항상 컴퓨터, 연결, 서버 등 어떤 이유로든 한 번쯤은 실패할 수 있다는 점을 고려하는 것이 가장 좋습니다. 예를 들어, 많은 양의 파일을 업로드할 계획이라면 다음 파일을 업로드하기 전에 이미 업로드한 파일을 로컬에서 추적하는 것이 가장 좋습니다. 이미 커밋된 LFS 파일은 절대 두 번 다시 업로드되지 않지만 클라이언트 측에서 이를 확인하면 시간을 절약할 수 있습니다. - **`hf_transfer`를 사용하세요**: [`hf_transfer`](https://github.com/huggingface/hf_transfer)는 대역폭이 매우 높은 컴퓨터에서 업로드 속도를 높이기 위한 Rust 기반 라이브러리입니다. `hf_transfer`를 사용하려면: 1. `huggingface_hub`를 설치할 때 `hf_transfer`를 추가로 지정합니다. (예: `pip install huggingface_hub[hf_transfer]`). 2. 환경 변수로 `HF_HUB_ENABLE_HF_TRANSFER=1`을 설정합니다. `hf_transfer`는 고급 사용자 도구입니다! 테스트 및 프로덕션 준비가 완료되었지만, 고급 오류 처리나 프록시와 같은 사용자 친화적인 기능이 부족합니다. 자세한 내용은 [이 섹션](https://huggingface.co/docs/huggingface_hub/hf_transfer)을 참조하세요. ## (레거시) Git LFS로 파일 업로드하기[[legacy-upload-files-with-git-lfs]] 위에서 설명한 모든 방법은 Hub의 API를 사용하여 파일을 업로드하며, 이는 Hub에 파일을 업로드하는 데 권장되는 방법입니다. 이뿐만 아니라 로컬 리포지토리를 관리하기 위하여 git 도구의 래퍼인 [`Repository`]또한 제공합니다. [`Repository`]는 공식적으로 지원 종료된 것은 아니지만, 가급적이면 위에서 설명한 HTTP 기반 방법들을 사용할 것을 권장합니다. 이 권장 사항에 대한 자세한 내용은 HTTP 기반 방식과 Git 기반 방식 간의 핵심적인 차이점을 설명하는 [이 가이드](../concepts/git_vs_http)를 참조하세요. Git LFS는 10MB보다 큰 파일을 자동으로 처리합니다. 하지만 매우 큰 파일(5GB 이상)의 경우, Git LFS용 사용자 지정 전송 에이전트를 설치해야 합니다: ```bash huggingface-cli lfs-enable-largefiles ``` 매우 큰 파일이 있는 각 리포지토리에 대해 이 옵션을 설치해야 합니다. 설치가 완료되면 5GB보다 큰 파일을 푸시할 수 있습니다. ### 커밋 컨텍스트 관리자[[commit-context-manager]] `commit` 컨텍스트 관리자는 가장 일반적인 네 가지 Git 명령인 pull, add, commit, push를 처리합니다. `git-lfs`는 10MB보다 큰 파일을 자동으로 추적합니다. 다음 예제에서는 `commit` 컨텍스트 관리자가 다음과 같은 작업을 수행합니다: 1. `text-files` 리포지토리에서 pull. 2. `file.txt`에 변경 내용을 add. 3. 변경 내용을 commit. 4. 변경 내용을 `text-files` 리포지토리에 push. ```python >>> from huggingface_hub import Repository >>> with Repository(local_dir="text-files", clone_from="/text-files").commit(commit_message="My first file :)"): ... with open("file.txt", "w+") as f: ... f.write(json.dumps({"hey": 8})) ``` 다음은 `commit` 컨텍스트 관리자를 사용하여 파일을 저장하고 리포지토리에 업로드하는 방법의 또 다른 예입니다: ```python >>> import torch >>> model = torch.nn.Transformer() >>> with Repository("torch-model", clone_from="/torch-model", token=True).commit(commit_message="My cool model :)"): ... torch.save(model.state_dict(), "model.pt") ``` 커밋을 비동기적으로 푸시하려면 `blocking=False`를 설정하세요. 커밋을 푸시하는 동안 스크립트를 계속 실행하고 싶을 때 논 블로킹 동작이 유용합니다. ```python >>> with repo.commit(commit_message="My cool model :)", blocking=False) ``` `command_queue` 메서드로 푸시 상태를 확인할 수 있습니다: ```python >>> last_command = repo.command_queue[-1] >>> last_command.status ``` 가능한 상태는 아래 표를 참조하세요: | 상태 | 설명 | | -------- | ----------------------------- | | -1 | 푸시가 진행 중입니다. | | 0 | 푸시가 성공적으로 완료되었습니다.| | Non-zero | 오류가 발생했습니다. | `blocking=False`인 경우, 명령이 추적되며 스크립트에서 다른 오류가 발생하더라도 모든 푸시가 완료된 경우에만 스크립트가 종료됩니다. 푸시 상태를 확인하는 데 유용한 몇 가지 추가 명령은 다음과 같습니다: ```python # 오류를 검사합니다. >>> last_command.stderr # 푸시 진행여부를 확인합니다. >>> last_command.is_done # 푸시 명령의 에러여부를 파악합니다. >>> last_command.failed ``` ### push_to_hub[[pushtohub]] [`Repository`] 클래스에는 파일을 추가하고 커밋한 후 리포지토리로 푸시하는 [`~Repository.push_to_hub`] 함수가 있습니다. [`~Repository.push_to_hub`]는 `commit` 컨텍스트 관리자와는 달리 호출하기 전에 먼저 리포지토리에서 업데이트(pull) 작업을 수행 해야 합니다. 예를 들어 Hub에서 리포지토리를 이미 복제했다면 로컬 디렉터리에서 `repo`를 초기화할 수 있습니다: ```python >>> from huggingface_hub import Repository >>> repo = Repository(local_dir="path/to/local/repo") ``` 로컬 클론을 [`~Repository.git_pull`]로 업데이트한 다음 파일을 Hub로 푸시합니다: ```py >>> repo.git_pull() >>> repo.push_to_hub(commit_message="Commit my-awesome-file to the Hub") ``` 그러나 아직 파일을 푸시할 준비가 되지 않았다면 [`~Repository.git_add`] 와 [`~Repository.git_commit`]을 사용하여 파일만 추가하고 커밋할 수 있습니다: ```py >>> repo.git_add("path/to/file") >>> repo.git_commit(commit_message="add my first model config file :)") ``` 준비가 완료되면 [`~Repository.git_push`]를 사용하여 파일을 리포지토리에 푸시합니다: ```py >>> repo.git_push() ``` huggingface_hub-0.31.1/docs/source/ko/guides/webhooks_server.md000066400000000000000000000260371500667546600246020ustar00rootroot00000000000000 # 웹훅 서버[[webhooks-server]] 웹훅은 MLOps 관련 기능의 기반이 됩니다. 이를 통해 특정 저장소의 새로운 변경 사항을 수신하거나, 관심 있는 특정 사용자/조직에 속한 모든 저장소의 변경 사항을 받아볼 수 있습니다. 이 가이드에서는 `huggingface_hub`를 활용하여 웹훅을 수신하는 서버를 만들고 Space에 배포하는 방법을 설명합니다. 이를 위해서는 Huggingface Hub의 웹훅 개념에 대해 익숙해야 합니다. 웹훅 자체에 대해 더 자세히 알아보려면 이 [가이드](https://huggingface.co/docs/hub/webhooks)를 먼저 읽어보세요. 이 가이드에서 사용할 기본 클래스는 [`WebhooksServer`]입니다. 이 클래스는 Huggingface Hub에서 웹훅을 받을 수 있는 서버를 쉽게 구성할 수 있습니다. 서버는 [Gradio](https://gradio.app/) 앱을 기반으로 합니다. 이 서버에는 사용자를 위한 지침을 표시하는 UI와 웹훅을 수신하는 API가 있습니다. 웹훅 서버의 실행 예시를 보려면 [Spaces CI Bot](https://huggingface.co/spaces/spaces-ci-bot/webhook)을 확인하세요. 이것은 Space의 PR이 열릴 때마다 임시 환경을 실행하는 Space입니다. 이것은 [실험적 기능](../package_reference/environment_variables#hfhubdisableexperimentalwarning)입니다. 본 API는 현재 개선 작업 중이며, 향후 사전 통지 없이 주요 변경 사항이 도입될 수 있습니다. requirements에서 `huggingface_hub`의 버전을 고정하는 것을 권장합니다. ## 엔드포인트 생성[[create-an-endpoint]] 웹훅 엔드포인트를 구현하는 것은 함수에 데코레이터를 추가하는 것만큼 간단합니다. 주요 개념을 설명하기 위해 첫 번째 예시를 살펴보겠습니다: ```python # app.py from huggingface_hub import webhook_endpoint, WebhookPayload @webhook_endpoint async def trigger_training(payload: WebhookPayload) -> None: if payload.repo.type == "dataset" and payload.event.action == "update": # 데이터 세트가 업데이트되면 학습 작업을 트리거합니다. ... ``` 이 코드 스니펫을 `'app.py'`라는 파일에 저장하고 `'python app.py'`로 실행하면 다음과 같은 메시지가 표시될 것입니다: ```text Webhook secret is not defined. This means your webhook endpoints will be open to everyone. To add a secret, set `WEBHOOK_SECRET` as environment variable or pass it at initialization: `app = WebhooksServer(webhook_secret='my_secret', ...)` For more details about webhook secrets, please refer to https://huggingface.co/docs/hub/webhooks#webhook-secret. Running on local URL: http://127.0.0.1:7860 Running on public URL: https://1fadb0f52d8bf825fc.gradio.live This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces Webhooks are correctly setup and ready to use: - POST https://1fadb0f52d8bf825fc.gradio.live/webhooks/trigger_training Go to https://huggingface.co/settings/webhooks to setup your webhooks. ``` 축하합니다! 웹훅 서버를 실행했습니다! 정확히 어떤 일이 일어났는지 살펴보겠습니다: 1. [`webhook_endpoint`]로 함수에 데코레이터를 추가하면 백그라운드에서 [`WebhooksServer`] 객체가 생성됩니다. 볼 수 있듯이 이 서버는 http://127.0.0.1:7860 에서 실행되는 Gradio 앱입니다. 이 URL을 브라우저에서 열면 등록된 웹훅에 대한 지침이 있는 랜딩 페이지를 볼 수 있습니다. 2. Gradio 앱은 내부적으로 FastAPI 서버입니다. 새로운 POST 경로 `/webhooks/trigger_training`이 추가되었습니다. 이 경로는 웹훅을 수신하고 트리거될 때 `trigger_training` 함수를 실행합니다. FastAPI는 자동으로 페이로드를 구문 분석하고 [`WebhookPayload`] 객체로 함수에 전달합니다. 이 `pydantic` 객체에는 웹훅을 트리거한 이벤트에 대한 모든 정보가 포함되어 있습니다. 3. Gradio 앱은 인터넷에서 요청을 받을 수 있는 터널도 열었습니다. 이것은 흥미로운 부분으로, https://huggingface.co/settings/webhooks 에서 로컬 머신을 가리키는 웹훅을 구성할 수 있습니다. 이를 통해 웹훅 서버를 디버깅하고 Space에 배포하기 전에 빠르게 반복할 수 있습니다. 4. 마지막으로 로그에는 서버가 현재 비밀로 보호되지 않는다고 알려줍니다. 이것은 로컬 디버깅에는 문제가 되지 않지만 나중에 고려해야 할 사항입니다. 기본적으로 서버는 스크립트 끝에서 시작됩니다. 주피터 노트북에서 실행 중이라면 `decorated_function.run()`을 호출하여 서버를 수동으로 시작할 수 있습니다. 고유한 서버를 사용하기 때문에 여러 엔드포인트가 있더라도 서버를 한 번만 시작하면 됩니다. ## 웹훅 설정하기[[configure-a-webhook]] 웹훅 서버를 실행하고 있으므로, 이제 메시지를 수신하기 위해 웹훅을 구성해야 합니다. https://huggingface.co/settings/webhooks 로 이동하여 "Add a new webhook"을 클릭하고 웹훅을 구성하세요. 모니터링할 대상 저장소와 웹훅 URL `https://1fadb0f52d8bf825fc.gradio.live/webhooks/trigger_training`을 설정하세요.
이걸로 끝입니다! 이제 대상 저장소를 업데이트하면 웹훅을 트리거할 수 있습니다. 예를 들면, 커밋 푸시가 그 방법이 될 수 있습니다. 웹훅의 Activity 탭에서 트리거된 이벤트를 확인할 수 있습니다. 이제 작동하는 구성이 있으므로 테스트하고 빠르게 반복할 수 있습니다. 코드를 수정하고 서버를 다시 시작하면 공개 URL이 변경될 수 있습니다. 필요한 경우 Hub에서 웹훅 구성을 업데이트하세요. ## Space에 배포하기[[deploy-to-a-space]] 이제 작동하는 웹훅 서버가 마련되었으므로, 다음 목표는 이를 Space에 배포하는 것입니다. https://huggingface.co/new-space 에 가서 Space를 생성합니다. 이름을 지정하고, Gradio SDK를 선택한 다음 "Create Space"를 클릭합니다. 코드를 `app.py` 파일로 Space에 업로드합니다. Space가 자동으로 시작됩니다! Space에 대한 자세한 내용은 이 [가이드](https://huggingface.co/docs/hub/spaces-overview)를 참조하세요. 웹훅 서버가 이제 공개 Space에서 실행 중입니다. 대부분의 경우 비밀번호로 보안을 설정하고 싶을 것입니다. Space 설정 > "Repository secrets" 섹션 > "Add a secret" 로 이동합니다. `WEBHOOK_SECRET` 환경 변수에 원하는 값을 설정합니다. [Webhooks 설정](https://huggingface.co/settings/webhooks)으로 돌아가서 웹훅 구성에 비밀번호를 설정합니다. 이제 올바른 비밀번호가 있는 요청만 서버에서 허용됩니다. 이게 전부입니다! Space가 이제 Hub의 웹훅을 수신할 준비가 되었습니다. 무료 하드웨어인 'cpu-basic'에서 Space를 실행 시, 48시간 동안 비활성화되면 종료된다는 점을 유념하세요. 영구적인 Space가 필요한 경우 [업그레이드된 하드웨어](https://huggingface.co/docs/hub/spaces-gpus#hardware-specs)를 설정해야 합니다. ## 고급 사용법[[advanced-usage]] 위의 가이드에서는 [`WebhooksServer`]를 설정하는 가장 빠른 방법에 대해 설명했습니다. 이 섹션에서는 이를 더욱 사용자 정의하는 방법을 살펴보겠습니다. ### 다중 엔드포인트[[multiple-endpoints]] 동일한 서버에 여러 엔드포인트를 등록할 수 있습니다. 예를 들어, 하나의 엔드포인트는 학습 작업을 트리거하고 다른 엔드포인트는 모델 평가를 트리거하도록 할 수 있습니다. 이를 위해 여러 개의 `@webhook_endpoint` 데코레이터를 추가하면 됩니다: ```python # app.py from huggingface_hub import webhook_endpoint, WebhookPayload @webhook_endpoint async def trigger_training(payload: WebhookPayload) -> None: if payload.repo.type == "dataset" and payload.event.action == "update": # 데이터 세트가 업데이트되면 학습 작업을 트리거합니다. ... @webhook_endpoint async def trigger_evaluation(payload: WebhookPayload) -> None: if payload.repo.type == "model" and payload.event.action == "update": # 모델이 업데이트되면 평가 작업을 트리거합니다. ... ``` 이렇게 하면 두 개의 엔드포인트가 생성됩니다: ```text (...) Webhooks are correctly setup and ready to use: - POST https://1fadb0f52d8bf825fc.gradio.live/webhooks/trigger_training - POST https://1fadb0f52d8bf825fc.gradio.live/webhooks/trigger_evaluation ``` ### 사용자 정의 서버[[custom-server]] 더 많은 유연성을 얻기 위해 [`WebhooksServer`] 객체를 직접 생성할 수도 있습니다. 이것은 서버의 랜딩 페이지를 사용자 정의하고자 할 때 유용합니다. 기본 페이지를 덮어쓸 [Gradio UI](https://gradio.app/docs/#blocks)를 전달하여 이를 수행할 수 있습니다. 예를 들어, 사용자를 위한 지침을 추가하거나 웹훅을 수동으로 트리거하는 양식을 추가할 수 있습니다. [`WebhooksServer`]를 생성할 때, [`~WebhooksServer.add_webhook`] 데코레이터를 사용하여 새로운 웹훅을 등록할 수 있습니다. 전체 예제는 다음과 같습니다: ```python import gradio as gr from fastapi import Request from huggingface_hub import WebhooksServer, WebhookPayload # 1. UI 정의 with gr.Blocks() as ui: ... # 2. 사용자 정의 UI와 시크릿으로 WebhooksServer 생성 app = WebhooksServer(ui=ui, webhook_secret="my_secret_key") # 3. 명시적 이름으로 웹훅 등록 @app.add_webhook("/say_hello") async def hello(payload: WebhookPayload): return {"message": "hello"} # 4. 암시적 이름으로 웹훅 등록 @app.add_webhook async def goodbye(payload: WebhookPayload): return {"message": "goodbye"} # 5. 서버 시작 (선택 사항) app.run() ``` 1. Gradio 블록을 사용하여 사용자 정의 UI를 정의합니다. 이 UI는 서버의 랜딩 페이지에 표시됩니다. 2. 사용자 정의 UI와 시크릿으로 [`WebhooksServer`] 객체를 생성합니다. 시크릿은 선택 사항이며 `WEBHOOK_SECRET` 환경 변수로 설정할 수 있습니다. 3. 명시적 이름으로 웹훅을 등록합니다. 이렇게 하면 `/webhooks/say_hello` 엔드포인트가 생성됩니다. 4. 암시적 이름으로 웹훅을 등록합니다. 이렇게 하면 `/webhooks/goodbye` 엔드포인트가 생성됩니다. 5. 서버를 시작합니다. 이것은 선택 사항이며 스크립트 끝에서 자동으로 서버가 시작됩니다. huggingface_hub-0.31.1/docs/source/ko/index.md000066400000000000000000000103711500667546600212140ustar00rootroot00000000000000 # 🤗 Hub 클라이언트 라이브러리 [[hub-client-library]] `huggingface_hub` 라이브러리는 [Hugging Face Hub](https://hf.co)와 상호작용할 수 있게 해줍니다. Hugging Face Hub는 창작자와 협업자를 위한 머신러닝 플랫폼입니다. 여러분의 프로젝트에 적합한 사전 훈련된 모델과 데이터셋을 발견하거나, Hub에 호스팅된 수백 개의 머신러닝 앱들을 사용해보세요. 또한, 여러분이 만든 모델과 데이터셋을 커뮤니티와 공유할 수도 있습니다. `huggingface_hub` 라이브러리는 파이썬으로 이 모든 것을 간단하게 할 수 있는 방법을 제공합니다. `huggingface_hub` 라이브러리를 사용하기 위한 [빠른 시작 가이드](quick-start)를 읽어보세요. Hub에서 파일을 다운로드하거나, 레포지토리를 생성하거나, 파일을 업로드하는 방법을 배울 수 있습니다. 계속 읽어보면, 🤗 Hub에서 여러분의 레포지토리를 어떻게 관리하고, 토론에 어떻게 참여하고, 추론 API에 어떻게 접근하는지 알아볼 수 있습니다. ## 기여하기 [[contribute]] `huggingface_hub`에 대한 모든 기여를 환영하며, 소중히 생각합니다! 🤗 코드에서 기존의 이슈를 추가하거나 수정하는 것 외에도, 문서를 정확하고 최신으로 유지하도록 개선하거나, 이슈에 대한 질문에 답하거나, 라이브러리를 개선할 수 있다고 생각하는 새로운 기능을 요청하는 것도 커뮤니티에 도움이 됩니다. 새로운 이슈나 기능 요청을 제출하는 방법, PR을 제출하는 방법, 기여한 내용을 테스트하여 모든 것이 예상대로 작동하는지 확인하는 방법 등에 대해 더 알아보려면 [기여 가이드](https://github.com/huggingface/huggingface_hub/blob/main/CONTRIBUTING.md)를 살펴보세요. 기여자들은 또한 모든 사람들을 위해 포괄적이고 환영받는 협업 공간을 만들기 위해 우리의 [행동 강령](https://github.com/huggingface/huggingface_hub/blob/main/CODE_OF_CONDUCT.md)을 준수해야 합니다. huggingface_hub-0.31.1/docs/source/ko/installation.md000066400000000000000000000172531500667546600226140ustar00rootroot00000000000000 # 설치 방법 [[installation]] 시작하기 전에 적절한 패키지를 설치하여 환경을 설정해야 합니다. `huggingface_hub`는 **Python 3.8+**에서 테스트되었습니다. ## pip로 설치하기 [[install-with-pip]] [가상 환경](https://docs.python.org/3/library/venv.html)에서 `huggingface_hub`를 설치하는 것을 적극 권장합니다. 파이썬 가상 환경에 익숙하지 않다면 이 [가이드](https://packaging.python.org/en/latest/guides/installing-using-pip-and-virtual-environments/)를 참고하세요. 가상 환경을 사용하면 여러 프로젝트를 더 쉽게 관리하고 의존성 간의 호환성 문제를 피할 수 있습니다. 프로젝트 디렉토리에 가상 환경을 생성하는 것으로 시작하세요: ```bash python -m venv .env ``` 가상환경을 활성화하려면 Linux 및 macOS의 경우: ```bash source .env/bin/activate ``` Windows의 경우: ```bash .env/Scripts/activate ``` [PyPi 레지스트리](https://pypi.org/project/huggingface-hub/)에서 `huggingface_hub`를 설치할 준비가 되었습니다: ```bash pip install --upgrade huggingface_hub ``` 완료되면 [설치 확인](#check-installation)이 올바르게 작동하는지 확인합니다. ### 선택 의존성 설치 [[install-optional-dependencies]] `huggingface_hub`의 일부 의존성은 `huggingface_hub`의 핵심 기능을 실행하는 데 필요하지 않으므로 [선택적](https://setuptools.pypa.io/en/latest/userguide/dependency_management.html#optional-dependencies)입니다. 설치가 되어있지 않다면 `huggingface_hub`의 추가적인 기능을 사용하지 못할 수 있습니다. 선택적 의존성은 `pip`을 통해 설치할 수 있습니다: ```bash # TensorFlow 관련 기능에 대한 의존성을 설치합니다. # /!\ 경고: `pip install tensorflow`와 동일하지 않습니다. pip install 'huggingface_hub[tensorflow]' # PyTorch와 CLI와 관련된 기능에 대한 의존성을 모두 설치합니다. pip install 'huggingface_hub[cli,torch]' ``` 다음은 `huggingface_hub`의 선택 의존성 목록입니다: - `cli`: 보다 편리한 `huggingface_hub`의 CLI 인터페이스입니다. - `fastai`, `torch`, `tensorflow`: 프레임워크별 기능을 실행하려면 필요합니다. - `dev`: 라이브러리에 기여하고 싶다면 필요합니다. 테스트 실행을 위한 `testing`, 타입 검사기 실행을 위한 `typing`, 린터 실행을 위한 `quality`가 포함됩니다. ### 소스에서 설치 [[install-from-source]] 경우에 따라 소스에서 직접 `huggingface_hub`를 설치하는 게 더 나을수도 있습니다. 이렇게 하면 최신 릴리스 버전이 아닌 최신 `main` 버전을 사용할 수 있습니다. `main` 버전은 마지막 공식 릴리스 이후 버그가 수정되었지만 아직 새 릴리스가 출시되지 않은 경우와 같이 최신 개발 사항을 들고오는 데 유용합니다. 동시에 `main` 버전은 항상 안정적일 수 없다는 뜻이기도 합니다. 저희는 `main` 버전을 계속 운영하기 위해 노력하고 있으며, 대부분의 문제는 보통 몇 시간 또는 하루 이내에 해결됩니다. 문제가 발생하면 이슈를 열어주시면 더 빨리 해결할 수 있어요! ```bash pip install git+https://github.com/huggingface/huggingface_hub ``` 소스에서 설치할 때 특정 브랜치를 지정할 수도 있습니다. 아직 병합되지 않은 새로운 기능이나 새로운 버그 수정을 테스트하려는 경우에 유용합니다: ```bash pip install git+https://github.com/huggingface/huggingface_hub@my-feature-branch ``` 완료되면 [설치 확인](#check-installation)을 통해 올바르게 작동하는지 확인하세요. ### 편집 가능한 설치 [[editable-install]] 소스에서 설치하면 [편집 가능한 설치](https://pip.pypa.io/en/stable/topics/local-project-installs/#editable-installs)를 설정할 수 있습니다. 이런 고급 설치는 `huggingface_hub`에 기여하고 코드의 변경 사항을 테스트해야 하는 경우에 쓰입니다. 컴퓨터에 `huggingface_hub`의 로컬 복사본을 클론해둬야 합니다. ```bash # 먼저 로컬에 리포지토리를 복제하세요. git clone https://github.com/huggingface/huggingface_hub.git # 그런 다음 -e 플래그를 사용하여 설치하세요. cd huggingface_hub pip install -e . ``` 이렇게 클론한 레포지토리 폴더와 Python 경로를 연결합니다. 이제 Python은 일반적인 라이브러리 경로 외에도 복제된 폴더 내부를 찾습니다. 예를 들어 파이썬 패키지가 일반적으로 `./.venv/lib/python3.13/site-packages/`에 설치되어 있다면, Python은 복제된 폴더 `./huggingface_hub/`도 검색하게 됩니다. ## conda로 설치하기 [[install-with-conda]] 이미 익숙하다면 [conda-forge 채널](https://anaconda.org/conda-forge/huggingface_hub)를 통해 `huggingface_hub`를 설치할 수도 있습니다: ```bash conda install -c conda-forge huggingface_hub ``` 완료되면 [설치 확인](#check-installation)을 통해 올바르게 작동하는지 확인하세요. ## 설치 확인 [[check-installation]] 설치가 완료되면 다음 명령을 실행하여 `huggingface_hub`가 제대로 작동하는지 확인하세요: ```bash python -c "from huggingface_hub import model_info; print(model_info('gpt2'))" ``` 이 명령은 Hub에서 [gpt2](https://huggingface.co/gpt2) 모델에 대한 정보를 가져옵니다. 출력은 다음과 같아야 합니다: ```text Model Name: gpt2 Tags: ['pytorch', 'tf', 'jax', 'tflite', 'rust', 'safetensors', 'gpt2', 'text-generation', 'en', 'doi:10.57967/hf/0039', 'transformers', 'exbert', 'license:mit', 'has_space'] Task: text-generation ``` ## Windows 제한 사항 [[windows-limitations]] 좋은 ML을 어디서나 사용할 수 있게 하자는 목표 아래, `huggingface_hub`를 크로스 플랫폼 라이브러리로 만들었으며, 특히 유닉스 기반과 Windows 시스템 모두에서 잘 작동하도록 했습니다. 그럼에도 `huggingface_hub`를 Windows에서 실행할 때 몇 가지 제한이 있습니다. 다음은 알려진 문제의 전체 목록입니다. 문서화되지 않은 문제가 발생하면 [GitHub에 이슈](https://github.com/huggingface/huggingface_hub/issues/new/choose)를 열어서 알려주시기 바랍니다. - `huggingface_hub`의 캐시 시스템은 Hub에서 다운로드한 파일을 효율적으로 캐시하기 위해 심볼릭 링크에 의존합니다. Windows에서는 개발자 모드를 활성화하거나 관리자 권한으로 스크립트를 실행해야 심볼릭 링크를 활성화할 수 있습니다. 활성화하지 않으면 캐시 시스템이 계속 작동하지만 최적화되지 않은 방식으로 작동합니다. 자세한 내용은 [캐시 제한](./guides/manage-cache#limitations) 섹션을 참조하세요. - Hub의 파일 경로에는 특수 문자를 사용할 수 있습니다(예: `"path/to?/my/file"`). 드문 경우이길 바라지만, Windows는 [특수 문자](https://learn.microsoft.com/en-us/windows/win32/intl/character-sets-used-in-file-names)에 대한 제한이 더 엄격하기 때문에 해당 파일을 다운로드할 수 없습니다. 실수라고 생각되면 레포지토리 소유자에게 문의하시거나 해결책을 찾기 위해 저희에게 연락해 주세요. ## 다음 단계 [[next-steps]] 컴퓨터에 `huggingface_hub`가 제대로 설치되면 [환경 변수를 설정](package_reference/environment_variables)하거나 [가이드 중 하나를 골라](guides/overview) 시작할 수 있습니다. huggingface_hub-0.31.1/docs/source/ko/package_reference/000077500000000000000000000000001500667546600231725ustar00rootroot00000000000000huggingface_hub-0.31.1/docs/source/ko/package_reference/cache.md000066400000000000000000000034301500667546600245570ustar00rootroot00000000000000 # 캐시 시스템 참조[[cache-system-reference]] 버전 0.8.0에서의 업데이트로, 캐시 시스템은 Hub에 의존하는 라이브러리 전체에서 공유되는 중앙 캐시 시스템으로 발전하였습니다. Hugging Face 캐싱에 대한 자세한 설명은 [캐시 시스템 가이드](../guides/manage-cache)를 참조하세요. ## 도우미 함수[[helpers]] ### try_to_load_from_cache[[huggingface_hub.try_to_load_from_cache]] [[autodoc]] huggingface_hub.try_to_load_from_cache ### cached_assets_path[[huggingface_hub.cached_assets_path]] [[autodoc]] huggingface_hub.cached_assets_path ### scan_cache_dir[[huggingface_hub.scan_cache_dir]] [[autodoc]] huggingface_hub.scan_cache_dir ## 데이터 구조[[data-structures]] 모든 구조체는 [`scan_cache_dir`]에 의해 생성되고 반환되며, 불변(immutable)입니다. ### HFCacheInfo[[huggingface_hub.HFCacheInfo]] [[autodoc]] huggingface_hub.HFCacheInfo ### CachedRepoInfo[[huggingface_hub.CachedRepoInfo]] [[autodoc]] huggingface_hub.CachedRepoInfo - size_on_disk_str - refs ### CachedRevisionInfo[[huggingface_hub.CachedRevisionInfo]] [[autodoc]] huggingface_hub.CachedRevisionInfo - size_on_disk_str - nb_files ### CachedFileInfo[[huggingface_hub.CachedFileInfo]] [[autodoc]] huggingface_hub.CachedFileInfo - size_on_disk_str ### DeleteCacheStrategy[[huggingface_hub.DeleteCacheStrategy]] [[autodoc]] huggingface_hub.DeleteCacheStrategy - expected_freed_size_str ## 예외[[exceptions]] ### CorruptedCacheException[[huggingface_hub.CorruptedCacheException]] [[autodoc]] huggingface_hub.CorruptedCacheException huggingface_hub-0.31.1/docs/source/ko/package_reference/cards.md000066400000000000000000000044601500667546600246140ustar00rootroot00000000000000# 리포지토리 카드[[repository-cards]] huggingface_hub 라이브러리는 모델/데이터 세트 카드를 생성, 공유 및 업데이트하기 위한 Python 인터페이스를 제공합니다. Hub의 모델 카드가 무엇이며 내부적으로 어떻게 작동하는지 더 깊이 있게 알아보려면 [전용 문서 페이지](https://huggingface.co/docs/hub/models-cards)를 방문하세요. 또한 이러한 유틸리티를 자신의 프로젝트에서 어떻게 사용할 수 있는지 감을 잡기 위해 [모델 카드 가이드](../how-to-model-cards)를 확인할 수 있습니다. ## 리포지토리 카드[[huggingface_hub.RepoCard]] `RepoCard` 객체는 [`ModelCard`], [`DatasetCard`] 및 `SpaceCard`의 상위 클래스입니다. [[autodoc]] huggingface_hub.repocard.RepoCard - __init__ - all ## 카드 데이터[[huggingface_hub.CardData]] [`CardData`] 객체는 [`ModelCardData`]와 [`DatasetCardData`]의 상위 클래스입니다. [[autodoc]] huggingface_hub.repocard_data.CardData ## 모델 카드[[model-cards]] ### ModelCard[[huggingface_hub.ModelCard]] [[autodoc]] ModelCard ### ModelCardData[[huggingface_hub.ModelCardData]] [[autodoc]] ModelCardData ## 데이터 세트 카드[[cards#dataset-cards]] ML 커뮤니티에서는 데이터 세트 카드를 데이터 카드라고도 합니다. ### DatasetCard[[huggingface_hub.DatasetCard]] [[autodoc]] DatasetCard ### DatasetCardData[[huggingface_hub.DatasetCardData]] [[autodoc]] DatasetCardData ## 공간 카드[[space-cards]] ### SpaceCard[[huggingface_hub.SpaceCardData]] [[autodoc]] SpaceCard ### SpaceCardData[[huggingface_hub.SpaceCardData]] [[autodoc]] SpaceCardData ## 유틸리티[[utilities]] ### EvalResult[[huggingface_hub.EvalResult]] [[autodoc]] EvalResult ### model_index_to_eval_results[[huggingface_hub.repocard_data.model_index_to_eval_results]] [[autodoc]] huggingface_hub.repocard_data.model_index_to_eval_results ### eval_results_to_model_index[[huggingface_hub.repocard_data.eval_results_to_model_index]] [[autodoc]] huggingface_hub.repocard_data.eval_results_to_model_index ### metadata_eval_result[[huggingface_hub.metadata_eval_result]] [[autodoc]] huggingface_hub.repocard.metadata_eval_result ### metadata_update[[huggingface_hub.metadata_update]] [[autodoc]] huggingface_hub.repocard.metadata_updatehuggingface_hub-0.31.1/docs/source/ko/package_reference/collections.md000066400000000000000000000015421500667546600260340ustar00rootroot00000000000000 # 컬렉션 관리[[managing-collections]] Hub에서 Space를 관리하는 메소드에 대한 자세한 설명은 [`HfApi`] 페이지를 확인하세요. - 컬렉션 내용 가져오기: [`get_collection`] - 새로운 컬렉션 생성: [`create_collection`] - 컬렉션 업데이트: [`update_collection_metadata`] - 컬렉션 삭제: [`delete_collection`] - 컬렉션에 항목 추가: [`add_collection_item`] - 컬렉션의 항목 업데이트: [`update_collection_item`] - 컬렉션에서 항목 제거: [`delete_collection_item`] ### Collection[[huggingface_hub.Collection]] [[autodoc]] Collection ### CollectionItem[[huggingface_hub.CollectionItem]] [[autodoc]] CollectionItem huggingface_hub-0.31.1/docs/source/ko/package_reference/community.md000066400000000000000000000017251500667546600255450ustar00rootroot00000000000000 # Discussions 및 Pull Requests를 이용하여 상호작용하기[[interacting-with-discussions-and-pull-requests]] Hub에서 Discussions 및 Pull Requests를 이용하여 상호 작용할 수 있는 방법에 대해 참조하고자 한다면 [`HfApi`] 문서 페이지를 확인하세요. - [`get_repo_discussions`] - [`get_discussion_details`] - [`create_discussion`] - [`create_pull_request`] - [`rename_discussion`] - [`comment_discussion`] - [`edit_discussion_comment`] - [`change_discussion_status`] - [`merge_pull_request`] ## 데이터 구조[[huggingface_hub.Discussion]] [[autodoc]] Discussion [[autodoc]] DiscussionWithDetails [[autodoc]] DiscussionEvent [[autodoc]] DiscussionComment [[autodoc]] DiscussionStatusChange [[autodoc]] DiscussionCommit [[autodoc]] DiscussionTitleChange huggingface_hub-0.31.1/docs/source/ko/package_reference/environment_variables.md000066400000000000000000000303441500667546600301140ustar00rootroot00000000000000 # 환경 변수[[environment-variables]] `huggingface_hub`는 환경 변수를 사용해 설정할 수 있습니다. 환경 변수에 대해 잘 알지 못하다면 그에 대한 문서인 [macOS and Linux](https://linuxize.com/post/how-to-set-and-list-environment-variables-in-linux/)와 [Windows](https://phoenixnap.com/kb/windows-set-environment-variable)를 참고하세요. 이 문서에서는 `huggingface_hub`와 관련된 모든 환경 변수와 그 의미에 대해 안내합니다. ## 일반적인 변수[[generic]] ### HF_INFERENCE_ENDPOINT[[hfinferenceendpoint]] 추론 API 기본 URL을 구성합니다. 조직에서 추론 API를 직접 가리키는 것이 아니라 API 게이트웨이를 가리키는 경우 이 변수를 설정할 수 있습니다. 기본값은 `"https://api-inference.huggingface.co"`입니다. ### HF_HOME[[hfhome]] `huggingface_hub`가 어디에 데이터를 로컬로 저장할 지 위치를 구성합니다. 특히 토큰과 캐시가 이 폴더에 저장됩니다. [XDG_CACHE_HOME](#xdgcachehome)이 설정되어 있지 않다면, 기본값은 `"~/.cache/huggingface"`입니다. ### HF_HUB_CACHE[[hfhubcache]] Hub의 리포지토리가 로컬로 캐시될 위치(모델, 데이터세트 및 스페이스)를 구성합니다. 기본값은 `"$HF_HOME/hub"` (예로 들면, 기본 설정은 `"~/.cache/huggingface/hub"`)입니다. ### HF_ASSETS_CACHE[[hfassetscache]] 다운스트림 라이브러리에서 생성된 [assets](../guides/manage-cache#caching-assets)가 로컬로 캐시되는 위치를 구성합니다. 이 assets은 전처리된 데이터, GitHub에서 다운로드한 파일, 로그, ... 등이 될 수 있습니다. 기본값은 `"$HF_HOME/assets"` (예로 들면, 기본 설정은 `"~/.cache/huggingface/assets"`)입니다. ### HF_TOKEN[[hftoken]] Hub에 인증하기 위한 사용자 액세스 토큰을 구성합니다. 이 값을 설정하면 머신에 저장된 토큰을 덮어씁니다(`$HF_TOKEN_PATH`, 또는 `$HF_TOKEN_PATH`가 설정되지 않은 경우 `"$HF_HOME/token"`에 저장됨). 인증에 대한 자세한 내용은 [이 섹션](../quick-start#인증)을 참조하세요. ### HF_TOKEN_PATH[[hftokenpath]] `huggingface_hub`가 사용자 액세스 토큰(User Access Token)을 저장할 위치를 구성합니다. 기본값은 `"$HF_HOME/token"`(예로 들면, 기본 설정은 `~/.cache/huggingface/token`)입니다. ### HF_HUB_VERBOSITY[[hfhubverbosity]] `huggingface_hub`의 로거(logger)의 상세도 수준(verbosity level)을 설정합니다. 다음 중 하나여야 합니다. `{"debug", "info", "warning", "error", "critical"}` 중 하나여야 합니다. 기본값은 `"warning"`입니다. 더 자세한 정보를 알아보고 싶다면, [logging reference](../package_reference/utilities#huggingface_hub.utils.logging.get_verbosity)를 살펴보세요. ### HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD[[hfhublocaldirautosymlinkthreshold]] 이 환경 변수는 더 이상 사용되지 않으며 이제 `huggingface_hub`에서 무시됩니다. 로컬 디렉터리로 파일을 다운로드할 때 더 이상 심볼릭 링크에 의존하지 않습니다. ### HF_HUB_ETAG_TIMEOUT[[hfhubetagtimeout]] 파일을 다운로드하기 전에 리포지토리에서 최신 메타데이터를 가져올 때 서버 응답을 기다리는 시간(초)을 정의하는 정수 값입니다. 요청 시간이 초과되면 `huggingface_hub`는 기본적으로 로컬에 캐시된 파일을 사용합니다. 값을 낮게 설정하면 이미 파일을 캐시한 연결 속도가 느린 컴퓨터의 워크플로 속도가 빨라집니다. 값이 클수록 더 많은 경우에서 메타데이터 호출이 성공할 수 있습니다. 기본값은 10초입니다. ### HF_HUB_DOWNLOAD_TIMEOUT[[hfhubdownloadtimeout]] 파일을 다운로드할 때 서버 응답을 기다리는 시간(초)을 정의하는 정수 값입니다. 요청 시간이 초과되면 TimeoutError가 발생합니다. 연결 속도가 느린 컴퓨터에서는 값을 높게 설정하는 것이 좋습니다. 값이 작을수록 네트워크가 완전히 중단된 경우에 프로세스가 더 빨리 실패합니다. 기본값은 10초입니다. ## 불리언 값[[boolean-values]] 다음 환경 변수는 불리언 값을 요구합니다. 변수는 값이 `{"1", "ON", "YES", "TRUE"}`(대소문자 구분 없음) 중 하나이면 `True`로 간주합니다. 다른 값(또는 정의되지 않음)은 `False`로 간주됩니다. ### HF_HUB_OFFLINE[[hfhuboffline]] 이 옵션을 설정하면 Hugging Face Hub에 HTTP 호출이 이루어지지 않습니다. 파일을 다운로드하려고 하면 캐시된 파일만 액세스됩니다. 캐시 파일이 감지되지 않으면 오류를 발생합니다. 네트워크 속도가 느리고 파일의 최신 버전이 중요하지 않은 경우에 유용합니다. 환경 변수로 `HF_HUB_OFFLINE=1`이 설정되어 있고 [`HfApi`]의 메소드를 호출하면 [`~huggingface_hub.utils.OfflineModeIsEnabled`] 예외가 발생합니다. **참고:** 최신 버전의 파일이 캐시되어 있더라도 `hf_hub_download`를 호출하면 새 버전을 사용할 수 없는지 확인하기 위해 HTTP 요청이 발생합니다. `HF_HUB_OFFLINE=1`을 설정하면 이 호출을 건너뛰어 로딩 시간이 빨라집니다. ### HF_HUB_DISABLE_IMPLICIT_TOKEN[[hfhubdisableimplicittoken]] Hub에 대한 모든 요청이 반드시 인증을 필요로 하는 것은 아닙니다. 예를 들어 `"gpt2"` 모델에 대한 세부 정보를 요청하는 경우에는 인증이 필요하지 않습니다. 그러나 사용자가 [로그인](../package_reference/login) 상태인 경우, 기본 동작은 사용자 경험을 편하게 하기 위해 비공개 또는 게이트 리포지토리에 액세스할 때 항상 토큰을 전송하는 것(HTTP 401 권한 없음이 표시되지 않음)입니다. 개인 정보 보호를 위해 `HF_HUB_DISABLE_IMPLICIT_TOKEN=1`로 설정하여 이 동작을 비활성화할 수 있습니다. 이 경우 토큰은 "쓰기 권한" 호출(예: 커밋 생성)에만 전송됩니다. **참고:** 토큰을 항상 전송하는 것을 비활성화하면 이상한 부작용이 발생할 수 있습니다. 예를 들어 Hub에 모든 모델을 나열하려는 경우 당신의 비공개 모델은 나열되지 않습니다. 사용자 스크립트에 명시적으로 `token=True` 인수를 전달해야 합니다. ### HF_HUB_DISABLE_PROGRESS_BARS[[hfhubdisableprogressbars]] 시간이 오래 걸리는 작업의 경우 `huggingface_hub`는 기본적으로 진행률 표시줄을 표시합니다(tqdm 사용). 모든 진행률 표시줄을 한 번에 비활성화하려면 `HF_HUB_DISABLE_PROGRESS_BARS=1`으로 설정하면 됩니다. ### HF_HUB_DISABLE_SYMLINKS_WARNING[[hfhubdisablesymlinkswarning]] Windows 머신을 사용하는 경우 개발자 모드를 활성화하거나 관리자 모드에서 `huggingface_hub`를 관리자 모드로 실행하는 것이 좋습니다. 그렇지 않은 경우 `huggingface_hub`가 캐시 시스템에 심볼릭 링크를 생성할 수 없습니다. 모든 스크립트를 실행할 수 있지만 일부 대용량 파일이 하드 드라이브에 중복될 수 있으므로 사용자 경힘이 저하될 수 있습니다. 이 동작을 경고하기 위해 경고 메시지가 나타납니다. `HF_HUB_DISABLE_SYMLINKS_WARNING=1`로 설정하면 이 경고를 비활성화할 수 있습니다. 자세한 내용은 [캐시 제한](../guides/manage-cache#limitations)을 참조하세요. ### HF_HUB_DISABLE_EXPERIMENTAL_WARNING[[hfhubdisableexperimentalwarning]] `huggingface_hub`의 일부 기능은 실험 단계입니다. 즉, 사용은 가능하지만 향후 유지될 것이라고 보장할 수는 없습니다. 특히 이러한 기능의 API나 동작은 지원 중단 없이 업데이트될 수 있습니다. 실험적 기능을 사용할 때는 이에 대한 경고를 위해 경고 메시지가 나타납니다. 실험적 기능을 사용하여 잠재적인 문제를 디버깅하는 것이 편하다면 `HF_HUB_DISABLE_EXPERIMENTAL_WARNING=1`으로 설정하여 경고를 비활성화할 수 있습니다. 실험적인 기능을 사용 중이라면 알려주세요! 여러분의 피드백은 기능을 설계하고 개선하는 데 도움이 됩니다. ### HF_HUB_DISABLE_TELEMETRY[[hfhubdisabletelemetry]] 기본적으로 일부 데이터는 사용량을 모니터링하고 문제를 디버그하며 기능의 우선순위를 정하는 데 도움을 주기 위해 HF 라이브러리(`transformers`, `datasets`, `gradio`,...)에서 수집합니다. 각 라이브러리는 자체 정책(즉, 모니터링할 사용량)을 정의하지만 핵심 구현은 `huggingface_hub`에서 이루어집니다([`send_telemetry`] 참조). 환경 변수로 `HF_HUB_DISABLE_TELEMETRY=1`을 설정하여 원격 측정을 전역적으로 비활성화할 수 있습니다. ### HF_HUB_ENABLE_HF_TRANSFER[[hfhubenablehftransfer]] Hub에서 `hf_transfer`를 사용하여 더 빠르게 업로드 및 다운로드하려면 `True`로 설정하세요. 기본적으로 `huggingface_hub`는 파이썬 기반 `requests.get` 및 `requests.post` 함수를 사용합니다. 이 함수들은 안정적이고 다용도로 사용할 수 있지만 대역폭이 높은 머신에서는 가장 효율적인 선택이 아닐 수 있습니다. [`hf_transfer`](https://github.com/huggingface/hf_transfer)는 대용량 파일을 작은 부분으로 분할하여 사용 대역폭을 최대화하고 여러 스레드를 사용하여 동시에 전송함으로써 대역폭을 최대화하기 위해 개발된 Rust 기반 패키지입니다. 이 접근 방식은 전송 속도를 거의 두 배로 높일 수 있습니다. `hf_transfer`를 사용하려면: 1. `huggingface_hub`를 설치할 때 `hf_transfer`를 추가로 지정합니다. (예시: `pip install huggingface_hub[hf_transfer]`). 2. 환경 변수로 `HF_HUB_ENABLE_HF_TRANSFER=1`을 설정합니다. `hf_transfer`를 사용하면 특정 제한 사항이 있다는 점에 유의하세요. 순수 파이썬 기반이 아니므로 오류 디버깅이 어려울 수 있습니다. 또한 `hf_transfer`에는 다운로드 재개 및 프록시와 같은 몇 가지 사용자 친화적인 기능이 없습니다. 이런 부족한 부분은 Rust 로직의 단순성과 속도를 유지하기 위해 의도한 것입니다. 이런 이유들로, `hf_transfer`는 `huggingface_hub`에서 기본적으로 활성화되지 않습니다. ## 사용되지 않는 환경 변수[[deprecated-environment-variables]] Hugging Face 생태계의 모든 환경 변수를 표준화하기 위해 일부 변수는 사용되지 않는 것으로 표시되었습니다. 해당 변수는 여전히 작동하지만 더 이상 대체한 변수보다 우선하지 않습니다. 다음 표에는 사용되지 않는 변수와 해당 대체 변수가 간략하게 설명되어 있습니다: | 사용되지 않는 변수 | 대체 변수 | | --- | --- | | `HUGGINGFACE_HUB_CACHE` | `HF_HUB_CACHE` | | `HUGGINGFACE_ASSETS_CACHE` | `HF_ASSETS_CACHE` | | `HUGGING_FACE_HUB_TOKEN` | `HF_TOKEN` | | `HUGGINGFACE_HUB_VERBOSITY` | `HF_HUB_VERBOSITY` | ## 외부 도구[[from-external-tools]] 일부 환경 변수는 `huggingface_hub`에만 특정되지는 않지만 설정 시 함께 고려됩니다. ### DO_NOT_TRACK[[donottrack]] 불리언 값입니다. `hf_hub_disable_telemetry`에 해당합니다. True로 설정하면 Hugging Face Python 생태계(`transformers`, `diffusers`, `gradio` 등)에서 원격 측정이 전역적으로 비활성화됩니다. 자세한 내용은 https://consoledonottrack.com/ 을 참조하세요. ### NO_COLOR[[nocolor]] 불리언 값입니다. 이 값을 설정하면 `huggingface-cli` 도구는 ANSI 색상을 출력하지 않습니다. [no-color.org](https://no-color.org/)를 참조하세요. ### XDG_CACHE_HOME[[xdgcachehome]] `HF_HOME`이 설정되지 않은 경우에만 사용합니다! 이것은 Linux 시스템에서 [사용자별 비필수(캐시된) 데이터](https://wiki.archlinux.org/title/XDG_Base_Directory)가 쓰여져야 하는 위치를 구성하는 기본 방법입니다. `HF_HOME`이 설정되지 않은 경우 기본 홈은 `"~/.cache/huggingface"`대신 `"$XDG_CACHE_HOME/huggingface"`가 됩니다. huggingface_hub-0.31.1/docs/source/ko/package_reference/file_download.md000066400000000000000000000025351500667546600263270ustar00rootroot00000000000000 # 파일 다운로드 하기[[downloading-files]] ## 단일 파일 다운로드하기[[download-a-single-file]] ### hf_hub_download[[huggingface_hub.hf_hub_download]] [[autodoc]]huggingface_hub.hf_hub_download ### hf_hub_url[[huggingface_hub.hf_hub_url]] [[autodoc]]huggingface_hub.hf_hub_url ## 리포지토리의 스냅샷 다운로드하기[[huggingface_hub.snapshot_download]] [[autodoc]]huggingface_hub.snapshot_download ## 파일에 대한 메타데이터 가져오기[[get-metadata-about-a-file]] ### get_hf_file_metadata[[huggingface_hub.get_hf_file_metadata]] [[autodoc]]huggingface_hub.get_hf_file_metadata ### HfFileMetadata[[huggingface_hub.HfFileMetadata]] [[autodoc]]huggingface_hub.HfFileMetadata ## 캐싱[[caching]] 위에 나열된 메소드들은 파일을 재다운로드하지 않도록 하는 캐싱 시스템과 함께 작동하도록 설계되었습니다. v0.8.0에서의 업데이트로, 캐싱 시스템은 Hub를 기반으로 하는 다양한 라이브러리 간의 공유 중앙 캐시 시스템으로 발전했습니다. Hugging Face에서의 캐싱에 대한 자세한 설명은[캐시 시스템 가이드](../guides/manage-cache)를 참조하세요. huggingface_hub-0.31.1/docs/source/ko/package_reference/hf_api.md000066400000000000000000000064121500667546600247450ustar00rootroot00000000000000 # HfApi Client[[hfapi-client]] 아래는 허깅 페이스 Hub의 API를 위한 파이썬 래퍼인 `HfApi` 클래스에 대한 문서입니다. `HfApi`의 모든 메서드는 패키지의 루트에서 직접 접근할 수 있습니다. 두 접근 방식은 아래에서 자세히 설명합니다. 루트 메서드를 사용하는 것이 더 간단하지만 [`HfApi`] 클래스를 사용하면 더 유연하게 사용할 수 있습니다. 특히 모든 HTTP 호출에서 재사용할 토큰을 전달할 수 있습니다. 이 방식은 토큰이 머신에 유지되지 않기 때문에 `huggingface-cli login` 또는 [`login`]를 사용하는 방식과는 다르며, 다른 엔드포인트를 제공하거나 사용자정의 에이전트를 구성할 수도 있습니다. ```python from huggingface_hub import HfApi, list_models # 루트 메서드를 사용하세요. models = list_models() # 또는 HfApi client를 구성하세요. hf_api = HfApi( endpoint="https://huggingface.co", # 비공개 Hub 엔드포인트를 지정할 수 있습니다. token="hf_xxx", # 토큰은 머신에 유지되지 않습니다. ) models = hf_api.list_models() ``` ## HfApi[[huggingface_hub.HfApi]] [[autodoc]] HfApi ## API Dataclasses[[api-dataclasses]] ### AccessRequest[[huggingface_hub.hf_api.AccessRequest]] [[autodoc]] huggingface_hub.hf_api.AccessRequest ### CommitInfo[[huggingface_hub.CommitInfo]] [[autodoc]] huggingface_hub.hf_api.CommitInfo ### DatasetInfo[[huggingface_hub.hf_api.DatasetInfo]] [[autodoc]] huggingface_hub.hf_api.DatasetInfo ### GitRefInfo[[huggingface_hub.GitRefInfo]] [[autodoc]] huggingface_hub.hf_api.GitRefInfo ### GitCommitInfo[[huggingface_hub.GitCommitInfo]] [[autodoc]] huggingface_hub.hf_api.GitCommitInfo ### GitRefs[[huggingface_hub.GitRefs]] [[autodoc]] huggingface_hub.hf_api.GitRefs ### ModelInfo[[huggingface_hub.hf_api.ModelInfo]] [[autodoc]] huggingface_hub.hf_api.ModelInfo ### RepoSibling[[huggingface_hub.hf_api.RepoSibling]] [[autodoc]] huggingface_hub.hf_api.RepoSibling ### RepoFile[[huggingface_hub.hf_api.RepoFile]] [[autodoc]] huggingface_hub.hf_api.RepoFile ### RepoUrl[[huggingface_hub.RepoUrl]] [[autodoc]] huggingface_hub.hf_api.RepoUrl ### SafetensorsRepoMetadata[[huggingface_hub.utils.SafetensorsRepoMetadata]] [[autodoc]] huggingface_hub.utils.SafetensorsRepoMetadata ### SafetensorsFileMetadata[[huggingface_hub.utils.SafetensorsFileMetadata]] [[autodoc]] huggingface_hub.utils.SafetensorsFileMetadata ### SpaceInfo[[huggingface_hub.hf_api.SpaceInfo]] [[autodoc]] huggingface_hub.hf_api.SpaceInfo ### TensorInfo[[huggingface_hub.utils.TensorInfo]] [[autodoc]] huggingface_hub.utils.TensorInfo ### User[[huggingface_hub.User]] [[autodoc]] huggingface_hub.hf_api.User ### UserLikes[[huggingface_hub.UserLikes]] [[autodoc]] huggingface_hub.hf_api.UserLikes ## CommitOperation[[huggingface_hub.CommitOperationAdd]] [`CommitOperation`]에 지원되는 값은 다음과 같습니다: [[autodoc]] CommitOperationAdd [[autodoc]] CommitOperationDelete [[autodoc]] CommitOperationCopy ## CommitScheduler[[huggingface_hub.CommitScheduler]] [[autodoc]] CommitScheduler huggingface_hub-0.31.1/docs/source/ko/package_reference/hf_file_system.md000066400000000000000000000015371500667546600265220ustar00rootroot00000000000000 # 파일 시스템 API[[filesystem-api]] [`HfFileSystem`] 클래스는 [`fsspec`](https://filesystem-spec.readthedocs.io/en/latest/)을 기반으로 Hugging Face Hub에 Python 파일 인터페이스를 제공합니다. ## [HfFileSystem](Hf파일시스템) [`HfFileSystem`]은 [`fsspec`](https://filesystem-spec.readthedocs.io/en/latest/)을 기반으로 하므로 제공되는 대부분의 API와 호환됩니다. 자세한 내용은 [가이드](../guides/hf_file_system) 및 fsspec의 [API 레퍼런스](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem)를 확인하세요. [[autodoc]] HfFileSystem - __init__ - resolve_path - ls huggingface_hub-0.31.1/docs/source/ko/package_reference/inference_client.md000066400000000000000000000061361500667546600270160ustar00rootroot00000000000000 # 추론[[inference]] 추론은 학습된 모델을 사용하여 새로운 데이터를 예측하는 과정입니다. 이 과정은 계산량이 많을 수 있기 때문에, 전용 서버에서 실행하는 것이 흥미로운 옵션이 될 수 있습니다. `huggingface_hub` 라이브러리는 호스팅된 모델에 대한 추론을 실행하는 간단한 방법을 제공합니다. 연결할 수 있는 서비스는 여러가지가 있습니다: - [추론 API](https://huggingface.co/docs/api-inference/index): Hugging Face의 인프라에서 가속화된 추론을 무료로 실행할 수 있는 서비스입니다. 이 서비스는 시작하기 위한 빠른 방법이며, 다양한 모델을 테스트하고 AI 제품을 프로토타입화하는 데에도 유용합니다. - [추론 엔드포인트](https://huggingface.co/inference-endpoints): 모델을 쉽게 운영 환경으로 배포할 수 있는 제품입니다. 추론은 여러분이 선택한 클라우드 제공업체의 전용 및 완전히 관리되는 인프라에서 Hugging Face에 의해 실행됩니다. 이러한 서비스는 [`InferenceClient`] 객체를 사용하여 호출할 수 있습니다. 자세한 사용 방법에 대해서는 [이 가이드](../guides/inference)를 참조해주세요. ## 추론 클라이언트[[huggingface_hub.InferenceClient]] [[autodoc]] InferenceClient ## 비동기 추론 클라이언트[[huggingface_hub.AsyncInferenceClient]] 비동기 버전의 클라이언트도 제공되며, 이는 `asyncio`와 `aiohttp`를 기반으로 작동합니다. 이를 사용하려면 `aiohttp`를 직접 설치하거나 `[inference]` 추가 기능을 사용할 수 있습니다: ```sh pip install --upgrade huggingface_hub[inference] # 또는 # pip install aiohttp ``` [[autodoc]] AsyncInferenceClient ## 추론 시간 초과 오류[[huggingface_hub.InferenceTimeoutError]] [[autodoc]] InferenceTimeoutError ## 반환 유형[[return-types]] 대부분의 작업에 대해, 반환 값은 내장된 유형(string, list, image...)을 갖습니다. 보다 복잡한 유형을 위한 목록은 다음과 같습니다. ### 모델 상태[[huggingface_hub.inference._common.ModelStatus]] [[autodoc]] huggingface_hub.inference._common.ModelStatus ## 추론 API[[huggingface_hub.InferenceApi]] [`InferenceAPI`]는 추론 API를 호출하는 레거시 방식입니다. 이 인터페이스는 더 간단하며 각 작업의 입력 매개변수와 출력 형식을 알아야 합니다. 또한 추론 엔드포인트나 AWS SageMaker와 같은 다른 서비스에 연결할 수 있는 기능이 없습니다. [`InferenceAPI`]는 곧 폐지될 예정이므로 가능한 경우 [`InferenceClient`]를 사용하는 것을 권장합니다. 스크립트에서 [`InferenceAPI`]를 [`InferenceClient`]로 전환하는 방법에 대해 알아보려면 [이 가이드](../guides/inference#legacy-inferenceapi-client)를 참조하세요. [[autodoc]] InferenceApi - __init__ - __call__ - all huggingface_hub-0.31.1/docs/source/ko/package_reference/inference_endpoints.md000066400000000000000000000050461500667546600275420ustar00rootroot00000000000000# 추론 엔드포인트 [[inference-endpoints]] Hugging Face가 관리하는 추론 엔드포인트는 우리가 모델을 쉽고 안전하게 배포할 수 있게 해주는 도구입니다. 이러한 추론 엔드포인트는 [Hub](https://huggingface.co/models)에 있는 모델을 기반으로 설계되었습니다. 이 문서는 `huggingface_hub`와 추론 엔드포인트 통합에 관한 참조 페이지이며, 더욱 자세한 정보는 [공식 문서](https://huggingface.co/docs/inference-endpoints/index)를 통해 확인할 수 있습니다. 'huggingface_hub'를 사용하여 추론 엔드포인트를 프로그래밍 방식으로 관리하는 방법을 알고 싶다면, [관련 가이드](../guides/inference_endpoints)를 확인해 보세요. 추론 엔드포인트는 API로 쉽게 접근할 수 있습니다. 이 엔드포인트들은 [Swagger](https://api.endpoints.huggingface.cloud/)를 통해 문서화되어 있고, [`InferenceEndpoint`] 클래스는 이 API를 사용해 만든 간단한 래퍼입니다. ## 매소드 [[methods]] 다음과 같은 추론 엔드포인트의 기능이 [`HfApi`]안에 구현되어 있습니다: - [`get_inference_endpoint`]와 [`list_inference_endpoints`]를 사용해 엔드포인트 정보를 조회할 수 있습니다. - [`create_inference_endpoint`], [`update_inference_endpoint`], [`delete_inference_endpoint`]로 엔드포인트를 배포하고 관리할 수 있습니다. - [`pause_inference_endpoint`]와 [`resume_inference_endpoint`]로 엔드포인트를 잠시 멈추거나 다시 시작할 수 있습니다. - [`scale_to_zero_inference_endpoint`]로 엔드포인트의 복제본을 0개로 설정할 수 있습니다. ## InferenceEndpoint [[huggingface_hub.InferenceEndpoint]] 기본 데이터 클래스는 [`InferenceEndpoint`]입니다. 여기에는 구성 및 현재 상태를 가지고 있는 배포된 `InferenceEndpoint`에 대한 정보가 포함되어 있습니다. 배포 후에는 [`InferenceEndpoint.client`]와 [`InferenceEndpoint.async_client`]를 사용해 엔드포인트에서 추론 작업을 할 수 있고, 이때 [`InferenceClient`]와 [`AsyncInferenceClient`] 객체를 반환합니다. [[autodoc]] InferenceEndpoint - from_raw - client - async_client - all ## InferenceEndpointStatus [[huggingface_hub.InferenceEndpointStatus]] [[autodoc]] InferenceEndpointStatus ## InferenceEndpointType [[huggingface_hub.InferenceEndpointType]] [[autodoc]] InferenceEndpointType ## InferenceEndpointError [[huggingface_hub.InferenceEndpointError]] [[autodoc]] InferenceEndpointError huggingface_hub-0.31.1/docs/source/ko/package_reference/inference_types.md000066400000000000000000000250041500667546600266770ustar00rootroot00000000000000 # 추론 타입[[inference-types]] 이 페이지에는 Hugging Face Hub에서 지원하는 타입(예: 데이터 클래스)이 나열되어 있습니다. 각 작업은 JSON 스키마를 사용하여 지정되며, 이러한 스키마에 의해서 타입이 생성됩니다. 이때 Python 요구 사항으로 인해 일부 사용자 정의가 있을 수 있습니다. 각 작업의 JSON 스키마를 확인하려면 [@huggingface.js/tasks](https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks)를 확인하세요. 라이브러리에서 이 부분은 아직 개발 중이며, 향후 릴리즈에서 개선될 예정입니다. ## audio_classification[[huggingface_hub.AudioClassificationInput]] [[autodoc]] huggingface_hub.AudioClassificationInput [[autodoc]] huggingface_hub.AudioClassificationOutputElement [[autodoc]] huggingface_hub.AudioClassificationParameters ## audio_to_audio[[huggingface_hub.AudioToAudioInput]] [[autodoc]] huggingface_hub.AudioToAudioInput [[autodoc]] huggingface_hub.AudioToAudioOutputElement ## automatic_speech_recognition[[huggingface_hub.AutomaticSpeechRecognitionGenerationParameters]] [[autodoc]] huggingface_hub.AutomaticSpeechRecognitionGenerationParameters [[autodoc]] huggingface_hub.AutomaticSpeechRecognitionInput [[autodoc]] huggingface_hub.AutomaticSpeechRecognitionOutput [[autodoc]] huggingface_hub.AutomaticSpeechRecognitionOutputChunk [[autodoc]] huggingface_hub.AutomaticSpeechRecognitionParameters ## chat_completion[[huggingface_hub.ChatCompletionInput]] [[autodoc]] huggingface_hub.ChatCompletionInput [[autodoc]] huggingface_hub.ChatCompletionInputFunctionDefinition [[autodoc]] huggingface_hub.ChatCompletionInputFunctionName [[autodoc]] huggingface_hub.ChatCompletionInputGrammarType [[autodoc]] huggingface_hub.ChatCompletionInputMessage [[autodoc]] huggingface_hub.ChatCompletionInputMessageChunk [[autodoc]] huggingface_hub.ChatCompletionInputStreamOptions [[autodoc]] huggingface_hub.ChatCompletionInputTool [[autodoc]] huggingface_hub.ChatCompletionInputToolCall [[autodoc]] huggingface_hub.ChatCompletionInputToolChoiceClass [[autodoc]] huggingface_hub.ChatCompletionInputURL [[autodoc]] huggingface_hub.ChatCompletionOutput [[autodoc]] huggingface_hub.ChatCompletionOutputComplete [[autodoc]] huggingface_hub.ChatCompletionOutputFunctionDefinition [[autodoc]] huggingface_hub.ChatCompletionOutputLogprob [[autodoc]] huggingface_hub.ChatCompletionOutputLogprobs [[autodoc]] huggingface_hub.ChatCompletionOutputMessage [[autodoc]] huggingface_hub.ChatCompletionOutputToolCall [[autodoc]] huggingface_hub.ChatCompletionOutputTopLogprob [[autodoc]] huggingface_hub.ChatCompletionOutputUsage [[autodoc]] huggingface_hub.ChatCompletionStreamOutput [[autodoc]] huggingface_hub.ChatCompletionStreamOutputChoice [[autodoc]] huggingface_hub.ChatCompletionStreamOutputDelta [[autodoc]] huggingface_hub.ChatCompletionStreamOutputDeltaToolCall [[autodoc]] huggingface_hub.ChatCompletionStreamOutputFunction [[autodoc]] huggingface_hub.ChatCompletionStreamOutputLogprob [[autodoc]] huggingface_hub.ChatCompletionStreamOutputLogprobs [[autodoc]] huggingface_hub.ChatCompletionStreamOutputTopLogprob [[autodoc]] huggingface_hub.ChatCompletionStreamOutputUsage ## depth_estimation[[huggingface_hub.DepthEstimationInput]] [[autodoc]] huggingface_hub.DepthEstimationInput [[autodoc]] huggingface_hub.DepthEstimationOutput ## document_question_answering[[huggingface_hub.DocumentQuestionAnsweringInput]] [[autodoc]] huggingface_hub.DocumentQuestionAnsweringInput [[autodoc]] huggingface_hub.DocumentQuestionAnsweringInputData [[autodoc]] huggingface_hub.DocumentQuestionAnsweringOutputElement [[autodoc]] huggingface_hub.DocumentQuestionAnsweringParameters ## feature_extraction[[huggingface_hub.FeatureExtractionInput]] [[autodoc]] huggingface_hub.FeatureExtractionInput ## fill_mask[[huggingface_hub.FillMaskInput]] [[autodoc]] huggingface_hub.FillMaskInput [[autodoc]] huggingface_hub.FillMaskOutputElement [[autodoc]] huggingface_hub.FillMaskParameters ## image_classification[[huggingface_hub.ImageClassificationInput]] [[autodoc]] huggingface_hub.ImageClassificationInput [[autodoc]] huggingface_hub.ImageClassificationOutputElement [[autodoc]] huggingface_hub.ImageClassificationParameters ## image_segmentation[[huggingface_hub.ImageSegmentationInput]] [[autodoc]] huggingface_hub.ImageSegmentationInput [[autodoc]] huggingface_hub.ImageSegmentationOutputElement [[autodoc]] huggingface_hub.ImageSegmentationParameters ## image_to_image[[huggingface_hub.ImageToImageInput]] [[autodoc]] huggingface_hub.ImageToImageInput [[autodoc]] huggingface_hub.ImageToImageOutput [[autodoc]] huggingface_hub.ImageToImageParameters [[autodoc]] huggingface_hub.ImageToImageTargetSize ## image_to_text[[huggingface_hub.ImageToTextGenerationParameters]] [[autodoc]] huggingface_hub.ImageToTextGenerationParameters [[autodoc]] huggingface_hub.ImageToTextInput [[autodoc]] huggingface_hub.ImageToTextOutput [[autodoc]] huggingface_hub.ImageToTextParameters ## object_detection[[huggingface_hub.ObjectDetectionBoundingBox]] [[autodoc]] huggingface_hub.ObjectDetectionBoundingBox [[autodoc]] huggingface_hub.ObjectDetectionInput [[autodoc]] huggingface_hub.ObjectDetectionOutputElement [[autodoc]] huggingface_hub.ObjectDetectionParameters ## question_answering[[huggingface_hub.QuestionAnsweringInput]] [[autodoc]] huggingface_hub.QuestionAnsweringInput [[autodoc]] huggingface_hub.QuestionAnsweringInputData [[autodoc]] huggingface_hub.QuestionAnsweringOutputElement [[autodoc]] huggingface_hub.QuestionAnsweringParameters ## sentence_similarity[[huggingface_hub.SentenceSimilarityInput]] [[autodoc]] huggingface_hub.SentenceSimilarityInput [[autodoc]] huggingface_hub.SentenceSimilarityInputData ## summarization[[huggingface_hub.SummarizationInput]] [[autodoc]] huggingface_hub.SummarizationInput [[autodoc]] huggingface_hub.SummarizationOutput [[autodoc]] huggingface_hub.SummarizationParameters ## table_question_answering[[huggingface_hub.TableQuestionAnsweringInput]] [[autodoc]] huggingface_hub.TableQuestionAnsweringInput [[autodoc]] huggingface_hub.TableQuestionAnsweringInputData [[autodoc]] huggingface_hub.TableQuestionAnsweringOutputElement [[autodoc]] huggingface_hub.TableQuestionAnsweringParameters ## text2text_generation[[huggingface_hub.Text2TextGenerationInput]] [[autodoc]] huggingface_hub.Text2TextGenerationInput [[autodoc]] huggingface_hub.Text2TextGenerationOutput [[autodoc]] huggingface_hub.Text2TextGenerationParameters ## text_classification[[huggingface_hub.TextClassificationInput]] [[autodoc]] huggingface_hub.TextClassificationInput [[autodoc]] huggingface_hub.TextClassificationOutputElement [[autodoc]] huggingface_hub.TextClassificationParameters ## text_generation[[huggingface_hub.TextGenerationInput]] [[autodoc]] huggingface_hub.TextGenerationInput [[autodoc]] huggingface_hub.TextGenerationInputGenerateParameters [[autodoc]] huggingface_hub.TextGenerationInputGrammarType [[autodoc]] huggingface_hub.TextGenerationOutput [[autodoc]] huggingface_hub.TextGenerationOutputBestOfSequence [[autodoc]] huggingface_hub.TextGenerationOutputDetails [[autodoc]] huggingface_hub.TextGenerationOutputPrefillToken [[autodoc]] huggingface_hub.TextGenerationOutputToken [[autodoc]] huggingface_hub.TextGenerationStreamOutput [[autodoc]] huggingface_hub.TextGenerationStreamOutputStreamDetails [[autodoc]] huggingface_hub.TextGenerationStreamOutputToken ## text_to_audio[[huggingface_hub.TextToAudioGenerationParameters]] [[autodoc]] huggingface_hub.TextToAudioGenerationParameters [[autodoc]] huggingface_hub.TextToAudioInput [[autodoc]] huggingface_hub.TextToAudioOutput [[autodoc]] huggingface_hub.TextToAudioParameters ## text_to_image[[huggingface_hub.TextToImageInput]] [[autodoc]] huggingface_hub.TextToImageInput [[autodoc]] huggingface_hub.TextToImageOutput [[autodoc]] huggingface_hub.TextToImageParameters ## text_to_speech[[huggingface_hub.TextToSpeechGenerationParameters]] [[autodoc]] huggingface_hub.TextToSpeechGenerationParameters [[autodoc]] huggingface_hub.TextToSpeechInput [[autodoc]] huggingface_hub.TextToSpeechOutput [[autodoc]] huggingface_hub.TextToSpeechParameters ## text_to_video[[huggingface_hub.TextToVideoInput]] [[autodoc]] huggingface_hub.TextToVideoInput [[autodoc]] huggingface_hub.TextToVideoOutput [[autodoc]] huggingface_hub.TextToVideoParameters ## token_classification[[huggingface_hub.TokenClassificationInput]] [[autodoc]] huggingface_hub.TokenClassificationInput [[autodoc]] huggingface_hub.TokenClassificationOutputElement [[autodoc]] huggingface_hub.TokenClassificationParameters ## translation[[huggingface_hub.TranslationInput]] [[autodoc]] huggingface_hub.TranslationInput [[autodoc]] huggingface_hub.TranslationOutput [[autodoc]] huggingface_hub.TranslationParameters ## video_classification[[huggingface_hub.VideoClassificationInput]] [[autodoc]] huggingface_hub.VideoClassificationInput [[autodoc]] huggingface_hub.VideoClassificationOutputElement [[autodoc]] huggingface_hub.VideoClassificationParameters ## visual_question_answering[[huggingface_hub.VisualQuestionAnsweringInput]] [[autodoc]] huggingface_hub.VisualQuestionAnsweringInput [[autodoc]] huggingface_hub.VisualQuestionAnsweringInputData [[autodoc]] huggingface_hub.VisualQuestionAnsweringOutputElement [[autodoc]] huggingface_hub.VisualQuestionAnsweringParameters ## zero_shot_classification[[huggingface_hub.ZeroShotClassificationInput]] [[autodoc]] huggingface_hub.ZeroShotClassificationInput [[autodoc]] huggingface_hub.ZeroShotClassificationOutputElement [[autodoc]] huggingface_hub.ZeroShotClassificationParameters ## zero_shot_image_classification[[huggingface_hub.ZeroShotImageClassificationInput]] [[autodoc]] huggingface_hub.ZeroShotImageClassificationInput [[autodoc]] huggingface_hub.ZeroShotImageClassificationOutputElement [[autodoc]] huggingface_hub.ZeroShotImageClassificationParameters ## zero_shot_object_detection[[huggingface_hub.ZeroShotObjectDetectionBoundingBox]] [[autodoc]] huggingface_hub.ZeroShotObjectDetectionBoundingBox [[autodoc]] huggingface_hub.ZeroShotObjectDetectionInput [[autodoc]] huggingface_hub.ZeroShotObjectDetectionOutputElement [[autodoc]] huggingface_hub.ZeroShotObjectDetectionParameters huggingface_hub-0.31.1/docs/source/ko/package_reference/login.md000066400000000000000000000013161500667546600246250ustar00rootroot00000000000000 # 로그인 및 로그아웃[[login-and-logout]] `huggingface_hub` 라이브러리를 사용하면 사용자의 기기를 Hub에 프로그래밍적으로 로그인/로그아웃할 수 있습니다. 인증에 대한 자세한 내용은 [이 섹션](../quick-start#authentication)을 확인하세요. ## 로그인[[login]] [[autodoc]] login ## 인터프리터_로그인[[interpreter_login]] [[autodoc]] interpreter_login ## 노트북_로그인[[notebook_login]] [[autodoc]] notebook_login ## 로그아웃[[logout]] [[autodoc]] logout huggingface_hub-0.31.1/docs/source/ko/package_reference/mixins.md000066400000000000000000000021541500667546600250250ustar00rootroot00000000000000 # 믹스인 & 직렬화 메소드[[mixins--serialization-methods]] ## 믹스인[[mixins]] `huggingface_hub` 라이브러리는 객체에 함수들의 업로드 및 다운로드 기능을 손쉽게 제공하기 위해서, 부모 클래스로 사용될 수 있는 다양한 믹스인을 제공합니다. ML 프레임워크를 Hub와 통합하는 방법은 [통합 가이드](../guides/integrations)를 통해 배울 수 있습니다. ### 제네릭[[huggingface_hub.ModelHubMixin]] [[autodoc]] ModelHubMixin - all - _save_pretrained - _from_pretrained ### PyTorch[[huggingface_hub.PyTorchModelHubMixin]] [[autodoc]] PyTorchModelHubMixin ### Keras[[huggingface_hub.KerasModelHubMixin]] [[autodoc]] KerasModelHubMixin [[autodoc]] from_pretrained_keras [[autodoc]] push_to_hub_keras [[autodoc]] save_pretrained_keras ### Fastai[[huggingface_hub.from_pretrained_fastai]] [[autodoc]] from_pretrained_fastai [[autodoc]] push_to_hub_fastai huggingface_hub-0.31.1/docs/source/ko/package_reference/overview.md000066400000000000000000000005101500667546600253560ustar00rootroot00000000000000 # Overview[[overview]] 이 섹션은 `huggingface_hub` 클래스와 메서드에 대한 상세하고 기술적인 설명을 포함하고 있습니다. huggingface_hub-0.31.1/docs/source/ko/package_reference/repository.md000066400000000000000000000031111500667546600257270ustar00rootroot00000000000000 # 로컬 및 온라인 리포지토리 관리[[managing-local-and-online-repositories]] `Repository` 클래스는 `git` 및 `git-lfs` 명령을 감싸는 도우미 클래스로, 매우 큰 리포지토리를 관리하는 데 적합한 툴링을 제공합니다. `git` 작업이 포함되거나 리포지토리에서의 협업이 중점이 될 때 권장되는 도구입니다. ## 리포지토리 클래스[[the-repository-class]] [[autodoc]] Repository - __init__ - current_branch - all ## 도우미 메소드[[helper-methods]] [[autodoc]] huggingface_hub.repository.is_git_repo [[autodoc]] huggingface_hub.repository.is_local_clone [[autodoc]] huggingface_hub.repository.is_tracked_with_lfs [[autodoc]] huggingface_hub.repository.is_git_ignored [[autodoc]] huggingface_hub.repository.files_to_be_staged [[autodoc]] huggingface_hub.repository.is_tracked_upstream [[autodoc]] huggingface_hub.repository.commits_to_push ## 후속 비동기 명령[[following-asynchronous-commands]] `Repository` 유틸리티는 비동기적으로 시작할 수 있는 여러 메소드를 제공합니다. - `git_push` - `git_pull` - `push_to_hub` - `commit` 컨텍스트 관리자 이러한 비동기 메소드를 관리하는 유틸리티는 아래를 참조하세요. [[autodoc]] Repository - commands_failed - commands_in_progress - wait_for_commands [[autodoc]] huggingface_hub.repository.CommandInProgress huggingface_hub-0.31.1/docs/source/ko/package_reference/serialization.md000066400000000000000000000041361500667546600263750ustar00rootroot00000000000000 # 직렬화[[serialization]] `huggingface_hub`에는 ML 라이브러리가 모델 가중치를 표준화된 방식으로 직렬화 할 수 있도록 돕는 헬퍼를 포함하고 있습니다. 라이브러리의 이 부분은 아직 개발 중이며 향후 버전에서 개선될 예정입니다. 개선 목표는 Hub에서 가중치의 직렬화 방식을 통일하고, 라이브러리 간 코드 중복을 줄이며, Hub에서의 규약을 촉진하는 것입니다. ## 상태 사전을 샤드로 나누기[[split-state-dict-into-shards]] 현재 이 모듈은 상태 딕셔너리(예: 레이어 이름과 관련 텐서 간의 매핑)를 받아 여러 샤드로 나누고, 이 과정에서 적절한 인덱스를 생성하는 단일 헬퍼를 포함하고 있습니다. 이 헬퍼는 `torch`, `tensorflow`, `numpy` 텐서에 사용 가능하며, 다른 ML 프레임워크로 쉽게 확장될 수 있도록 설계되었습니다. ### split_tf_state_dict_into_shards[[huggingface_hub.split_tf_state_dict_into_shards]] [[autodoc]] huggingface_hub.split_tf_state_dict_into_shards ### split_torch_state_dict_into_shards[[huggingface_hub.split_torch_state_dict_into_shards]] [[autodoc]] huggingface_hub.split_torch_state_dict_into_shards ### split_state_dict_into_shards_factory[[huggingface_hub.split_state_dict_into_shards_factory]] 이것은 각 프레임워크별 헬퍼가 파생되는 기본 틀입니다. 실제로는 아직 지원되지 않는 프레임워크에 맞게 조정할 필요가 있는 경우가 아니면 이 틀을 직접 사용할 것으로 예상되지 않습니다. 그런 경우가 있다면, `huggingface_hub` 리포지토리에 [새로운 이슈를 개설](https://github.com/huggingface/huggingface_hub/issues/new) 하여 알려주세요. [[autodoc]] huggingface_hub.split_state_dict_into_shards_factory ## 도우미 ### get_torch_storage_id[[huggingface_hub.get_torch_storage_id]] [[autodoc]] huggingface_hub.get_torch_storage_idhuggingface_hub-0.31.1/docs/source/ko/package_reference/space_runtime.md000066400000000000000000000020261500667546600263520ustar00rootroot00000000000000 # Space 런타임 관리[[managing-your-space-runtime]] Hub의 Space를 관리하는 메소드에 대한 자세한 설명은 [`HfApi`]페이지를 확인하세요. - Space 복제: [`duplicate_space`] - 현재 런타임 가져오기: [`get_space_runtime`] - 보안 관리: [`add_space_secret`] 및 [`delete_space_secret`] - 하드웨어 관리: [`request_space_hardware`] - 상태 관리: [`pause_space`], [`restart_space`], [`set_space_sleep_time`] ## 데이터 구조[[data-structures]] ### SpaceRuntime[[huggingface_hub.SpaceRuntime]] [[autodoc]] SpaceRuntime ### SpaceHardware[[huggingface_hub.SpaceHardware]] [[autodoc]] SpaceHardware ### SpaceStage[[huggingface_hub.SpaceStage]] [[autodoc]] SpaceStage ### SpaceStorage[[huggingface_hub.SpaceStorage]] [[autodoc]] SpaceStorage ### SpaceVariable[[huggingface_hub.SpaceVariable]] [[autodoc]] SpaceVariablehuggingface_hub-0.31.1/docs/source/ko/package_reference/tensorboard.md000066400000000000000000000025201500667546600260350ustar00rootroot00000000000000 # TensorBoard 로거[[tensorboard-logger]] TensorBoard는 기계학습 실험을 위한 시각화 도구입니다. 주로 손실 및 정확도와 같은 지표를 추적 및 시각화하고, 모델 그래프와 히스토그램을 보여주고, 이미지를 표시하는 등 다양한 기능을 제공합니다. 또한 TensorBoard는 Hugging Face Hub와 잘 통합되어 있습니다. `tfevents` 같은 TensorBoard 추적을 Hub에 푸시하면 Hub는 이를 자동으로 감지하여 시각화 인스턴스를 시작합니다. TensorBoard와 Hub의 통합에 대한 자세한 정보는 [가이드](https://huggingface.co/docs/hub/tensorboard)를 확인하세요. 이 통합을 위해, `huggingface_hub`는 로그를 Hub로 푸시하기 위한 사용자 정의 로거를 제공합니다. 이 로거는 추가적인 코드 없이 [SummaryWriter](https://tensorboardx.readthedocs.io/en/latest/tensorboard.html)의 대체제로 사용될 수 있습니다. 추적은 계속해서 로컬에 저장되며 백그라운드 작업이 일정한 시간마다 Hub에 푸시하는 형태로 동작합니다. ## HFSummaryWriter[[huggingface_hub.HFSummaryWriter]] [[autodoc]] HFSummaryWriter huggingface_hub-0.31.1/docs/source/ko/package_reference/utilities.md000066400000000000000000000257331500667546600255410ustar00rootroot00000000000000 # 유틸리티[[utilities]] ## 로깅 구성[[huggingface_hub.utils.logging.get_verbosity]] `huggingface_hub` 패키지는 패키지 로그 레벨을 제어하기 위한 `logging` 유틸리티를 제공합니다. 다음과 같이 가져올 수 있습니다: ```py from huggingface_hub import logging ``` 그런 다음, 로그의 출력 수를 업데이트하기 위해 로그 레벨을 정의할 수 있습니다: ```python from huggingface_hub import logging logging.set_verbosity_error() logging.set_verbosity_warning() logging.set_verbosity_info() logging.set_verbosity_debug() logging.set_verbosity(...) ``` 로그 레벨은 다음과 같이 이해하면 됩니다: - `error`: 오류 또는 예기치 않은 동작으로 이어질 수 있는 결정적인 로그만 표시합니다. - `warning`: 결정적이진 않지만 의도치 않은 동작을 초래할 수 있는 로그를 표시합니다. 또한 중요한 정보를 포함한 로그도 표시될 수 있습니다. - `info`: 하부에서 무슨 일이 일어나고 있는지에 대한 자세한 로그를 포함하여 대부분의 로그를 표시합니다. 무언가 예상치 못한 방식으로 동작하는 경우, 더 많은 정보를 얻기 위해 verbosity 단계로 전환하는 것이 좋습니다. - `debug`: 하부에서 정확히 무슨 일이 일어나고 있는지를 추적하는 데 사용될 수 있는 일부 내부 로그를 포함하여 모든 로그를 표시합니다. [[autodoc]] logging.get_verbosity [[autodoc]] logging.set_verbosity [[autodoc]] logging.set_verbosity_info [[autodoc]] logging.set_verbosity_debug [[autodoc]] logging.set_verbosity_warning [[autodoc]] logging.set_verbosity_error [[autodoc]] logging.disable_propagation [[autodoc]] logging.enable_propagation ### 리포지토리별 도우미 메소드[[huggingface_hub.utils.logging.get_logger]] 아래 제공된 메소드들은 `huggingface_hub` 라이브러리 모듈을 수정할 때 관련이 있습니다. `huggingface_hub`를 사용하고 해당 모듈을 수정하지 않는 경우에는 사용할 필요가 없습니다. [[autodoc]] logging.get_logger ## 프로그레스 바 구성하기[[configure-progress-bars]] 프로그레스 바는 긴 시간이 걸리는 작업을 실행하는 동안 정보를 표시하는 유용한 도구입니다(예시로 파일을 다운로드하거나 업로드하는 등). `huggingface_hub`는 라이브러리 전체에서 일관된 방식으로 프로그레스 바를 표시하기 위한 [`~utils.tqdm`] 래퍼를 제공합니다. 기본적으로 프로그레스 바가 활성화되어 있습니다. `HF_HUB_DISABLE_PROGRESS_BARS` 환경 변수를 설정하여 전역적으로 비활성화할 수 있습니다. 또한 [`~utils.enable_progress_bars`]와 [`~utils.disable_progress_bars`]를 사용하여 프로그레스 바를 개별적으로 활성화 또는 비활성화할 수도 있습니다. 만약 환경 변수가 설정되어 있다면, 환경 변수가 도우미에서 우선 순위를 가집니다. ```py >>> from huggingface_hub import snapshot_download >>> from huggingface_hub.utils import are_progress_bars_disabled, disable_progress_bars, enable_progress_bars >>> # 전역적으로 프로그레스 바를 비활성화합니다. >>> disable_progress_bars() >>> # 프로그레스 바가 표시되지 않습니다! >>> snapshot_download("gpt2") >>> are_progress_bars_disabled() True >>> # 다시 프로그레스 바가 활성화됩니다 >>> enable_progress_bars() ``` ### are_progress_bars_disabled[[huggingface_hub.utils.are_progress_bars_disabled]] [[autodoc]] huggingface_hub.utils.are_progress_bars_disabled ### disable_progress_bars[[huggingface_hub.utils.disable_progress_bars]] [[autodoc]] huggingface_hub.utils.disable_progress_bars ### enable_progress_bars[huggingface_hub.utils.enable_progress_bars]] [[autodoc]] huggingface_hub.utils.enable_progress_bars ## HTTP 백엔드 구성[[huggingface_hub.configure_http_backend]] 일부 환경에서는 HTTP 호출이 이루어지는 방식을 구성할 수 있습니다. 예를 들어, 프록시를 사용하는 경우가 그렇습니다. `huggingface_hub`는 [`configure_http_backend`]를 사용하여 전역적으로 이를 구성할 수 있게 합니다. 그러면 Hub로의 모든 요청이 사용자가 설정한 설정을 사용합니다. 내부적으로 `huggingface_hub`는 `requests.Session`을 사용하므로 사용 가능한 매개변수에 대해 자세히 알아보려면 [requests 문서](https://requests.readthedocs.io/en/latest/user/advanced)를 참조하는 것이 좋습니다. `requests.Session`이 스레드 안전을 보장하지 않기 때문에 `huggingface_hub`는 스레드당 하나의 세션 인스턴스를 생성합니다. 세션을 사용하면 HTTP 호출 사이에 연결을 유지하고 최종적으로 시간을 절약할 수 있습니다. `huggingface_hub`를 서드 파티 라이브러리에 통합하고 사용자 지정 호출을 Hub로 만들려는 경우, [`get_session`]을 사용하여 사용자가 구성한 세션을 가져옵니다 (즉, 모든 `requests.get(...)` 호출을 `get_session().get(...)`으로 대체합니다). [[autodoc]] configure_http_backend [[autodoc]] get_session ## HTTP 오류 다루기[[handle-http-errors]] `huggingface_hub`는 서버에서 반환된 추가 정보로 `requests`에서 발생한 `HTTPError`를 세분화하기 위해 자체 HTTP 오류를 정의합니다. ### 예외 발생[[huggingface_hub.utils.hf_raise_for_status]] [`~utils.hf_raise_for_status`]는 Hub에 대한 모든 요청에서 "상태를 확인하고 예외를 발생시키는" 중앙 메소드로 사용됩니다. 이 메서드는 기본 `requests.raise_for_status`를 감싸서 추가 정보를 제공합니다. 발생된 모든 `HTTPError`는 `HfHubHTTPError`로 변환됩니다. ```py import requests from huggingface_hub.utils import hf_raise_for_status, HfHubHTTPError response = requests.post(...) try: hf_raise_for_status(response) except HfHubHTTPError as e: print(str(e)) # 형식화된 메시지 e.request_id, e.server_message # 서버에서 반환된 세부 정보 # 오류 메시지를 발생시킬 때 추가 정보를 포함하여 완성합니다 e.append_to_message("\n`create_commit` expects the repository to exist.") raise ``` [[autodoc]] huggingface_hub.utils.hf_raise_for_status ### HTTP 오류[[http-errors]] 여기에는 `huggingface_hub`에서 발생하는 HTTP 오류 목록이 있습니다. #### HfHubHTTPError[[huggingface_hub.utils.HfHubHTTPError]] `HfHubHTTPError`는 HF Hub HTTP 오류에 대한 부모 클래스입니다. 이 클래스는 서버 응답을 구문 분석하고 오류 메시지를 형식화하여 사용자에게 가능한 많은 정보를 제공합니다. [[autodoc]] huggingface_hub.utils.HfHubHTTPError #### RepositoryNotFoundError[[huggingface_hub.utils.RepositoryNotFoundError]] [[autodoc]] huggingface_hub.utils.RepositoryNotFoundError #### GatedRepoError[[huggingface_hub.utils.GatedRepoError]] [[autodoc]] huggingface_hub.utils.GatedRepoError #### RevisionNotFoundError[[huggingface_hub.utils.RevisionNotFoundError]] [[autodoc]] huggingface_hub.utils.RevisionNotFoundError #### EntryNotFoundError[[huggingface_hub.utils.EntryNotFoundError]] [[autodoc]] huggingface_hub.utils.EntryNotFoundError #### BadRequestError[[huggingface_hub.utils.BadRequestError]] [[autodoc]] huggingface_hub.utils.BadRequestError #### LocalEntryNotFoundError[[huggingface_hub.utils.LocalEntryNotFoundError]] [[autodoc]] huggingface_hub.utils.LocalEntryNotFoundError #### OfflineModeIsEnabledd[[huggingface_hub.utils.OfflineModeIsEnabled]] [[autodoc]] huggingface_hub.utils.OfflineModeIsEnabled ## 원격 측정[[huggingface_hub.utils.send_telemetry]] `huggingface_hub`는 원격 측정 데이터를 보내는 도우미가 포함되어 있습니다. 이 정보는 문제를 디버깅하고 새로운 기능을 우선적으로 처리하는 데 도움이 됩니다. 사용자는 `HF_HUB_DISABLE_TELEMETRY=1` 환경 변수를 설정하여 언제든지 원격 측정 수집을 비활성화할 수 있습니다. 또한 오프라인 모드에서도 (즉, HF_HUB_OFFLINE=1로 설정된 경우) 원격 측정이 비활성화됩니다. 서드 파티 라이브러리의 유지 관리자인 경우, 원격 측정 데이터를 보내는 것은 [`send_telemetry`]를 호출하는 것만큼 간단합니다. 사용자에게 가능한 영향을 최소화하기 위해 데이터는 별도의 스레드에서 전송됩니다. [[autodoc]] utils.send_telemetry ## 검증기[[validators]] `huggingface_hub`에는 메소드 인수를 자동으로 유효성 검사하는 사용자 정의 검증기가 포함되어 있습니다. 이 유효성 검사는 타입 힌트를 검증하는 데 [Pydantic](https://pydantic-docs.helpmanual.io/)의 작업을 참고하여 구현되었지만, 기능은 더 제한적입니다. ### 일반 데코레이터[[generic-decorator]] [`~utils.validate_hf_hub_args`]는 `huggingface_hub`의 네이밍을 따르는 인수를 갖는 메소드를 캡슐화하는 일반적인 데코레이터입니다. 기본적으로 구현된 검증기가 있는 모든 인수가 유효성 검사됩니다. 입력이 유효하지 않은 경우 [`~utils.HFValidationError`]이 발생합니다. 첫 번째 유효하지 않은 값만 오류를 발생시키고 유효성 검사 프로세스를 중지합니다. 사용법: ```py >>> from huggingface_hub.utils import validate_hf_hub_args >>> @validate_hf_hub_args ... def my_cool_method(repo_id: str): ... print(repo_id) >>> my_cool_method(repo_id="valid_repo_id") valid_repo_id >>> my_cool_method("other..repo..id") huggingface_hub.utils._validators.HFValidationError: Cannot have -- or .. in repo_id: 'other..repo..id'. >>> my_cool_method(repo_id="other..repo..id") huggingface_hub.utils._validators.HFValidationError: Cannot have -- or .. in repo_id: 'other..repo..id'. >>> @validate_hf_hub_args ... def my_cool_auth_method(token: str): ... print(token) >>> my_cool_auth_method(token="a token") "a token" >>> my_cool_auth_method(use_auth_token="a use_auth_token") "a use_auth_token" >>> my_cool_auth_method(token="a token", use_auth_token="a use_auth_token") UserWarning: Both `token` and `use_auth_token` are passed (...). `use_auth_token` value will be ignored. "a token" ``` #### validate_hf_hub_args[[huggingface_hub.utils.validate_hf_hub_args]] [[autodoc]] utils.validate_hf_hub_args #### HFValidationError[[huggingface_hub.utils.HFValidationError]] [[autodoc]] utils.HFValidationError ### Argument validators[[argument-validators]] 검증기는 개별적으로도 사용할 수 있습니다. 다음은 검증할 수 있는 모든 인수 목록입니다. #### repo_id[[huggingface_hub.utils.validate_repo_id]] [[autodoc]] utils.validate_repo_id #### smoothly_deprecate_use_auth_token[[huggingface_hub.utils.smoothly_deprecate_use_auth_token]] 정확히 검증기는 아니지만, 잘 실행됩니다. [[autodoc]] utils.smoothly_deprecate_use_auth_token huggingface_hub-0.31.1/docs/source/ko/package_reference/webhooks_server.md000066400000000000000000000070161500667546600267270ustar00rootroot00000000000000 # 웹훅 서버[[webhooks-server]] 웹훅은 MLOps 관련 기능의 기반이 됩니다. 이를 통해 특정 저장소의 새로운 변경 사항을 수신하거나, 관심 있는 특정 사용자/조직에 속한 모든 저장소의 변경 사항을 받아볼 수 있습니다. Huggingface Hub의 웹훅에 대해 더 자세히 알아보려면 이 [가이드](https://huggingface.co/docs/hub/webhooks)를 읽어보세요. 웹훅 서버를 설정하고 Space로 배포하는 방법은 이 단계별 [가이드](../guides/webhooks_server)를 확인하세요. 이 기능은 실험적인 기능입니다. 본 API는 현재 개선 작업 중이며, 향후 사전 통지 없이 주요 변경 사항이 도입될 수 있음을 의미합니다. `requirements`에서 `huggingface_hub`의 버전을 고정하는 것을 권장합니다. 참고로 실험적 기능을 사용하면 경고가 트리거 됩니다. 이 경고 트리거를 비활성화 시키길 원한다면 환경변수 `HF_HUB_DISABLE_EXPERIMENTAL_WARNING=1`를 설정하세요. ## 서버[[server]] 여기서 서버는 하나의 [Gradio](https://gradio.app/) 앱을 의미합니다. Gradio에는 사용자 또는 사용자에게 지침을 표시하는 UI와 웹훅을 수신하기 위한 API가 있습니다. 웹훅 엔드포인트를 구현하는 것은 함수에 데코레이터를 추가하는 것만큼 간단합니다. 서버를 Space에 배포하기 전에 Gradio 터널을 사용하여 웹훅을 머신으로 리디렉션하여 디버깅할 수 있습니다. ### WebhooksServer[[huggingface_hub.WebhooksServer]] [[autodoc]] huggingface_hub.WebhooksServer ### @webhook_endpoint[[huggingface_hub.webhook_endpoint]] [[autodoc]] huggingface_hub.webhook_endpoint ## 페이로드[[huggingface_hub.WebhookPayload]] [`WebhookPayload`]는 웹훅의 페이로드를 포함하는 기본 데이터 구조입니다. 이것은 `pydantic` 클래스로서 FastAPI에서 매우 쉽게 사용할 수 있습니다. 즉 WebhookPayload를 웹후크 엔드포인트에 매개변수로 전달하면 자동으로 유효성이 검사되고 파이썬 객체로 파싱됩니다. 웹훅 페이로드에 대한 자세한 사항은 이 [가이드](https://huggingface.co/docs/hub/webhooks#webhook-payloads)를 참고하세요. [[autodoc]] huggingface_hub.WebhookPayload ### WebhookPayload[[huggingface_hub.WebhookPayload]] [[autodoc]] huggingface_hub.WebhookPayload ### WebhookPayloadComment[[huggingface_hub.WebhookPayloadComment]] [[autodoc]] huggingface_hub.WebhookPayloadComment ### WebhookPayloadDiscussion[[huggingface_hub.WebhookPayloadDiscussion]] [[autodoc]] huggingface_hub.WebhookPayloadDiscussion ### WebhookPayloadDiscussionChanges[[huggingface_hub.WebhookPayloadDiscussionChanges]] [[autodoc]] huggingface_hub.WebhookPayloadDiscussionChanges ### WebhookPayloadEvent[[huggingface_hub.WebhookPayloadEvent]] [[autodoc]] huggingface_hub.WebhookPayloadEvent ### WebhookPayloadMovedTo[[huggingface_hub.WebhookPayloadMovedTo]] [[autodoc]] huggingface_hub.WebhookPayloadMovedTo ### WebhookPayloadRepo[[huggingface_hub.WebhookPayloadRepo]] [[autodoc]] huggingface_hub.WebhookPayloadRepo ### WebhookPayloadUrl[[huggingface_hub.WebhookPayloadUrl]] [[autodoc]] huggingface_hub.WebhookPayloadUrl ### WebhookPayloadWebhook[[huggingface_hub.WebhookPayloadWebhook]] [[autodoc]] huggingface_hub.WebhookPayloadWebhook huggingface_hub-0.31.1/docs/source/ko/quick-start.md000066400000000000000000000152431500667546600223570ustar00rootroot00000000000000 # 둘러보기 [[quickstart]] [Hugging Face Hub](https://huggingface.co/)는 머신러닝 모델, 데모, 데이터 세트 및 메트릭을 공유할 수 있는 곳입니다. `huggingface_hub` 라이브러리는 개발 환경을 벗어나지 않고도 Hub와 상호작용할 수 있도록 도와줍니다. 리포지토리를 쉽게 만들고 관리하거나, 파일을 다운로드 및 업로드하고, 유용한 모델과 데이터 세트의 메타데이터도 구할 수 있습니다. ## 설치 [[installation]] 시작하려면 `huggingface_hub` 라이브러리를 설치하세요: ```bash pip install --upgrade huggingface_hub ``` 자세한 내용은 [설치](./installation) 가이드를 참조하세요. ## 파일 다운로드 [[download-files]] Hub의 리포지토리는 git으로 버전 관리되며, 사용자는 단일 파일 또는 전체 리포지토리를 다운로드할 수 있습니다. 파일을 다운로드하려면 [`hf_hub_download`] 함수를 사용하면 됩니다. 사용하면 파일을 다운로드하여 로컬 디스크에 캐시하기 때문에, 다음에 해당 파일이 필요하면 캐시에서 가져오므로 다시 다운로드할 필요가 없습니다. 다운로드하려면 리포지토리 ID와 파일명이 필요합니다. 예를 들어, [Pegasus](https://huggingface.co/google/pegasus-xsum) 모델 구성 파일을 다운로드하려면: ```py >>> from huggingface_hub import hf_hub_download >>> hf_hub_download(repo_id="google/pegasus-xsum", filename="config.json") ``` 특정 버전의 파일을 다운로드하려면 `revision` 매개변수를 사용하여 브랜치 이름, 태그 또는 커밋 해시를 지정하세요. 커밋 해시를 사용하기로 선택한 경우, 7자로 된 짧은 커밋 해시 대신 전체 길이의 해시여야 합니다: ```py >>> from huggingface_hub import hf_hub_download >>> hf_hub_download( ... repo_id="google/pegasus-xsum", ... filename="config.json", ... revision="4d33b01d79672f27f001f6abade33f22d993b151" ... ) ``` 자세한 내용과 옵션은 [`hf_hub_download`]에 대한 API 레퍼런스를 참조하세요. ## 로그인 [[login]] 비공개 리포지토리 다운로드, 파일 업로드, PR 생성 등 Hub와 상호 작용하려면 Hugging Face 계정으로 로그인해야 하는 경우가 많습니다. 아직 계정이 없다면 [계정 만들기](https://huggingface.co/join)를 클릭한 다음, 로그인하여 [설정 페이지](https://huggingface.co/settings/tokens)에서 [사용자 액세스 토큰](https://huggingface.co/docs/hub/security-tokens)을 받으세요. 사용자 액세스 토큰은 Hub에 인증하는 데 사용됩니다. 사용자 액세스 토큰을 받으면 터미널에서 다음 명령을 실행하세요: ```bash huggingface-cli login # or using an environment variable huggingface-cli login --token $HUGGINGFACE_TOKEN ``` 또는 주피터 노트북이나 스크립트에서 [`login`]로 프로그래밍 방식으로 로그인할 수도 있습니다: ```py >>> from huggingface_hub import login >>> login() ``` `login(token="hf_xxx")`과 같이 토큰을 [`login`]에 직접 전달하여 토큰을 입력하라는 메시지를 표시하지 않고 프로그래밍 방식으로 로그인할 수도 있습니다. 이렇게 한다면 소스 코드를 공유할 때 주의하세요. 토큰을 소스코드에 명시적으로 저장하는 대신에 보안 저장소에서 토큰을 가져오는 것이 가장 좋습니다. 한 번에 하나의 계정에만 로그인할 수 있습니다. 새 계정으로 로그인하면 이전 계정에서 로그아웃됩니다. 항상 `huggingface-cli whoami` 명령으로 어떤 계정을 사용 중인지 확인하세요. 동일한 스크립트에서 여러 계정을 처리하려면 각 메서드를 호출할 때 토큰을 제공하면 됩니다. 이 방법은 머신에 토큰을 저장하지 않으려는 경우에도 유용합니다. 로그인하면 Hub에 대한 모든 요청(반드시 인증이 필요하지 않은 메소드 포함)은 기본적으로 액세스 토큰을 사용합니다. 토큰의 암시적 사용을 비활성화하려면 `HF_HUB_DISABLE_IMPLICIT_TOKEN` 환경 변수를 설정해야 합니다. ## 리포지토리 만들기 [[create-a-repository]] 등록 및 로그인이 완료되면 [`create_repo`] 함수를 사용하여 리포지토리를 생성하세요: ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.create_repo(repo_id="super-cool-model") ``` 리포지토리를 비공개로 설정하려면 다음과 같이 하세요: ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.create_repo(repo_id="super-cool-model", private=True) ``` 비공개 리포지토리는 본인 외에는 누구에게도 공개되지 않습니다. 리포지토리를 생성하거나 Hub에 콘텐츠를 푸시하려면 `write` (쓰기) 권한이 있는 사용자 액세스 토큰을 제공해야 합니다. 토큰을 생성할 때 [설정 페이지](https://huggingface.co/settings/tokens)에서 권한을 선택할 수 있습니다. ## 파일 업로드 [[upload-files]] 새로 만든 리포지토리에 파일을 추가하려면 [`upload_file`] 함수를 사용하세요. 다음을 지정해야 합니다: 1. 업로드할 파일의 경로 2. 리포지토리에 있는 파일의 경로 3. 파일을 추가할 위치의 리포지토리 ID ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.upload_file( ... path_or_fileobj="/home/lysandre/dummy-test/README.md", ... path_in_repo="README.md", ... repo_id="lysandre/test-model", ... ) ``` 한 번에 두 개 이상의 파일을 업로드하려면 [업로드](./guides/upload) 가이드에서 (git을 포함하거나 제외한) 여러 가지 파일 업로드 방법을 소개하는 가이드를 참조하세요. ## 다음 단계 [[next-steps]] `huggingface_hub` 라이브러리는 사용자가 파이썬으로 Hub와 상호작용할 수 있는 쉬운 방법을 제공합니다. Hub에서 파일과 리포지토리를 관리하는 방법에 대해 자세히 알아보려면 [How-to 가이드](./guides/overview)를 읽어보시기 바랍니다: - 보다 쉽게 [리포지토리를 관리](./guides/repository)해보세요. - Hub에서 [다운로드](./guides/download) 파일을 다운로드해보세요. - Hub에 [업로드](./guides/upload) 파일을 업로드해보세요. - 원하는 모델 또는 데이터 세트에 대한 [Hub에서 검색](./guides/search)해보세요. - 빠른 추론을 원하신다면 [추론 API](./guides/inference)를 사용해보세요. huggingface_hub-0.31.1/docs/source/tm/000077500000000000000000000000001500667546600175705ustar00rootroot00000000000000huggingface_hub-0.31.1/docs/source/tm/_toctree.yml000066400000000000000000000002321500667546600221140ustar00rootroot00000000000000- title: "Get started" sections: - local: index title: குறியீட்டு - local: installation title: நிறுவல்huggingface_hub-0.31.1/docs/source/tm/index.md000066400000000000000000000166601500667546600212320ustar00rootroot00000000000000 # 🤗 ஹப் கிளையன்ட் லைப்ரரி `Huggingface_hub` லைப்ரரி உங்களை [ஹக்கிங் ஃபேஸ் ஹப்](https://hf.co) உடன் தொடர்புகொள்ள அனுமதிக்கிறது, இது படைப்பாளர்கள் மற்றும் கூட்டுப்பணியாளர்களுக்கான இயந்திர கற்றல் தளமாகும். உங்கள் திட்டங்களுக்கான முன் பயிற்சி பெற்ற மாதிரிகள் மற்றும் தரவுத்தொகுப்புகளைக் கண்டறியவும் அல்லது ஹப்பில் ஹோஸ்ட் செய்யப்பட்ட நூற்றுக்கணக்கான இயந்திர கற்றல் பயன்பாடுகளுடன் விளையாடவும். உங்கள் சொந்த மாதிரிகள் மற்றும் தரவுத்தொகுப்புகளை உருவாக்கி சமூகத்துடன் பகிரலாம். `huggingface_hub` லைப்ரரி பைதான் மூலம் இவற்றைச் செய்வதற்கான எளிய வழியை வழங்குகிறது. [இந்த துரிதத் தொடக்கக் கையேட்டை](quick-start) வாசித்தால், `huggingface_hub` நூலகத்துடன் வேலை செய்ய எவ்வாறு ஆரம்பிக்கலாம் என்பதை நீங்கள் கற்றுக்கொள்வீர்கள். இதில், 🤗 ஹப் (Hub) இலிருந்து கோப்புகளை எவ்வாறு பதிவிறக்குவது, ஒரு `repository` உருவாக்குவது மற்றும் கோப்புகளை ஹபுக்கு எவ்வாறு பதிவேற்றுவது என்பதை நீங்கள் கற்றுக்கொள்வீர்கள்.மேலும், 🤗 ஹபில் உங்கள் repositoryகளை எவ்வாறு நிர்வகிக்க வேண்டும், விவாதங்களில் எவ்வாறு ஈடுபட வேண்டும், அல்லது `Inference API`யை எப்படி அணுகுவது என்பதையும் கற்றுக்கொள்ள இந்த வழிகாட்டியை தொடர்ந்து வாசியுங்கள். ## பங்களிப்பு `huggingface_hub`-க்கு அனைத்து பங்களிப்புகளும் வரவேற்கப்படுகின்றன மற்றும் சமமாக மதிக்கப்படுகின்றன! 🤗 கோடில் உள்ள உள்ளமைவுகளையும் அல்லது பிழைகளைச் சரிசெய்வதோடு, ஆவணங்களை சரியாகவும், தற்போதைய நிலையில் இருப்பதையும் உறுதிப்படுத்துவதன் மூலம் தங்களால் உதவலாம், மேலும் இஷ்யூக்களுக்கான கேள்விகளுக்கு பதிலளிக்கவும், நூலகத்தை மேம்படுத்துமாறு நீங்கள் நினைப்பதைத் தொடர்ந்து புதிய அம்சங்களை கோரலாம். பங்களிப்பு குறித்த [வழிகாட்டலை](https://github.com/huggingface/huggingface_hub/blob/main/CONTRIBUTING.md) பார்க்கவும், புதிய இஷ்யூவோ அல்லது அம்சக் கோரிக்கையோ எப்படி சமர்ப்பிக்க வேண்டும், புல் ரிக்வெஸ்ட்களை (Pull Request) சமர்ப்பிப்பது எப்படி, மேலும் உங்கள் பங்களிப்புகள் அனைத்தும் எதிர்பார்த்தது போல வேலை செய்கிறதா என்பதைச் சோதிப்பது எப்படி என்பதையும் கற்றுக்கொள்ளலாம். பங்களிப்பாளர்கள் அனைவருக்கும் உள்ளடக்கிய மற்றும் வரவேற்கக்கூடிய ஒத்துழைப்பு நிலையை உருவாக்க, நாங்கள் உருவாக்கிய [நடத்தை விதிகளை](https://github.com/huggingface/huggingface_hub/blob/main/CODE_OF_CONDUCT.md) மதிக்க வேண்டும். huggingface_hub-0.31.1/docs/source/tm/installation.md000066400000000000000000000363551500667546600226270ustar00rootroot00000000000000# நிறுவல் நீங்கள் தொடங்குவதற்கு முன், தகுந்த தொகுப்புகளை நிறுவுவதன் மூலம் உங்கள் சூழலை அமைக்க வேண்டும். `huggingface_hub` **Python 3.8+** மின்பொருள்களில் சோதிக்கப்பட்டுள்ளது. ### பிப் மூலம் நிறுவு **pip மூலம் நிறுவல்** `huggingface_hub`-ஐ ஒரு [மெய்நிகர் சூழலில்](https://docs.python.org/3/library/venv.html) (virtual environment) நிறுவுவது மிகவும் பரிந்துரைக்கப்படுகிறது. நீங்கள் பைதான் மெய்நிகர் சூழல்களைக் குறித்து அறியாதவராக இருந்தால், இந்த [வழிகாட்டலைப்](https://packaging.python.org/en/latest/guides/installing-using-pip-and-virtual-environments/)பார்க்கவும். ஒரு மெய்நிகர் சூழல் பல்வேறு திட்டங்களை எளிதில் நிர்வகிக்கவும், சார்புகளுக்கிடையிலான (dependencies) இணக்கமின்மை பிரச்சனைகளைத் தவிர்க்கவும் உதவுகிறது. முதலில், உங்கள் திட்ட அடைவரிசையில் (project directory) ஒரு மெய்நிகர் சூழலை உருவாக்கத் தொடங்குங்கள்: ```bash python -m venv .env ``` மெய்நிகர் சூழலை செயல்படுத்தவும். Linux மற்றும் macOS-இல்: ```bash source .env/bin/activate ``` விண்டோஸ்-இல் மெய்நிகர் சூழலை செயல்படுத்த: ```bash .env/Scripts/activate ``` இப்போது நீங்கள் `huggingface_hub`-ஐ [PyPi பதிவகத்திலிருந்து](https://pypi.org/project/huggingface-hub/) நிறுவ தயாராக இருக்கிறீர்கள். ```bash pip install --upgrade huggingface_hub ``` முடித்த பிறகு, [நிறுவல் சரியாக வேலை](#check-installation) செய்கிறதா என்பதைச் சோதிக்கவும். ### விருப்பத் தேவைப்படும் சார்புகளை நிறுவல்** `huggingface_hub`-இன் சில சார்புகள் விருப்பமானவை, ஏனெனில் அவை `huggingface_hub`-இன் அடிப்படை அம்சங்களை இயக்க தேவையில்லை. எனினும், விருப்பச் சார்புகள் நிறுவப்படாதால், `huggingface_hub`-இன் சில அம்சங்கள் கிடைக்காது. நீங்கள் விருப்பத் தேவைப்படும் சார்புகளை `pip` மூலம் நிறுவலாம்: ```bash # டென்சர்‌ஃபிளோவுக்கான குறிப்பிட்ட அம்சங்களுக்கு சார்ந்த பொறுப்பு நிறுவவும் # /!\ எச்சரிக்கை: இது `pip install tensorflow` க்கு சமமாகக் கருதப்படாது pip install 'huggingface_hub[tensorflow]' # டார்ச்-குறிப்பிட்ட மற்றும் CLI-குறிப்பிட்ட அம்சங்களுக்கு தேவையான பொறுப்புகளை நிறுவவும். pip install 'huggingface_hub[cli,torch]' ``` `huggingface_hub`-இல் உள்ள விருப்பத் தேவைப்படும் சார்புகளின் பட்டியல்: - `cli`: `huggingface_hub`-க்கு மிகவும் வசதியான CLI இடைமுகத்தை வழங்குகிறது. - `fastai`, `torch`, `tensorflow`: வடிவமைப்பு குறிப்பிட்ட அம்சங்களை இயக்க தேவையான சார்புகள். - `dev`: நூலகத்திற்கு பங்களிக்க தேவையான சார்புகள். இதில் சோதனை (சோதனைகளை இயக்க), வகை சோதனை (வகை சரிபார்ப்பு ஐ இயக்க) மற்றும் தரம் (லிண்டர்கள் ஐ இயக்க) உள்ளன. ### மூலத்திலிருந்து நிறுவல் சில சமயம், `huggingface_hub`-ஐ நேரடியாக மூலத்திலிருந்து நிறுவுவது சுவாரஸ்யமாக இருக்கலாம். இது, சமீபத்திய நிலையான பதிப்பு பதிலாக, புதியதாக இருக்கும் `முக்கிய` பதிப்பைப் பயன்படுத்த அனுமதிக்கிறது. `முக்கிய` பதிப்பு, சமீபத்திய முன்னேற்றங்களுடன் புதுப்பிக்க உதவுகிறது, உதாரணமாக, சமீபத்திய அதிகாரப்பூர்வ வெளியீட்டுக்குப் பிறகு பிழை சரிசெய்யப்பட்டிருந்தாலும் புதிய வெளியீடு வந்ததாக இல்லை. எனினும், இதன் பொருள் `முக்கிய` பதிப்பு எப்போதும் நிலையாக இருக்காது. `முக்கிய` பதிப்பை செயல்படுமாறு வைத்திருக்க நாங்கள் முயற்சிக்கிறோம், மேலும் பெரும்பாலான சிக்கல்களை சில மணி நேரங்கள் அல்லது ஒரு நாளுக்குள் தீர்க்கவேண்டியவை. நீங்கள் ஒரு பிரச்சினையை எதிர்கொண்டால், அதைக் கூட்டுங்கள், அதைக் கூட விரைவில் சரிசெய்ய நாங்கள் முயற்சிக்கிறோம்! ```bash pip install git+https://github.com/huggingface/huggingface_hub ``` மூலத்திலிருந்து நிறுவும் போது, நீங்கள் குறிப்பிட்ட கிளையை (branch) குறிப்படலாம். இது, இன்னும் இணைக்கப்படாத புதிய அம்சம் அல்லது புதிய பிழை சரிசெய்வுகளை சோதிக்க விரும்பும்போது பயனுள்ளதாக இருக்கும்: ```bash pip install git+https://github.com/huggingface/huggingface_hub@my-feature-branch ``` முடித்த பிறகு, [நிறுவல் சரியாக வேலை செய்கிறதா]((#check-installation)) என்பதைச் சோதிக்கவும். ### திருத்தக்கூடிய நிறுவல் மூலத்திலிருந்து நிறுவுதல் [எடிடேபிள் இன்ஸ்டால்](https://pip.pypa.io/en/stable/topics/local-project-installs/#editable-installs) அமைப்பதற்கு அனுமதிக்கிறது. இது, `huggingface_hub`-க்கு பங்களிக்க திட்டமிட்டு, கோடில் மாற்றங்களை சோதிக்க விரும்பும் போது மேலும் முற்றிலும் மேம்பட்ட நிறுவல் ஆகும். உங்கள் இயந்திரத்தில் `huggingface_hub`-இன் ஒரு உள்ளூர் நகலை கிளோன் செய்ய வேண்டும். ```bash # முதலில், கிடுகிடுக்கும் தொகுப்பை உள்ளூர் முறையில் கிளோன் செய்யவும். git clone https://github.com/huggingface/huggingface_hub.git # அதன் பிறகு, -e கொள்கையைப் பயன்படுத்தி நிறுவவும். cd huggingface_hub pip install -e . ``` இந்த கட்டளைகள், நீங்கள் தரவுகளை கிளோன் செய்த அடைவை மற்றும் உங்கள் பைதான் நூலகப் பாதைகளை இணைக்கும். பைதான், தற்போது சாதாரண நூலகப் பாதைகளுக்கு கூட, நீங்கள் கிளோன் செய்த அடைவைப் பார்வையிடும். உதாரணமாக, உங்கள் பைதான் தொகுப்புகள் பொதுவாக `./.venv/lib/python3.11/site-packages/` இல் நிறுவப்பட்டிருந்தால், பைதான்n நீங்கள் கிளோன் செய்த `./huggingface_hub/` அடைவையும் தேடுவதாக இருக்கும். ## கொண்டா மூலம் நிறுவல் **நீங்கள் அதனுடன் மேலும் பரிச்சயமாக இருந்தால்**, `huggingface_hub`-ஐ [conda-forge சேனல்](https://anaconda.org/conda-forge/huggingface_hub) பயன்படுத்தி நிறுவலாம்: ```bash conda install -c conda-forge huggingface_hub ``` முடித்த பிறகு, [நிறுவல் சரியாக வேலை செய்கிறதா என்பதைச் சோதிக்கவும்](#check-installation). ## நிறுவலைச் சோதிக்கவும் நிறுவலுக்குப் பிறகு, `huggingface_hub` சரியாக வேலை செய்கிறதா என்பதைக் கீழ்காணும் கட்டளையை இயக்கி சோதிக்கவும்: ```bash python -c "from huggingface_hub import model_info; print(model_info('gpt2'))" ``` இந்த கட்டளை, Hub-இல் உள்ள [gpt2](https://huggingface.co/gpt2) மாடலுக்கான தகவல்களை பெறும். வெளியீடு கீழ்காணும் மாதிரியாக இருக்க வேண்டும்: ```text Model Name: gpt2 Tags: ['pytorch', 'tf', 'jax', 'tflite', 'rust', 'safetensors', 'gpt2', 'text-generation', 'en', 'doi:10.57967/hf/0039', 'transformers', 'exbert', 'license:mit', 'has_space'] Task: text-generation ``` ## Windows மரபுகள் எந்த இடத்திலும் சிறந்த ML-ஐ பொதுமக்களுக்கு வழங்கும் எங்கள் இலக்குடன், `huggingface_hub`-ஐ ஒரு குறைவில்லாத தளத்துடன் உருவாக்கினோம் மற்றும் குறிப்பாக Unix அடிப்படையிலான மற்றும் Windows அமைப்புகளில் சரியாக செயல்படவும். ஆனால், Windows-இல் இயங்கும் போது `huggingface_hub`-க்கு சில வரையறைகள் உள்ளன. இங்கே தெரிந்த சிக்கல்களின் முழு பட்டியல் உள்ளது. உங்கள் சந்தர்ப்பத்தில் ஆவணமிடாத சிக்கல் கண்டுபிடித்தால், [Github-ல் ஒரு பிரச்சனை திறக்க](https://github.com/huggingface/huggingface_hub/issues/new/choose) எங்களுக்கு தெரிவிக்கவும். - `huggingface_hub`-இன் காசே அமைப்பு, Hub-இல் இருந்து பதிவிறக்கம் செய்யப்பட்ட கோப்புகளைச் சரியாக காசே செய்ய சிம்லிங்குகளை நம்புகிறது. Windows-இல், சிம்லிங்குகளை இயக்குவதற்கு நீங்கள் டெவலப்பர் முறை அல்லது உங்கள் ஸ்கிரிப்டைப் ஆட்மின் ஆக இயக்க வேண்டும். சிம்லிங்குகள் இயக்கப்படாவிட்டால், காசே அமைப்பு இன்னும் வேலை செய்யும் ஆனால் சரியாக செயல்படாது. மேலும் விவரங்களுக்கு [காசே வரையறைகள்](./guides/manage-cache#limitations) பகுதியைப் படிக்கவும். - Hub-இல் கோப்பு பாதைகள் சிறப்பு எழுத்துக்கள் கொண்டதாக இருக்கலாம் (எ.கா. `"path/to?/my/file"`). Windows, [சிறப்பு எழுத்துக்கள்](https://learn.microsoft.com/en-us/windows/win32/intl/character-sets-used-in-file-names) மீது அதிக கட்டுப்பாடுகளை கொண்டுள்ளது, இது Windows-இல் அந்த கோப்புகளை பதிவிறக்கம் செய்ய முடியாததாக உருவாக்குகிறது. இது நிச்சயமாக ஒரு புலவியல் சந்தர்ப்பமாக இருக்க வேண்டும். இது தவறு என்று நீங்கள் நினைத்தால், அதற்கான தீர்வைத் தேட எங்களை அணுகவும். ## அடுத்த கட்டங்கள் `huggingface_hub` உங்கள் இயந்திரத்தில் முறையாக நிறுவப்பட்ட பிறகு, [சூழல் மாறிலிகளை](package_reference/environment_variables) கட்டமைக்க அல்லது [எங்கள் வழிகாட்டிகளில்](guides/overview) ஒன்றைப் பார்வையிட தேவையெனில், தொடங்குங்கள்.huggingface_hub-0.31.1/i18n/000077500000000000000000000000001500667546600154775ustar00rootroot00000000000000huggingface_hub-0.31.1/i18n/README_cn.md000066400000000000000000000156111500667546600174420ustar00rootroot00000000000000

huggingface_hub library logo

Hugging Face Hub Python 客户端

Documentation GitHub release PyPi version PyPI - Downloads Code coverage

English | Deutsch | हिंदी | 한국어 | 中文(简体)

--- **文档**: https://hf.co/docs/huggingface_hub **源代码**: https://github.com/huggingface/huggingface_hub --- ## 欢迎使用 Hugging Face Hub 库 通过`huggingface_hub` 库,您可以与面向机器学习开发者和协作者的平台 [Hugging Face Hub](https://huggingface.co/)进行交互,找到适用于您所在项目的预训练模型和数据集,体验在平台托管的数百个机器学习应用,还可以创建或分享自己的模型和数据集并于社区共享。以上所有都可以用Python在`huggingface_hub` 库中轻松实现。 ## 主要特点 - [从hugging face hub下载文件](https://huggingface.co/docs/huggingface_hub/en/guides/download) - [上传文件到 hugging face hub](https://huggingface.co/docs/huggingface_hub/en/guides/upload) - [管理您的存储库](https://huggingface.co/docs/huggingface_hub/en/guides/repository) - [在部署的模型上运行推断](https://huggingface.co/docs/huggingface_hub/en/guides/inference) - [搜索模型、数据集和空间](https://huggingface.co/docs/huggingface_hub/en/guides/search) - [分享模型卡片](https://huggingface.co/docs/huggingface_hub/en/guides/model-cards) - [社区互动](https://huggingface.co/docs/huggingface_hub/en/guides/community) ## 安装 使用pip安装 `huggingface_hub` 包: ```bash pip install huggingface_hub ``` 如果您更喜欢,也可以使用 conda 进行安装 为了默认保持包的最小化,huggingface_hub 带有一些可选的依赖项,适用于某些用例。例如,如果您想要完整的推断体验,请运行: ```bash pip install huggingface_hub[inference] ``` 要了解更多安装和可选依赖项,请查看[安装指南](https://huggingface.co/docs/huggingface_hub/cn/安装) ## 快速入门指南 ### 下载文件 下载单个文件,请运行以下代码: ```py from huggingface_hub import hf_hub_download hf_hub_download(repo_id="tiiuae/falcon-7b-instruct", filename="config.json") ``` 如果下载整个存储库,请运行以下代码: ```py from huggingface_hub import snapshot_download snapshot_download("stabilityai/stable-diffusion-2-1") ``` 文件将被下载到本地缓存文件夹。更多详细信息请参阅此 [指南](https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache). ### 登录 Hugging Face Hub 使用令牌对应用进行身份验证(请参阅[文档](https://huggingface.co/docs/hub/security-tokens)). 要登录您的机器,请运行以下命令行: ```bash huggingface-cli login # or using an environment variable huggingface-cli login --token $HUGGINGFACE_TOKEN ``` ### 创建一个存储库 要创建一个新存储库,请运行以下代码: ```py from huggingface_hub import create_repo create_repo(repo_id="super-cool-model") ``` ### 上传文件 上传单个文件,请运行以下代码 ```py from huggingface_hub import upload_file upload_file( path_or_fileobj="/home/lysandre/dummy-test/README.md", path_in_repo="README.md", repo_id="lysandre/test-model", ) ``` 如果上传整个存储库,请运行以下代码: ```py from huggingface_hub import upload_folder upload_folder( folder_path="/path/to/local/space", repo_id="username/my-cool-space", repo_type="space", ) ``` 有关详细信息,请查看 [上传指南](https://huggingface.co/docs/huggingface_hub/en/guides/upload). ## 集成到 Hub 中 我们正在与一些出色的开源机器学习库合作,提供免费的模型托管和版本控制。您可以在 [这里](https://huggingface.co/docs/hub/libraries)找到现有的集成 优势包括: - 为库及其用户提供免费的模型或数据集托管 - 内置文件版本控制,即使对于非常大的文件也能实现,这得益于基于 Git 的方法 - 为所有公开可用的模型提供托管的推断 API - 在网页端可在线体验所有公开的模型 - 任何人都可以上传新模型到您的库,他们只需为模型添加相应的标签,以便让其被发现 - 快速下载!我们使用 Cloudfront(CDN)进行地理复制下载,因此无论在全球任何地方,下载速度都非常快。 - 使用统计和更多功能即将推出 如果您想要集成您的库,请随时打开一个问题来开始讨论。我们编写了一份逐步指南,以❤️的方式展示如何进行这种集成。 ## 欢迎各种贡献(功能请求、错误等) 💙💚💛💜🧡❤️ 欢迎每个人来进行贡献,我们重视每个人的贡献。编写代码并非唯一的帮助社区的方式。回答问题、帮助他人、积极互动并改善文档对社区来说都是极其有价值的。为此我们编写了一份 [贡献指南](https://github.com/huggingface/huggingface_hub/blob/main/CONTRIBUTING.md) 以进行总结,即如何开始为这个存储库做贡献 huggingface_hub-0.31.1/i18n/README_de.md000066400000000000000000000175611500667546600174400ustar00rootroot00000000000000

huggingface_hub library logo

Der offizielle Python-Client für den Huggingface Hub.

Dokumentation GitHub release PyPi version PyPI - Downloads Code coverage

English | Deutsch | हिंदी | 한국인 | 中文(简体)

--- **Dokumentation**: https://hf.co/docs/huggingface_hub **Quellcode**: https://github.com/huggingface/huggingface_hub --- ## Willkommen bei der huggingface_hub Bibliothek Die `huggingface_hub` Bibliothek ermöglicht Ihnen die Interaktion mit dem [Hugging Face Hub](https://huggingface.co/), einer Plattform, die Open-Source Machine Learning für Entwickler und Mitwirkende demokratisiert. Entdecken Sie vortrainierte Modelle und Datensätze für Ihre Projekte oder spielen Sie mit den Tausenden von Machine-Learning-Apps, die auf dem Hub gehostet werden. Sie können auch Ihre eigenen Modelle, Datensätze und Demos mit der Community teilen. Die `huggingface_hub` Bibliothek bietet eine einfache Möglichkeit, all dies mit Python zu tun. ## Hauptmerkmale - Dateien vom Hub [herunterladen](https://huggingface.co/docs/huggingface_hub/de/guides/download). - Dateien auf den Hub [hochladen](https://huggingface.co/docs/huggingface_hub/de/guides/upload). - [Verwalten Ihrer Repositories](https://huggingface.co/docs/huggingface_hub/de/guides/repository). - [Ausführen von Inferenz](https://huggingface.co/docs/huggingface_hub/de/guides/inference) auf bereitgestellten Modellen. - [Suche](https://huggingface.co/docs/huggingface_hub/de/guides/search) nach Modellen, Datensätzen und Spaces. - [Model Cards teilen](https://huggingface.co/docs/huggingface_hub/de/guides/model-cards), um Ihre Modelle zu dokumentieren. - [Mit der Community interagieren](https://huggingface.co/docs/huggingface_hub/de/guides/community), durch PRs und Kommentare. ## Installation Installieren Sie das `huggingface_hub` Paket mit [pip](https://pypi.org/project/huggingface-hub/): ```bash pip install huggingface_hub ``` Wenn Sie möchten, können Sie es auch mit [conda](https://huggingface.co/docs/huggingface_hub/de/installation#installieren-mit-conda) installieren. Um das Paket standardmäßig minimal zu halten, kommt `huggingface_hub` mit optionalen Abhängigkeiten, die für einige Anwendungsfälle nützlich sind. Zum Beispiel, wenn Sie ein vollständiges Erlebnis für Inferenz möchten, führen Sie den folgenden Befehl aus: ```bash pip install huggingface_hub[inference] ``` Um mehr über die Installation und optionale Abhängigkeiten zu erfahren, sehen Sie sich bitte den [Installationsleitfaden](https://huggingface.co/docs/huggingface_hub/de/installation) an. ## Schnellstart ### Dateien herunterladen Eine einzelne Datei herunterladen ```py from huggingface_hub import hf_hub_download hf_hub_download(repo_id="tiiuae/falcon-7b-instruct", filename="config.json") ``` Oder eine gesamte Repository ```py from huggingface_hub import snapshot_download snapshot_download("stabilityai/stable-diffusion-2-1") ``` Dateien werden in einen lokalen Cache-Ordner heruntergeladen. Weitere Details finden Sie in diesem [Leitfaden](https://huggingface.co/docs/huggingface_hub/de/guides/manage-cache). ### Anmeldung Der Hugging Face Hub verwendet Tokens zur Authentifizierung von Anwendungen (siehe [Dokumentation](https://huggingface.co/docs/hub/security-tokens)). Um sich an Ihrem Computer anzumelden, führen Sie das folgende Kommando in der Befehlszeile aus: ```bash huggingface-cli login # oder mit einer Umgebungsvariablen huggingface-cli login --token $HUGGINGFACE_TOKEN ``` ### Eine Repository erstellen ```py from huggingface_hub import create_repo create_repo(repo_id="super-cool-model") ``` ### Dateien hochladen Eine einzelne Datei hochladen ```py from huggingface_hub import upload_file upload_file( path_or_fileobj="/home/lysandre/dummy-test/README.md", path_in_repo="README.md", repo_id="lysandre/test-model", ) ``` Oder einen gesamten Ordner ```py from huggingface_hub import upload_folder upload_folder( folder_path="/path/to/local/space", repo_id="username/my-cool-space", repo_type="space", ) ``` Weitere Informationen finden Sie im [Upload-Leitfaden](https://huggingface.co/docs/huggingface_hub/de/guides/upload). ## Integration in den Hub Wir arbeiten mit coolen Open-Source-ML-Bibliotheken zusammen, um kostenloses Model-Hosting und -Versionierung anzubieten. Die bestehenden Integrationen finden Sie [hier](https://huggingface.co/docs/hub/libraries). Die Vorteile sind: - Kostenloses Hosting von Modellen oder Datensätzen für Bibliotheken und deren Benutzer.. - Eingebaute Dateiversionierung, selbst bei sehr großen Dateien, dank eines git-basierten Ansatzes. - Bereitgestellte Inferenz-API für alle öffentlich verfügbaren Modelle. - In-Browser-Widgets zum Spielen mit den hochgeladenen Modellen. - Jeder kann ein neues Modell für Ihre Bibliothek hochladen, es muss nur das entsprechende Tag hinzugefügt werden, damit das Modell auffindbar ist. - Schnelle Downloads! Wir verwenden Cloudfront (ein CDN), um Downloads zu geo-replizieren, sodass sie von überall auf der Welt blitzschnell sind. - Nutzungsstatistiken und mehr Funktionen in Kürze. Wenn Sie Ihre Bibliothek integrieren möchten, öffnen Sie gerne ein Issue, um die Diskussion zu beginnen. Wir haben mit ❤️ einen [schrittweisen Leitfaden](https://huggingface.co/docs/hub/adding-a-library) geschrieben, der zeigt, wie diese Integration durchgeführt wird. ## Beiträge (Feature-Anfragen, Fehler usw.) sind super willkommen 💙💚💛💜🧡❤️ Jeder ist willkommen beizutragen, und wir schätzen den Beitrag jedes Einzelnen. Code zu schreiben ist nicht der einzige Weg, der Community zu helfen. Fragen zu beantworten, anderen zu helfen, sich zu vernetzen und die Dokumentationen zu verbessern, sind für die Gemeinschaft von unschätzbarem Wert. Wir haben einen [Beitrags-Leitfaden](https://github.com/huggingface/huggingface_hub/blob/main/CONTRIBUTING.md) geschrieben, der zusammenfasst, wie Sie beginnen können, zu dieser Repository beizutragen. huggingface_hub-0.31.1/i18n/README_hi.md000066400000000000000000000276671500667546600174600ustar00rootroot00000000000000

huggingface_hub library logo

Huggingface Hub के लिए आधिकारिक पायथन क्लाइंट।

Documentation GitHub release PyPi version PyPI - Downloads Code coverage

English | Deutsch | हिंदी | 한국어 | 中文(简体)

--- **दस्तावेज़ीकरण**: https://hf.co/docs/huggingface_hub **सोर्स कोड**: https://github.com/huggingface/huggingface_hub --- ## huggingface_hub लाइब्रेरी में आपका स्वागत है `huggingface_hub` लाइब्रेरी आपको [हगिंग फेस हब](https://huggingface.co/) के साथ बातचीत करने की अनुमति देती है, जो रचनाकारों और सहयोगियों के लिए ओपन-सोर्स मशीन लर्निंग का लोकतंत्रीकरण करने वाला एक मंच है। अपनी परियोजनाओं के लिए पूर्व-प्रशिक्षित मॉडल और डेटासेट खोजें या हब पर होस्ट किए गए हजारों मशीन लर्निंग ऐप्स के साथ खेलें। आप समुदाय के साथ अपने स्वयं के मॉडल, डेटासेट और डेमो भी बना और साझा कर सकते हैं। `huggingface_hub` लाइब्रेरी पायथन के साथ इन सभी चीजों को करने का एक आसान तरीका प्रदान करती है। ## प्रमुख विशेषताऐं - [फ़ाइलें डाउनलोड करें](https://huggingface.co/docs/huggingface_hub/en/guides/download) हब से। - [फ़ाइलें अपलोड करें](https://huggingface.co/docs/huggingface_hub/en/guides/upload) हब पर। - [अपनी रिपॉजिटरी प्रबंधित करें](https://huggingface.co/docs/huggingface_hub/en/guides/repository)। - तैनात मॉडलों पर [अनुमान चलाएँ](https://huggingface.co/docs/huggingface_hub/en/guides/inference)। - मॉडल, डेटासेट और स्पेस के लिए [खोज](https://huggingface.co/docs/huggingface_hub/en/guides/search)। - [मॉडल कार्ड साझा करें](https://huggingface.co/docs/huggingface_hub/en/guides/model-cards) अपने मॉडलों का दस्तावेजीकरण करने के लिए। - [समुदाय के साथ जुड़ें](https://huggingface.co/docs/huggingface_hub/en/guides/community) पीआर और टिप्पणियों के माध्यम से। ## स्थापना [pip](https://pypi.org/project/huggingface-hub/) के साथ `huggingface_hub` पैकेज इंस्टॉल करें: ```bash pip install huggingface_hub ``` यदि आप चाहें, तो आप इसे [conda](https://huggingface.co/docs/huggingface_hub/en/installation#install-with-conda) से भी इंस्टॉल कर सकते हैं। पैकेज को डिफ़ॉल्ट रूप से न्यूनतम रखने के लिए, `huggingface_hub` कुछ उपयोग मामलों के लिए उपयोगी वैकल्पिक निर्भरता के साथ आता है। उदाहरण के लिए, यदि आप अनुमान के लिए संपूर्ण अनुभव चाहते हैं, तो चलाएँ: ```bash pip install huggingface_hub[inference] ``` अधिक इंस्टॉलेशन और वैकल्पिक निर्भरता जानने के लिए, [इंस्टॉलेशन गाइड](https://huggingface.co/docs/huggingface_hub/en/installation) देखें। ## जल्दी शुरू ### फ़ाइलें डाउनलोड करें एकल फ़ाइल डाउनलोड करें ```py from huggingface_hub import hf_hub_download hf_hub_download(repo_id="tiiuae/falcon-7b-instruct", filename="config.json") ``` या एक संपूर्ण भंडार ```py from huggingface_hub import snapshot_download snapshot_download("stabilityai/stable-diffusion-2-1") ``` फ़ाइलें स्थानीय कैश फ़ोल्डर में डाउनलोड की जाएंगी. [this_guide] में अधिक विवरण (https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache)। ### लॉग इन करें Hugging Face Hub एप्लिकेशन को प्रमाणित करने के लिए टोकन का उपयोग करता है (देखें [docs](https://huggingface.co/docs/hub/security-tokens))। अपनी मशीन में लॉगिन करने के लिए, निम्नलिखित सीएलआई चलाएँ: ```bash huggingface-cli login # या कृपया इसे एक पर्यावरण चर के रूप में निर्दिष्ट करें। huggingface-cli login --token $HUGGINGFACE_TOKEN ``` ### एक रिपॉजिटरी बनाएं ```py from huggingface_hub import create_repo create_repo(repo_id="super-cool-model") ``` ### फाइलें अपलोड करें एकल फ़ाइल अपलोड करें ```py from huggingface_hub import upload_file upload_file( path_or_fileobj="/home/lysandre/dummy-test/README.md", path_in_repo="README.md", repo_id="lysandre/test-model", ) ``` या एक संपूर्ण फ़ोल्डर ```py from huggingface_hub import upload_folder upload_folder( folder_path="/path/to/local/space", repo_id="username/my-cool-space", repo_type="space", ) ``` [अपलोड गाइड](https://huggingface.co/docs/huggingface_hub/en/guides/upload) में विवरण के लिए। ## हब से एकीकरण। हम मुफ्त मॉडल होस्टिंग और वर्जनिंग प्रदान करने के लिए शानदार ओपन सोर्स एमएल लाइब्रेरीज़ के साथ साझेदारी कर रहे हैं। आप मौजूदा एकीकरण [यहां](https://huggingface.co/docs/hub/libraries) पा सकते हैं। फायदे ये हैं: - पुस्तकालयों और उनके उपयोगकर्ताओं के लिए निःशुल्क मॉडल या डेटासेट होस्टिंग। - गिट-आधारित दृष्टिकोण के कारण, बहुत बड़ी फ़ाइलों के साथ भी अंतर्निहित फ़ाइल संस्करणिंग। - सभी मॉडलों के लिए होस्टेड अनुमान एपीआई सार्वजनिक रूप से उपलब्ध है। - अपलोड किए गए मॉडलों के साथ खेलने के लिए इन-ब्राउज़र विजेट। - कोई भी आपकी लाइब्रेरी के लिए एक नया मॉडल अपलोड कर सकता है, उन्हें मॉडल को खोजने योग्य बनाने के लिए बस संबंधित टैग जोड़ना होगा। - तेज़ डाउनलोड! हम डाउनलोड को जियो-रेप्लिकेट करने के लिए क्लाउडफ्रंट (एक सीडीएन) का उपयोग करते हैं ताकि वे दुनिया में कहीं से भी तेजी से चमक सकें। - उपयोग आँकड़े और अधिक सुविधाएँ आने वाली हैं। यदि आप अपनी लाइब्रेरी को एकीकृत करना चाहते हैं, तो चर्चा शुरू करने के लिए बेझिझक एक मुद्दा खोलें। हमने ❤️ के साथ एक [चरण-दर-चरण मार्गदर्शिका](https://huggingface.co/docs/hub/adding-a-library) लिखी, जिसमें दिखाया गया कि यह एकीकरण कैसे करना है। ## योगदान (सुविधा अनुरोध, बग, आदि) का अति स्वागत है 💙💚💛💜🧡❤️ योगदान के लिए हर किसी का स्वागत है और हम हर किसी के योगदान को महत्व देते हैं। कोड समुदाय की मदद करने का एकमात्र तरीका नहीं है। प्रश्नों का उत्तर देना, दूसरों की मदद करना, उन तक पहुंचना और दस्तावेज़ों में सुधार करना समुदाय के लिए बेहद मूल्यवान है। हमने संक्षेप में बताने के लिए एक [योगदान मार्गदर्शिका](https://github.com/huggingface/huggingface_hub/blob/main/CONTRIBUTING.md) लिखी है इस भंडार में योगदान करने की शुरुआत कैसे करें। huggingface_hub-0.31.1/i18n/README_ko.md000066400000000000000000000177371500667546600174660ustar00rootroot00000000000000

huggingface_hub library logo

공식 Huggingface Hub 파이썬 클라이언트

Documentation GitHub release PyPi version PyPI - Downloads Code coverage

English | Deutsch | हिंदी | 한국어| 中文(简体)

--- **기술 문서**: https://hf.co/docs/huggingface_hub **소스 코드**: https://github.com/huggingface/huggingface_hub --- ## huggingface_hub 라이브러리 개요 `huggingface_hub` 라이브러리는 [Hugging Face Hub](https://huggingface.co/)와 상호작용할 수 있게 해줍니다. Hugging Face Hub는 창작자와 협업자를 위한 오픈소스 머신러닝 플랫폼입니다. 여러분의 프로젝트에 적합한 사전 훈련된 모델과 데이터셋을 발견하거나, Hub에 호스팅된 수천 개의 머신러닝 앱들을 사용해보세요. 또한, 여러분이 만든 모델, 데이터셋, 데모를 커뮤니티와 공유할 수도 있습니다. `huggingface_hub` 라이브러리는 파이썬으로 이 모든 것을 간단하게 할 수 있는 방법을 제공합니다. ## 주요 기능 - Hub에서 [파일을 다운로드](https://huggingface.co/docs/huggingface_hub/main/ko/guides/download) - Hub에 [파일을 업로드](https://huggingface.co/docs/huggingface_hub/main/en/guides/upload) (영어) - [레포지토리를 관리](https://huggingface.co/docs/huggingface_hub/main/en/guides/repository) (영어) - 배포된 모델에 [추론을 실행](https://huggingface.co/docs/huggingface_hub/main/en/guides/inference) (영어) - 모델, 데이터셋, Space를 [검색](https://huggingface.co/docs/huggingface_hub/main/en/guides/search) (영어) - [모델 카드를 공유](https://huggingface.co/docs/huggingface_hub/main/en/guides/model-cards)하여 모델을 문서화 (영어) - PR과 댓글을 통해 [커뮤니티와 소통](https://huggingface.co/docs/huggingface_hub/main/en/guides/community) (영어) ## 설치 [pip](https://pypi.org/project/huggingface-hub/)로 `huggingface_hub` 패키지를 설치하세요: ```bash pip install huggingface_hub ``` 원한다면 [conda](https://huggingface.co/docs/huggingface_hub/ko/installation#install-with-conda)를 이용하여 설치할 수도 있습니다. 기본 패키지를 작게 유지하기 위해 `huggingface_hub`는 유용한 의존성을 추가적으로 제공합니다. 추론과 관련된 기능을 원한다면, 아래를 실행하세요: ```bash pip install huggingface_hub[inference] ``` 설치와 선택적 의존성에 대해 더 알아보려면, [설치 가이드](https://huggingface.co/docs/huggingface_hub/ko/installation)를 참고하세요. ## 맛보기 ### 파일 다운로드 파일 하나의 경우: ```py from huggingface_hub import hf_hub_download hf_hub_download(repo_id="tiiuae/falcon-7b-instruct", filename="config.json") ``` 레포지토리 전체의 경우: ```py from huggingface_hub import snapshot_download snapshot_download("stabilityai/stable-diffusion-2-1") ``` 파일은 로컬 캐시 폴더에 다운로드됩니다. 자세한 내용은 [이 가이드](https://huggingface.co/docs/huggingface_hub/ko/guides/manage-cache)를 참조하세요. ### 로그인 Hugging Face Hub는 토큰을 사용하여 애플리케이션을 인증합니다([문서](https://huggingface.co/docs/hub/security-tokens) 참조). 컴퓨터에서 로그인하려면 CLI를 사용하세요: ```bash huggingface-cli login # 또는 환경 변수로 지정해주세요 huggingface-cli login --token $HUGGINGFACE_TOKEN ``` ### 레포지토리 생성 ```py from huggingface_hub import create_repo create_repo(repo_id="super-cool-model") ``` ### 파일 업로드 파일 하나의 경우: ```py from huggingface_hub import upload_file upload_file( path_or_fileobj="/home/lysandre/dummy-test/README.md", path_in_repo="README.md", repo_id="lysandre/test-model", ) ``` 레포지토리 전체의 경우: ```py from huggingface_hub import upload_folder upload_folder( folder_path="/path/to/local/space", repo_id="username/my-cool-space", repo_type="space", ) ``` 자세한 내용은 [업로드 가이드](https://huggingface.co/docs/huggingface_hub/ko/guides/upload)를 참조하세요. ## Hugging Face Hub와 함께 성장하기 저희는 멋진 오픈소스 ML 라이브러리들과 협력하여, 모델 호스팅과 버전 관리를 무료로 제공하고 있습니다. 이미 통합된 라이브러리들은 [여기](https://huggingface.co/docs/hub/libraries)서 확인할 수 있습니다. 이렇게 하면 다음과 같은 장점이 있습니다: - 라이브러리 사용자들의 모델이나 데이터셋을 무료로 호스팅해줍니다. - git을 기반으로 한 방식으로, 아주 큰 파일들도 버전을 관리할 수 있습니다. - 공개된 모든 모델에 대해 추론 API를 호스팅해줍니다. - 업로드된 모델들을 브라우저에서 쉽게 사용할 수 있는 위젯을 제공합니다. - 누구나 여러분의 라이브러리에 새로운 모델을 업로드할 수 있습니다. 모델이 검색될 수 있도록 해당 태그만 추가하면 됩니다. - 다운로드 속도가 매우 빠릅니다! 왜냐하면 Cloudfront (CDN)를 이용하여 전 세계 어디에서나 빠르게 다운로드할 수 있도록 지역적으로 복제해뒀기 때문입니다. - 사용 통계와 더 많은 기능들을 제공합니다. 여러분의 라이브러리를 통합하고 싶다면, 이슈를 열어서 의견을 나눠주세요. 통합 과정을 안내하기 위해 ❤️을 담아 [단계별 가이드](https://huggingface.co/docs/hub/adding-a-library)를 작성했습니다. ## (기능 요청, 버그 패치 등의) 기여는 대환영입니다 💙💚💛💜🧡❤️ 모든 분들의 기여를 환영하며, 소중히 생각합니다. 코드 작성만이 커뮤니티에 도움을 주는 유일한 방법이 아니에요. 질문에 답하거나, 다른 분들을 돕거나, 컨택하거나, 문서를 개선하는 것도 커뮤니티에 큰 도움이 됩니다. 지금 시작하려면 간단한 [기여 가이드](https://github.com/huggingface/huggingface_hub/blob/main/CONTRIBUTING.md)를 참조해주세요. huggingface_hub-0.31.1/pyproject.toml000066400000000000000000000023311500667546600176330ustar00rootroot00000000000000[tool.mypy] ignore_missing_imports = true no_implicit_optional = true scripts_are_modules = true [tool.pytest.ini_options] # Add the specified `OPTS` to the set of command line arguments as if they had # been specified by the user. addopts = "-Werror::FutureWarning --log-cli-level=INFO -sv --durations=0" # The defined variables will be added to the environment before any tests are # run, part of pytest-env plugin env = [ "DISABLE_SYMLINKS_IN_WINDOWS_TESTS=1", "HF_TOKEN=", "HUGGINGFACE_CO_STAGING=1", "HUGGING_FACE_HUB_TOKEN=", ] [tool.ruff] exclude = [ ".eggs", ".git", ".git-rewrite", ".hg", ".mypy_cache", ".nox", ".pants.d", ".pytype", ".ruff_cache", ".svn", ".tox", ".venv", ".venv*", "__pypackages__", "_build", "build", "dist", "venv", ] line-length = 119 # Ignored rules: # "E501" -> line length violation lint.ignore = ["E501"] lint.select = ["E", "F", "I", "W"] [tool.ruff.lint.isort] known-first-party = ["huggingface_hub"] lines-after-imports = 2 [tool.tomlsort] all = true in_place = true spaces_before_inline_comment = 2 # Match Python PEP 8 spaces_indent_inline_array = 4 # Match Python PEP 8 trailing_comma_inline_array = true huggingface_hub-0.31.1/setup.py000066400000000000000000000105561500667546600164410ustar00rootroot00000000000000from setuptools import find_packages, setup def get_version() -> str: rel_path = "src/huggingface_hub/__init__.py" with open(rel_path, "r") as fp: for line in fp.read().splitlines(): if line.startswith("__version__"): delim = '"' if '"' in line else "'" return line.split(delim)[1] raise RuntimeError("Unable to find version string.") install_requires = [ "filelock", "fsspec>=2023.5.0", "hf-xet>=1.1.0,<2.0.0; platform_machine=='x86_64' or platform_machine=='amd64' or platform_machine=='arm64' or platform_machine=='aarch64'", "packaging>=20.9", "pyyaml>=5.1", "requests", "tqdm>=4.42.1", "typing-extensions>=3.7.4.3", # to be able to import TypeAlias ] extras = {} extras["cli"] = [ "InquirerPy==0.3.4", # Note: installs `prompt-toolkit` in the background ] extras["inference"] = [ "aiohttp", # for AsyncInferenceClient ] extras["torch"] = [ "torch", "safetensors[torch]", ] extras["hf_transfer"] = [ "hf_transfer>=0.1.4", # Pin for progress bars ] extras["fastai"] = [ "toml", "fastai>=2.4", "fastcore>=1.3.27", ] extras["tensorflow"] = [ "tensorflow", "pydot", "graphviz", ] extras["tensorflow-testing"] = [ "tensorflow", "keras<3.0", ] extras["hf_xet"] = ["hf_xet>=1.1.0,<2.0.0"] extras["testing"] = ( extras["cli"] + extras["inference"] + [ "jedi", "Jinja2", "pytest>=8.1.1,<8.2.2", # at least until 8.2.3 is released with https://github.com/pytest-dev/pytest/pull/12436 "pytest-cov", "pytest-env", "pytest-xdist", "pytest-vcr", # to mock Inference "pytest-asyncio", # for AsyncInferenceClient "pytest-rerunfailures", # to rerun flaky tests in CI "pytest-mock", "urllib3<2.0", # VCR.py broken with urllib3 2.0 (see https://urllib3.readthedocs.io/en/stable/v2-migration-guide.html) "soundfile", "Pillow", "gradio>=4.0.0", # to test webhooks # pin to avoid issue on Python3.12 "numpy", # for embeddings "fastapi", # To build the documentation ] ) # Typing extra dependencies list is duplicated in `.pre-commit-config.yaml` # Please make sure to update the list there when adding a new typing dependency. extras["typing"] = [ "typing-extensions>=4.8.0", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", ] extras["quality"] = [ "ruff>=0.9.0", "mypy==1.5.1", "libcst==1.4.0", ] extras["all"] = extras["testing"] + extras["quality"] + extras["typing"] extras["dev"] = extras["all"] setup( name="huggingface_hub", version=get_version(), author="Hugging Face, Inc.", author_email="julien@huggingface.co", description="Client library to download and publish models, datasets and other repos on the huggingface.co hub", long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", keywords="model-hub machine-learning models natural-language-processing deep-learning pytorch pretrained-models", license="Apache", url="https://github.com/huggingface/huggingface_hub", package_dir={"": "src"}, packages=find_packages("src"), extras_require=extras, entry_points={ "console_scripts": ["huggingface-cli=huggingface_hub.commands.huggingface_cli:main"], "fsspec.specs": "hf=huggingface_hub.HfFileSystem", }, python_requires=">=3.8.0", install_requires=install_requires, classifiers=[ "Intended Audience :: Developers", "Intended Audience :: Education", "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], include_package_data=True, package_data={"huggingface_hub": ["py.typed"]}, # Needed for wheel installation ) huggingface_hub-0.31.1/src/000077500000000000000000000000001500667546600155075ustar00rootroot00000000000000huggingface_hub-0.31.1/src/huggingface_hub/000077500000000000000000000000001500667546600206145ustar00rootroot00000000000000huggingface_hub-0.31.1/src/huggingface_hub/README.md000066400000000000000000000317341500667546600221030ustar00rootroot00000000000000# Hugging Face Hub Client library ## Download files from the Hub The `hf_hub_download()` function is the main function to download files from the Hub. One advantage of using it is that files are cached locally, so you won't have to download the files multiple times. If there are changes in the repository, the files will be automatically downloaded again. ### `hf_hub_download` The function takes the following parameters, downloads the remote file, stores it to disk (in a version-aware way) and returns its local file path. Parameters: - a `repo_id` (a user or organization name and a repo name, separated by `/`, like `julien-c/EsperBERTo-small`) - a `filename` (like `pytorch_model.bin`) - an optional Git revision id (can be a branch name, a tag, or a commit hash) - a `cache_dir` which you can specify if you want to control where on disk the files are cached. ```python from huggingface_hub import hf_hub_download hf_hub_download("lysandre/arxiv-nlp", filename="config.json") ``` ### `snapshot_download` Using `hf_hub_download()` works well when you know which files you want to download; for example a model file alongside a configuration file, both with static names. There are cases in which you will prefer to download all the files of the remote repository at a specified revision. That's what `snapshot_download()` does. It downloads and stores a remote repository to disk (in a versioning-aware way) and returns its local file path. Parameters: - a `repo_id` in the format `namespace/repository` - a `revision` on which the repository will be downloaded - a `cache_dir` which you can specify if you want to control where on disk the files are cached ### `hf_hub_url` Internally, the library uses `hf_hub_url()` to return the URL to download the actual files: `https://huggingface.co/julien-c/EsperBERTo-small/resolve/main/pytorch_model.bin` Parameters: - a `repo_id` (a user or organization name and a repo name separated by a `/`, like `julien-c/EsperBERTo-small`) - a `filename` (like `pytorch_model.bin`) - an optional `subfolder`, corresponding to a folder inside the model repo - an optional `repo_type`, such as `dataset` or `space` - an optional Git revision id (can be a branch name, a tag, or a commit hash) If you check out this URL's headers with a `HEAD` http request (which you can do from the command line with `curl -I`) for a few different files, you'll see that: - small files are returned directly - large files (i.e. the ones stored through [git-lfs](https://git-lfs.github.com/)) are returned via a redirect to a Cloudfront URL. Cloudfront is a Content Delivery Network, or CDN, that ensures that downloads are as fast as possible from anywhere on the globe.
## Publish files to the Hub If you've used Git before, this will be very easy since Git is used to manage files in the Hub. You can find a step-by-step guide on how to upload your model to the Hub: https://huggingface.co/docs/hub/adding-a-model. ### API utilities in `hf_api.py` You don't need them for the standard publishing workflow (ie. using git command line), however, if you need a programmatic way of creating a repo, deleting it (`⚠️ caution`), pushing a single file to a repo or listing models from the Hub, you'll find helpers in `hf_api.py`. Some example functionality available with the `HfApi` class: * `whoami()` * `create_repo()` * `list_repo_files()` * `list_repo_objects()` * `delete_repo()` * `update_repo_settings()` * `create_commit()` * `upload_file()` * `delete_file()` * `delete_folder()` Those API utilities are also exposed through the `huggingface-cli` CLI: ```bash huggingface-cli login huggingface-cli logout huggingface-cli whoami huggingface-cli repo create ``` With the `HfApi` class there are methods to query models, datasets, and Spaces by specific tags (e.g. if you want to list models compatible with your library): - **Models**: - `list_models()` - `model_info()` - `get_model_tags()` - **Datasets**: - `list_datasets()` - `dataset_info()` - `get_dataset_tags()` - **Spaces**: - `list_spaces()` - `space_info()` These lightly wrap around the API Endpoints. Documentation for valid parameters and descriptions can be found [here](https://huggingface.co/docs/hub/endpoints). ### Advanced programmatic repository management The `Repository` class helps manage both offline Git repositories and Hugging Face Hub repositories. Using the `Repository` class requires `git` and `git-lfs` to be installed. Instantiate a `Repository` object by calling it with a path to a local Git clone/repository: ```python >>> from huggingface_hub import Repository >>> repo = Repository("//") ``` The `Repository` takes a `clone_from` string as parameter. This can stay as `None` for offline management, but can also be set to any URL pointing to a Git repo to clone that repository in the specified directory: ```python >>> repo = Repository("huggingface-hub", clone_from="https://github.com/huggingface/huggingface_hub") ``` The `clone_from` method can also take any Hugging Face model ID as input, and will clone that repository: ```python >>> repo = Repository("w2v2", clone_from="facebook/wav2vec2-large-960h-lv60") ``` If the repository you're cloning is one of yours or one of your organisation's, then having the ability to commit and push to that repository is important. In order to do that, you should make sure to be logged-in using `huggingface-cli login`, and to have the `token` parameter set to `True` (the default) when instantiating the `Repository` object: ```python >>> repo = Repository("my-model", clone_from="/", token=True) ``` This works for models, datasets and spaces repositories; but you will need to explicitely specify the type for the last two options: ```python >>> repo = Repository("my-dataset", clone_from="/", token=True, repo_type="dataset") ``` You can also change between branches: ```python >>> repo = Repository("huggingface-hub", clone_from="/", revision='branch1') >>> repo.git_checkout("branch2") ``` The `clone_from` method can also take any Hugging Face model ID as input, and will clone that repository: ```python >>> repo = Repository("w2v2", clone_from="facebook/wav2vec2-large-960h-lv60") ``` Finally, you can choose to specify the Git username and email attributed to that clone directly by using the `git_user` and `git_email` parameters. When committing to that repository, Git will therefore be aware of who you are and who will be the author of the commits: ```python >>> repo = Repository( ... "my-dataset", ... clone_from="/", ... token=True, ... repo_type="dataset", ... git_user="MyName", ... git_email="me@cool.mail" ... ) ``` The repository can be managed through this object, through wrappers of traditional Git methods: - `git_add(pattern: str, auto_lfs_track: bool)`. The `auto_lfs_track` flag triggers auto tracking of large files (>10MB) with `git-lfs` - `git_commit(commit_message: str)` - `git_pull(rebase: bool)` - `git_push()` - `git_checkout(branch)` The `git_push` method has a parameter `blocking` which is `True` by default. When set to `False`, the push will happen behind the scenes - which can be helpful if you would like your script to continue on while the push is happening. LFS-tracking methods: - `lfs_track(pattern: Union[str, List[str]], filename: bool)`. Setting `filename` to `True` will use the `--filename` parameter, which will consider the pattern(s) as filenames, even if they contain special glob characters. - `lfs_untrack()`. - `auto_track_large_files()`: automatically tracks files that are larger than 10MB. Make sure to call this after adding files to the index. On top of these unitary methods lie some useful additional methods: - `push_to_hub(commit_message)`: consecutively does `git_add`, `git_commit` and `git_push`. - `commit(commit_message: str, track_large_files: bool)`: this is a context manager utility that handles committing to a repository. This automatically tracks large files (>10Mb) with `git-lfs`. The `track_large_files` argument can be set to `False` if you wish to ignore that behavior. These two methods also have support for the `blocking` parameter. Examples using the `commit` context manager: ```python >>> with Repository("text-files", clone_from="/text-files", token=True).commit("My first file :)"): ... with open("file.txt", "w+") as f: ... f.write(json.dumps({"hey": 8})) ``` ```python >>> import torch >>> model = torch.nn.Transformer() >>> with Repository("torch-model", clone_from="/torch-model", token=True).commit("My cool model :)"): ... torch.save(model.state_dict(), "model.pt") ``` ### Non-blocking behavior The pushing methods have access to a `blocking` boolean parameter to indicate whether the push should happen asynchronously. In order to see if the push has finished or its status code (to spot a failure), one should use the `command_queue` property on the `Repository` object. For example: ```python from huggingface_hub import Repository repo = Repository("", clone_from="/") with repo.commit("Commit message", blocking=False): # Save data last_command = repo.command_queue[-1] # Status of the push command last_command.status # Will return the status code # -> -1 will indicate the push is still ongoing # -> 0 will indicate the push has completed successfully # -> non-zero code indicates the error code if there was an error # if there was an error, the stderr may be inspected last_command.stderr # Whether the command finished or if it is still ongoing last_command.is_done # Whether the command errored-out. last_command.failed ``` When using `blocking=False`, the commands will be tracked and your script will exit only when all pushes are done, even if other errors happen in your script (a failed push counts as done). ### Need to upload very large (>5GB) files? To upload large files (>5GB 🔥) from git command-line, you need to install the custom transfer agent for git-lfs, bundled in this package. To install, just run: ```bash $ huggingface-cli lfs-enable-largefiles ``` This should be executed once for each model repo that contains a model file >5GB. If you just try to push a file bigger than 5GB without running that command, you will get an error with a message reminding you to run it. Finally, there's a `huggingface-cli lfs-multipart-upload` command but that one is internal (called by lfs directly) and is not meant to be called by the user.
## Using the Inference API wrapper `huggingface_hub` comes with a wrapper client to make calls to the Inference API! You can find some examples below, but we encourage you to visit the Inference API [documentation](https://api-inference.huggingface.co/docs/python/html/detailed_parameters.html) to review the specific parameters for the different tasks. When you instantiate the wrapper to the Inference API, you specify the model repository id. The pipeline (`text-classification`, `text-to-speech`, etc) is automatically extracted from the [repository](https://huggingface.co/docs/hub/main#how-is-a-models-type-of-inference-api-and-widget-determined), but you can also override it as shown below. ### Examples Here is a basic example of calling the Inference API for a `fill-mask` task using the `bert-base-uncased` model. The `fill-mask` task only expects a string (or list of strings) as input. ```python from huggingface_hub.inference_api import InferenceApi inference = InferenceApi("bert-base-uncased", token=API_TOKEN) inference(inputs="The goal of life is [MASK].") >> [{'sequence': 'the goal of life is life.', 'score': 0.10933292657136917, 'token': 2166, 'token_str': 'life'}] ``` This is an example of a task (`question-answering`) which requires a dictionary as input thas has the `question` and `context` keys. ```python inference = InferenceApi("deepset/roberta-base-squad2", token=API_TOKEN) inputs = {"question":"What's my name?", "context":"My name is Clara and I live in Berkeley."} inference(inputs) >> {'score': 0.9326569437980652, 'start': 11, 'end': 16, 'answer': 'Clara'} ``` Some tasks might also require additional params in the request. Here is an example using a `zero-shot-classification` model. ```python inference = InferenceApi("typeform/distilbert-base-uncased-mnli", token=API_TOKEN) inputs = "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!" params = {"candidate_labels":["refund", "legal", "faq"]} inference(inputs, params) >> {'sequence': 'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!', 'labels': ['refund', 'faq', 'legal'], 'scores': [0.9378499388694763, 0.04914155602455139, 0.013008488342165947]} ``` Finally, there are some models that might support multiple tasks. For example, `sentence-transformers` models can do `sentence-similarity` and `feature-extraction`. You can override the configured task when initializing the API. ```python inference = InferenceApi("bert-base-uncased", task="feature-extraction", token=API_TOKEN) ``` huggingface_hub-0.31.1/src/huggingface_hub/__init__.py000066400000000000000000001403301500667546600227260ustar00rootroot00000000000000# Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # *********** # `huggingface_hub` init has 2 modes: # - Normal usage: # If imported to use it, all modules and functions are lazy-loaded. This means # they exist at top level in module but are imported only the first time they are # used. This way, `from huggingface_hub import something` will import `something` # quickly without the hassle of importing all the features from `huggingface_hub`. # - Static check: # If statically analyzed, all modules and functions are loaded normally. This way # static typing check works properly as well as autocomplete in text editors and # IDEs. # # The static model imports are done inside the `if TYPE_CHECKING:` statement at # the bottom of this file. Since module/functions imports are duplicated, it is # mandatory to make sure to add them twice when adding one. This is checked in the # `make quality` command. # # To update the static imports, please run the following command and commit the changes. # ``` # # Use script # python utils/check_static_imports.py --update-file # # # Or run style on codebase # make style # ``` # # *********** # Lazy loader vendored from https://github.com/scientific-python/lazy_loader import importlib import os import sys from typing import TYPE_CHECKING __version__ = "0.31.1" # Alphabetical order of definitions is ensured in tests # WARNING: any comment added in this dictionary definition will be lost when # re-generating the file ! _SUBMOD_ATTRS = { "_commit_scheduler": [ "CommitScheduler", ], "_inference_endpoints": [ "InferenceEndpoint", "InferenceEndpointError", "InferenceEndpointStatus", "InferenceEndpointTimeoutError", "InferenceEndpointType", ], "_login": [ "auth_list", "auth_switch", "interpreter_login", "login", "logout", "notebook_login", ], "_snapshot_download": [ "snapshot_download", ], "_space_api": [ "SpaceHardware", "SpaceRuntime", "SpaceStage", "SpaceStorage", "SpaceVariable", ], "_tensorboard_logger": [ "HFSummaryWriter", ], "_webhooks_payload": [ "WebhookPayload", "WebhookPayloadComment", "WebhookPayloadDiscussion", "WebhookPayloadDiscussionChanges", "WebhookPayloadEvent", "WebhookPayloadMovedTo", "WebhookPayloadRepo", "WebhookPayloadUrl", "WebhookPayloadWebhook", ], "_webhooks_server": [ "WebhooksServer", "webhook_endpoint", ], "community": [ "Discussion", "DiscussionComment", "DiscussionCommit", "DiscussionEvent", "DiscussionStatusChange", "DiscussionTitleChange", "DiscussionWithDetails", ], "constants": [ "CONFIG_NAME", "FLAX_WEIGHTS_NAME", "HUGGINGFACE_CO_URL_HOME", "HUGGINGFACE_CO_URL_TEMPLATE", "PYTORCH_WEIGHTS_NAME", "REPO_TYPE_DATASET", "REPO_TYPE_MODEL", "REPO_TYPE_SPACE", "TF2_WEIGHTS_NAME", "TF_WEIGHTS_NAME", ], "fastai_utils": [ "_save_pretrained_fastai", "from_pretrained_fastai", "push_to_hub_fastai", ], "file_download": [ "HfFileMetadata", "_CACHED_NO_EXIST", "get_hf_file_metadata", "hf_hub_download", "hf_hub_url", "try_to_load_from_cache", ], "hf_api": [ "Collection", "CollectionItem", "CommitInfo", "CommitOperation", "CommitOperationAdd", "CommitOperationCopy", "CommitOperationDelete", "DatasetInfo", "GitCommitInfo", "GitRefInfo", "GitRefs", "HfApi", "ModelInfo", "RepoUrl", "SpaceInfo", "User", "UserLikes", "WebhookInfo", "WebhookWatchedItem", "accept_access_request", "add_collection_item", "add_space_secret", "add_space_variable", "auth_check", "cancel_access_request", "change_discussion_status", "comment_discussion", "create_branch", "create_collection", "create_commit", "create_discussion", "create_inference_endpoint", "create_inference_endpoint_from_catalog", "create_pull_request", "create_repo", "create_tag", "create_webhook", "dataset_info", "delete_branch", "delete_collection", "delete_collection_item", "delete_file", "delete_folder", "delete_inference_endpoint", "delete_repo", "delete_space_secret", "delete_space_storage", "delete_space_variable", "delete_tag", "delete_webhook", "disable_webhook", "duplicate_space", "edit_discussion_comment", "enable_webhook", "file_exists", "get_collection", "get_dataset_tags", "get_discussion_details", "get_full_repo_name", "get_inference_endpoint", "get_model_tags", "get_paths_info", "get_repo_discussions", "get_safetensors_metadata", "get_space_runtime", "get_space_variables", "get_token_permission", "get_user_overview", "get_webhook", "grant_access", "list_accepted_access_requests", "list_collections", "list_datasets", "list_inference_catalog", "list_inference_endpoints", "list_lfs_files", "list_liked_repos", "list_models", "list_organization_members", "list_papers", "list_pending_access_requests", "list_rejected_access_requests", "list_repo_commits", "list_repo_files", "list_repo_likers", "list_repo_refs", "list_repo_tree", "list_spaces", "list_user_followers", "list_user_following", "list_webhooks", "merge_pull_request", "model_info", "move_repo", "paper_info", "parse_safetensors_file_metadata", "pause_inference_endpoint", "pause_space", "permanently_delete_lfs_files", "preupload_lfs_files", "reject_access_request", "rename_discussion", "repo_exists", "repo_info", "repo_type_and_id_from_hf_id", "request_space_hardware", "request_space_storage", "restart_space", "resume_inference_endpoint", "revision_exists", "run_as_future", "scale_to_zero_inference_endpoint", "set_space_sleep_time", "space_info", "super_squash_history", "unlike", "update_collection_item", "update_collection_metadata", "update_inference_endpoint", "update_repo_settings", "update_repo_visibility", "update_webhook", "upload_file", "upload_folder", "upload_large_folder", "whoami", ], "hf_file_system": [ "HfFileSystem", "HfFileSystemFile", "HfFileSystemResolvedPath", "HfFileSystemStreamFile", ], "hub_mixin": [ "ModelHubMixin", "PyTorchModelHubMixin", ], "inference._client": [ "InferenceClient", "InferenceTimeoutError", ], "inference._generated._async_client": [ "AsyncInferenceClient", ], "inference._generated.types": [ "AudioClassificationInput", "AudioClassificationOutputElement", "AudioClassificationOutputTransform", "AudioClassificationParameters", "AudioToAudioInput", "AudioToAudioOutputElement", "AutomaticSpeechRecognitionEarlyStoppingEnum", "AutomaticSpeechRecognitionGenerationParameters", "AutomaticSpeechRecognitionInput", "AutomaticSpeechRecognitionOutput", "AutomaticSpeechRecognitionOutputChunk", "AutomaticSpeechRecognitionParameters", "ChatCompletionInput", "ChatCompletionInputFunctionDefinition", "ChatCompletionInputFunctionName", "ChatCompletionInputGrammarType", "ChatCompletionInputGrammarTypeType", "ChatCompletionInputMessage", "ChatCompletionInputMessageChunk", "ChatCompletionInputMessageChunkType", "ChatCompletionInputStreamOptions", "ChatCompletionInputTool", "ChatCompletionInputToolCall", "ChatCompletionInputToolChoiceClass", "ChatCompletionInputToolChoiceEnum", "ChatCompletionInputURL", "ChatCompletionOutput", "ChatCompletionOutputComplete", "ChatCompletionOutputFunctionDefinition", "ChatCompletionOutputLogprob", "ChatCompletionOutputLogprobs", "ChatCompletionOutputMessage", "ChatCompletionOutputToolCall", "ChatCompletionOutputTopLogprob", "ChatCompletionOutputUsage", "ChatCompletionStreamOutput", "ChatCompletionStreamOutputChoice", "ChatCompletionStreamOutputDelta", "ChatCompletionStreamOutputDeltaToolCall", "ChatCompletionStreamOutputFunction", "ChatCompletionStreamOutputLogprob", "ChatCompletionStreamOutputLogprobs", "ChatCompletionStreamOutputTopLogprob", "ChatCompletionStreamOutputUsage", "DepthEstimationInput", "DepthEstimationOutput", "DocumentQuestionAnsweringInput", "DocumentQuestionAnsweringInputData", "DocumentQuestionAnsweringOutputElement", "DocumentQuestionAnsweringParameters", "FeatureExtractionInput", "FeatureExtractionInputTruncationDirection", "FillMaskInput", "FillMaskOutputElement", "FillMaskParameters", "ImageClassificationInput", "ImageClassificationOutputElement", "ImageClassificationOutputTransform", "ImageClassificationParameters", "ImageSegmentationInput", "ImageSegmentationOutputElement", "ImageSegmentationParameters", "ImageSegmentationSubtask", "ImageToImageInput", "ImageToImageOutput", "ImageToImageParameters", "ImageToImageTargetSize", "ImageToTextEarlyStoppingEnum", "ImageToTextGenerationParameters", "ImageToTextInput", "ImageToTextOutput", "ImageToTextParameters", "ObjectDetectionBoundingBox", "ObjectDetectionInput", "ObjectDetectionOutputElement", "ObjectDetectionParameters", "Padding", "QuestionAnsweringInput", "QuestionAnsweringInputData", "QuestionAnsweringOutputElement", "QuestionAnsweringParameters", "SentenceSimilarityInput", "SentenceSimilarityInputData", "SummarizationInput", "SummarizationOutput", "SummarizationParameters", "SummarizationTruncationStrategy", "TableQuestionAnsweringInput", "TableQuestionAnsweringInputData", "TableQuestionAnsweringOutputElement", "TableQuestionAnsweringParameters", "Text2TextGenerationInput", "Text2TextGenerationOutput", "Text2TextGenerationParameters", "Text2TextGenerationTruncationStrategy", "TextClassificationInput", "TextClassificationOutputElement", "TextClassificationOutputTransform", "TextClassificationParameters", "TextGenerationInput", "TextGenerationInputGenerateParameters", "TextGenerationInputGrammarType", "TextGenerationOutput", "TextGenerationOutputBestOfSequence", "TextGenerationOutputDetails", "TextGenerationOutputFinishReason", "TextGenerationOutputPrefillToken", "TextGenerationOutputToken", "TextGenerationStreamOutput", "TextGenerationStreamOutputStreamDetails", "TextGenerationStreamOutputToken", "TextToAudioEarlyStoppingEnum", "TextToAudioGenerationParameters", "TextToAudioInput", "TextToAudioOutput", "TextToAudioParameters", "TextToImageInput", "TextToImageOutput", "TextToImageParameters", "TextToSpeechEarlyStoppingEnum", "TextToSpeechGenerationParameters", "TextToSpeechInput", "TextToSpeechOutput", "TextToSpeechParameters", "TextToVideoInput", "TextToVideoOutput", "TextToVideoParameters", "TokenClassificationAggregationStrategy", "TokenClassificationInput", "TokenClassificationOutputElement", "TokenClassificationParameters", "TranslationInput", "TranslationOutput", "TranslationParameters", "TranslationTruncationStrategy", "TypeEnum", "VideoClassificationInput", "VideoClassificationOutputElement", "VideoClassificationOutputTransform", "VideoClassificationParameters", "VisualQuestionAnsweringInput", "VisualQuestionAnsweringInputData", "VisualQuestionAnsweringOutputElement", "VisualQuestionAnsweringParameters", "ZeroShotClassificationInput", "ZeroShotClassificationOutputElement", "ZeroShotClassificationParameters", "ZeroShotImageClassificationInput", "ZeroShotImageClassificationOutputElement", "ZeroShotImageClassificationParameters", "ZeroShotObjectDetectionBoundingBox", "ZeroShotObjectDetectionInput", "ZeroShotObjectDetectionOutputElement", "ZeroShotObjectDetectionParameters", ], "inference_api": [ "InferenceApi", ], "keras_mixin": [ "KerasModelHubMixin", "from_pretrained_keras", "push_to_hub_keras", "save_pretrained_keras", ], "repocard": [ "DatasetCard", "ModelCard", "RepoCard", "SpaceCard", "metadata_eval_result", "metadata_load", "metadata_save", "metadata_update", ], "repocard_data": [ "CardData", "DatasetCardData", "EvalResult", "ModelCardData", "SpaceCardData", ], "repository": [ "Repository", ], "serialization": [ "StateDictSplit", "get_tf_storage_size", "get_torch_storage_id", "get_torch_storage_size", "load_state_dict_from_file", "load_torch_model", "save_torch_model", "save_torch_state_dict", "split_state_dict_into_shards_factory", "split_tf_state_dict_into_shards", "split_torch_state_dict_into_shards", ], "serialization._dduf": [ "DDUFEntry", "export_entries_as_dduf", "export_folder_as_dduf", "read_dduf_file", ], "utils": [ "CacheNotFound", "CachedFileInfo", "CachedRepoInfo", "CachedRevisionInfo", "CorruptedCacheException", "DeleteCacheStrategy", "HFCacheInfo", "HfFolder", "cached_assets_path", "configure_http_backend", "dump_environment_info", "get_session", "get_token", "logging", "scan_cache_dir", ], } # WARNING: __all__ is generated automatically, Any manual edit will be lost when re-generating this file ! # # To update the static imports, please run the following command and commit the changes. # ``` # # Use script # python utils/check_all_variable.py --update # # # Or run style on codebase # make style # ``` __all__ = [ "AsyncInferenceClient", "AudioClassificationInput", "AudioClassificationOutputElement", "AudioClassificationOutputTransform", "AudioClassificationParameters", "AudioToAudioInput", "AudioToAudioOutputElement", "AutomaticSpeechRecognitionEarlyStoppingEnum", "AutomaticSpeechRecognitionGenerationParameters", "AutomaticSpeechRecognitionInput", "AutomaticSpeechRecognitionOutput", "AutomaticSpeechRecognitionOutputChunk", "AutomaticSpeechRecognitionParameters", "CONFIG_NAME", "CacheNotFound", "CachedFileInfo", "CachedRepoInfo", "CachedRevisionInfo", "CardData", "ChatCompletionInput", "ChatCompletionInputFunctionDefinition", "ChatCompletionInputFunctionName", "ChatCompletionInputGrammarType", "ChatCompletionInputGrammarTypeType", "ChatCompletionInputMessage", "ChatCompletionInputMessageChunk", "ChatCompletionInputMessageChunkType", "ChatCompletionInputStreamOptions", "ChatCompletionInputTool", "ChatCompletionInputToolCall", "ChatCompletionInputToolChoiceClass", "ChatCompletionInputToolChoiceEnum", "ChatCompletionInputURL", "ChatCompletionOutput", "ChatCompletionOutputComplete", "ChatCompletionOutputFunctionDefinition", "ChatCompletionOutputLogprob", "ChatCompletionOutputLogprobs", "ChatCompletionOutputMessage", "ChatCompletionOutputToolCall", "ChatCompletionOutputTopLogprob", "ChatCompletionOutputUsage", "ChatCompletionStreamOutput", "ChatCompletionStreamOutputChoice", "ChatCompletionStreamOutputDelta", "ChatCompletionStreamOutputDeltaToolCall", "ChatCompletionStreamOutputFunction", "ChatCompletionStreamOutputLogprob", "ChatCompletionStreamOutputLogprobs", "ChatCompletionStreamOutputTopLogprob", "ChatCompletionStreamOutputUsage", "Collection", "CollectionItem", "CommitInfo", "CommitOperation", "CommitOperationAdd", "CommitOperationCopy", "CommitOperationDelete", "CommitScheduler", "CorruptedCacheException", "DDUFEntry", "DatasetCard", "DatasetCardData", "DatasetInfo", "DeleteCacheStrategy", "DepthEstimationInput", "DepthEstimationOutput", "Discussion", "DiscussionComment", "DiscussionCommit", "DiscussionEvent", "DiscussionStatusChange", "DiscussionTitleChange", "DiscussionWithDetails", "DocumentQuestionAnsweringInput", "DocumentQuestionAnsweringInputData", "DocumentQuestionAnsweringOutputElement", "DocumentQuestionAnsweringParameters", "EvalResult", "FLAX_WEIGHTS_NAME", "FeatureExtractionInput", "FeatureExtractionInputTruncationDirection", "FillMaskInput", "FillMaskOutputElement", "FillMaskParameters", "GitCommitInfo", "GitRefInfo", "GitRefs", "HFCacheInfo", "HFSummaryWriter", "HUGGINGFACE_CO_URL_HOME", "HUGGINGFACE_CO_URL_TEMPLATE", "HfApi", "HfFileMetadata", "HfFileSystem", "HfFileSystemFile", "HfFileSystemResolvedPath", "HfFileSystemStreamFile", "HfFolder", "ImageClassificationInput", "ImageClassificationOutputElement", "ImageClassificationOutputTransform", "ImageClassificationParameters", "ImageSegmentationInput", "ImageSegmentationOutputElement", "ImageSegmentationParameters", "ImageSegmentationSubtask", "ImageToImageInput", "ImageToImageOutput", "ImageToImageParameters", "ImageToImageTargetSize", "ImageToTextEarlyStoppingEnum", "ImageToTextGenerationParameters", "ImageToTextInput", "ImageToTextOutput", "ImageToTextParameters", "InferenceApi", "InferenceClient", "InferenceEndpoint", "InferenceEndpointError", "InferenceEndpointStatus", "InferenceEndpointTimeoutError", "InferenceEndpointType", "InferenceTimeoutError", "KerasModelHubMixin", "ModelCard", "ModelCardData", "ModelHubMixin", "ModelInfo", "ObjectDetectionBoundingBox", "ObjectDetectionInput", "ObjectDetectionOutputElement", "ObjectDetectionParameters", "PYTORCH_WEIGHTS_NAME", "Padding", "PyTorchModelHubMixin", "QuestionAnsweringInput", "QuestionAnsweringInputData", "QuestionAnsweringOutputElement", "QuestionAnsweringParameters", "REPO_TYPE_DATASET", "REPO_TYPE_MODEL", "REPO_TYPE_SPACE", "RepoCard", "RepoUrl", "Repository", "SentenceSimilarityInput", "SentenceSimilarityInputData", "SpaceCard", "SpaceCardData", "SpaceHardware", "SpaceInfo", "SpaceRuntime", "SpaceStage", "SpaceStorage", "SpaceVariable", "StateDictSplit", "SummarizationInput", "SummarizationOutput", "SummarizationParameters", "SummarizationTruncationStrategy", "TF2_WEIGHTS_NAME", "TF_WEIGHTS_NAME", "TableQuestionAnsweringInput", "TableQuestionAnsweringInputData", "TableQuestionAnsweringOutputElement", "TableQuestionAnsweringParameters", "Text2TextGenerationInput", "Text2TextGenerationOutput", "Text2TextGenerationParameters", "Text2TextGenerationTruncationStrategy", "TextClassificationInput", "TextClassificationOutputElement", "TextClassificationOutputTransform", "TextClassificationParameters", "TextGenerationInput", "TextGenerationInputGenerateParameters", "TextGenerationInputGrammarType", "TextGenerationOutput", "TextGenerationOutputBestOfSequence", "TextGenerationOutputDetails", "TextGenerationOutputFinishReason", "TextGenerationOutputPrefillToken", "TextGenerationOutputToken", "TextGenerationStreamOutput", "TextGenerationStreamOutputStreamDetails", "TextGenerationStreamOutputToken", "TextToAudioEarlyStoppingEnum", "TextToAudioGenerationParameters", "TextToAudioInput", "TextToAudioOutput", "TextToAudioParameters", "TextToImageInput", "TextToImageOutput", "TextToImageParameters", "TextToSpeechEarlyStoppingEnum", "TextToSpeechGenerationParameters", "TextToSpeechInput", "TextToSpeechOutput", "TextToSpeechParameters", "TextToVideoInput", "TextToVideoOutput", "TextToVideoParameters", "TokenClassificationAggregationStrategy", "TokenClassificationInput", "TokenClassificationOutputElement", "TokenClassificationParameters", "TranslationInput", "TranslationOutput", "TranslationParameters", "TranslationTruncationStrategy", "TypeEnum", "User", "UserLikes", "VideoClassificationInput", "VideoClassificationOutputElement", "VideoClassificationOutputTransform", "VideoClassificationParameters", "VisualQuestionAnsweringInput", "VisualQuestionAnsweringInputData", "VisualQuestionAnsweringOutputElement", "VisualQuestionAnsweringParameters", "WebhookInfo", "WebhookPayload", "WebhookPayloadComment", "WebhookPayloadDiscussion", "WebhookPayloadDiscussionChanges", "WebhookPayloadEvent", "WebhookPayloadMovedTo", "WebhookPayloadRepo", "WebhookPayloadUrl", "WebhookPayloadWebhook", "WebhookWatchedItem", "WebhooksServer", "ZeroShotClassificationInput", "ZeroShotClassificationOutputElement", "ZeroShotClassificationParameters", "ZeroShotImageClassificationInput", "ZeroShotImageClassificationOutputElement", "ZeroShotImageClassificationParameters", "ZeroShotObjectDetectionBoundingBox", "ZeroShotObjectDetectionInput", "ZeroShotObjectDetectionOutputElement", "ZeroShotObjectDetectionParameters", "_CACHED_NO_EXIST", "_save_pretrained_fastai", "accept_access_request", "add_collection_item", "add_space_secret", "add_space_variable", "auth_check", "auth_list", "auth_switch", "cached_assets_path", "cancel_access_request", "change_discussion_status", "comment_discussion", "configure_http_backend", "create_branch", "create_collection", "create_commit", "create_discussion", "create_inference_endpoint", "create_inference_endpoint_from_catalog", "create_pull_request", "create_repo", "create_tag", "create_webhook", "dataset_info", "delete_branch", "delete_collection", "delete_collection_item", "delete_file", "delete_folder", "delete_inference_endpoint", "delete_repo", "delete_space_secret", "delete_space_storage", "delete_space_variable", "delete_tag", "delete_webhook", "disable_webhook", "dump_environment_info", "duplicate_space", "edit_discussion_comment", "enable_webhook", "export_entries_as_dduf", "export_folder_as_dduf", "file_exists", "from_pretrained_fastai", "from_pretrained_keras", "get_collection", "get_dataset_tags", "get_discussion_details", "get_full_repo_name", "get_hf_file_metadata", "get_inference_endpoint", "get_model_tags", "get_paths_info", "get_repo_discussions", "get_safetensors_metadata", "get_session", "get_space_runtime", "get_space_variables", "get_tf_storage_size", "get_token", "get_token_permission", "get_torch_storage_id", "get_torch_storage_size", "get_user_overview", "get_webhook", "grant_access", "hf_hub_download", "hf_hub_url", "interpreter_login", "list_accepted_access_requests", "list_collections", "list_datasets", "list_inference_catalog", "list_inference_endpoints", "list_lfs_files", "list_liked_repos", "list_models", "list_organization_members", "list_papers", "list_pending_access_requests", "list_rejected_access_requests", "list_repo_commits", "list_repo_files", "list_repo_likers", "list_repo_refs", "list_repo_tree", "list_spaces", "list_user_followers", "list_user_following", "list_webhooks", "load_state_dict_from_file", "load_torch_model", "logging", "login", "logout", "merge_pull_request", "metadata_eval_result", "metadata_load", "metadata_save", "metadata_update", "model_info", "move_repo", "notebook_login", "paper_info", "parse_safetensors_file_metadata", "pause_inference_endpoint", "pause_space", "permanently_delete_lfs_files", "preupload_lfs_files", "push_to_hub_fastai", "push_to_hub_keras", "read_dduf_file", "reject_access_request", "rename_discussion", "repo_exists", "repo_info", "repo_type_and_id_from_hf_id", "request_space_hardware", "request_space_storage", "restart_space", "resume_inference_endpoint", "revision_exists", "run_as_future", "save_pretrained_keras", "save_torch_model", "save_torch_state_dict", "scale_to_zero_inference_endpoint", "scan_cache_dir", "set_space_sleep_time", "snapshot_download", "space_info", "split_state_dict_into_shards_factory", "split_tf_state_dict_into_shards", "split_torch_state_dict_into_shards", "super_squash_history", "try_to_load_from_cache", "unlike", "update_collection_item", "update_collection_metadata", "update_inference_endpoint", "update_repo_settings", "update_repo_visibility", "update_webhook", "upload_file", "upload_folder", "upload_large_folder", "webhook_endpoint", "whoami", ] def _attach(package_name, submodules=None, submod_attrs=None): """Attach lazily loaded submodules, functions, or other attributes. Typically, modules import submodules and attributes as follows: ```py import mysubmodule import anothersubmodule from .foo import someattr ``` The idea is to replace a package's `__getattr__`, `__dir__`, such that all imports work exactly the way they would with normal imports, except that the import occurs upon first use. The typical way to call this function, replacing the above imports, is: ```python __getattr__, __dir__ = lazy.attach( __name__, ['mysubmodule', 'anothersubmodule'], {'foo': ['someattr']} ) ``` This functionality requires Python 3.7 or higher. Args: package_name (`str`): Typically use `__name__`. submodules (`set`): List of submodules to attach. submod_attrs (`dict`): Dictionary of submodule -> list of attributes / functions. These attributes are imported as they are used. Returns: __getattr__, __dir__, __all__ """ if submod_attrs is None: submod_attrs = {} if submodules is None: submodules = set() else: submodules = set(submodules) attr_to_modules = {attr: mod for mod, attrs in submod_attrs.items() for attr in attrs} def __getattr__(name): if name in submodules: try: return importlib.import_module(f"{package_name}.{name}") except Exception as e: print(f"Error importing {package_name}.{name}: {e}") raise elif name in attr_to_modules: submod_path = f"{package_name}.{attr_to_modules[name]}" try: submod = importlib.import_module(submod_path) except Exception as e: print(f"Error importing {submod_path}: {e}") raise attr = getattr(submod, name) # If the attribute lives in a file (module) with the same # name as the attribute, ensure that the attribute and *not* # the module is accessible on the package. if name == attr_to_modules[name]: pkg = sys.modules[package_name] pkg.__dict__[name] = attr return attr else: raise AttributeError(f"No {package_name} attribute {name}") def __dir__(): return __all__ return __getattr__, __dir__ __getattr__, __dir__ = _attach(__name__, submodules=[], submod_attrs=_SUBMOD_ATTRS) if os.environ.get("EAGER_IMPORT", ""): for attr in __all__: __getattr__(attr) # WARNING: any content below this statement is generated automatically. Any manual edit # will be lost when re-generating this file ! # # To update the static imports, please run the following command and commit the changes. # ``` # # Use script # python utils/check_static_imports.py --update # # # Or run style on codebase # make style # ``` if TYPE_CHECKING: # pragma: no cover from ._commit_scheduler import CommitScheduler # noqa: F401 from ._inference_endpoints import ( InferenceEndpoint, # noqa: F401 InferenceEndpointError, # noqa: F401 InferenceEndpointStatus, # noqa: F401 InferenceEndpointTimeoutError, # noqa: F401 InferenceEndpointType, # noqa: F401 ) from ._login import ( auth_list, # noqa: F401 auth_switch, # noqa: F401 interpreter_login, # noqa: F401 login, # noqa: F401 logout, # noqa: F401 notebook_login, # noqa: F401 ) from ._snapshot_download import snapshot_download # noqa: F401 from ._space_api import ( SpaceHardware, # noqa: F401 SpaceRuntime, # noqa: F401 SpaceStage, # noqa: F401 SpaceStorage, # noqa: F401 SpaceVariable, # noqa: F401 ) from ._tensorboard_logger import HFSummaryWriter # noqa: F401 from ._webhooks_payload import ( WebhookPayload, # noqa: F401 WebhookPayloadComment, # noqa: F401 WebhookPayloadDiscussion, # noqa: F401 WebhookPayloadDiscussionChanges, # noqa: F401 WebhookPayloadEvent, # noqa: F401 WebhookPayloadMovedTo, # noqa: F401 WebhookPayloadRepo, # noqa: F401 WebhookPayloadUrl, # noqa: F401 WebhookPayloadWebhook, # noqa: F401 ) from ._webhooks_server import ( WebhooksServer, # noqa: F401 webhook_endpoint, # noqa: F401 ) from .community import ( Discussion, # noqa: F401 DiscussionComment, # noqa: F401 DiscussionCommit, # noqa: F401 DiscussionEvent, # noqa: F401 DiscussionStatusChange, # noqa: F401 DiscussionTitleChange, # noqa: F401 DiscussionWithDetails, # noqa: F401 ) from .constants import ( CONFIG_NAME, # noqa: F401 FLAX_WEIGHTS_NAME, # noqa: F401 HUGGINGFACE_CO_URL_HOME, # noqa: F401 HUGGINGFACE_CO_URL_TEMPLATE, # noqa: F401 PYTORCH_WEIGHTS_NAME, # noqa: F401 REPO_TYPE_DATASET, # noqa: F401 REPO_TYPE_MODEL, # noqa: F401 REPO_TYPE_SPACE, # noqa: F401 TF2_WEIGHTS_NAME, # noqa: F401 TF_WEIGHTS_NAME, # noqa: F401 ) from .fastai_utils import ( _save_pretrained_fastai, # noqa: F401 from_pretrained_fastai, # noqa: F401 push_to_hub_fastai, # noqa: F401 ) from .file_download import ( _CACHED_NO_EXIST, # noqa: F401 HfFileMetadata, # noqa: F401 get_hf_file_metadata, # noqa: F401 hf_hub_download, # noqa: F401 hf_hub_url, # noqa: F401 try_to_load_from_cache, # noqa: F401 ) from .hf_api import ( Collection, # noqa: F401 CollectionItem, # noqa: F401 CommitInfo, # noqa: F401 CommitOperation, # noqa: F401 CommitOperationAdd, # noqa: F401 CommitOperationCopy, # noqa: F401 CommitOperationDelete, # noqa: F401 DatasetInfo, # noqa: F401 GitCommitInfo, # noqa: F401 GitRefInfo, # noqa: F401 GitRefs, # noqa: F401 HfApi, # noqa: F401 ModelInfo, # noqa: F401 RepoUrl, # noqa: F401 SpaceInfo, # noqa: F401 User, # noqa: F401 UserLikes, # noqa: F401 WebhookInfo, # noqa: F401 WebhookWatchedItem, # noqa: F401 accept_access_request, # noqa: F401 add_collection_item, # noqa: F401 add_space_secret, # noqa: F401 add_space_variable, # noqa: F401 auth_check, # noqa: F401 cancel_access_request, # noqa: F401 change_discussion_status, # noqa: F401 comment_discussion, # noqa: F401 create_branch, # noqa: F401 create_collection, # noqa: F401 create_commit, # noqa: F401 create_discussion, # noqa: F401 create_inference_endpoint, # noqa: F401 create_inference_endpoint_from_catalog, # noqa: F401 create_pull_request, # noqa: F401 create_repo, # noqa: F401 create_tag, # noqa: F401 create_webhook, # noqa: F401 dataset_info, # noqa: F401 delete_branch, # noqa: F401 delete_collection, # noqa: F401 delete_collection_item, # noqa: F401 delete_file, # noqa: F401 delete_folder, # noqa: F401 delete_inference_endpoint, # noqa: F401 delete_repo, # noqa: F401 delete_space_secret, # noqa: F401 delete_space_storage, # noqa: F401 delete_space_variable, # noqa: F401 delete_tag, # noqa: F401 delete_webhook, # noqa: F401 disable_webhook, # noqa: F401 duplicate_space, # noqa: F401 edit_discussion_comment, # noqa: F401 enable_webhook, # noqa: F401 file_exists, # noqa: F401 get_collection, # noqa: F401 get_dataset_tags, # noqa: F401 get_discussion_details, # noqa: F401 get_full_repo_name, # noqa: F401 get_inference_endpoint, # noqa: F401 get_model_tags, # noqa: F401 get_paths_info, # noqa: F401 get_repo_discussions, # noqa: F401 get_safetensors_metadata, # noqa: F401 get_space_runtime, # noqa: F401 get_space_variables, # noqa: F401 get_token_permission, # noqa: F401 get_user_overview, # noqa: F401 get_webhook, # noqa: F401 grant_access, # noqa: F401 list_accepted_access_requests, # noqa: F401 list_collections, # noqa: F401 list_datasets, # noqa: F401 list_inference_catalog, # noqa: F401 list_inference_endpoints, # noqa: F401 list_lfs_files, # noqa: F401 list_liked_repos, # noqa: F401 list_models, # noqa: F401 list_organization_members, # noqa: F401 list_papers, # noqa: F401 list_pending_access_requests, # noqa: F401 list_rejected_access_requests, # noqa: F401 list_repo_commits, # noqa: F401 list_repo_files, # noqa: F401 list_repo_likers, # noqa: F401 list_repo_refs, # noqa: F401 list_repo_tree, # noqa: F401 list_spaces, # noqa: F401 list_user_followers, # noqa: F401 list_user_following, # noqa: F401 list_webhooks, # noqa: F401 merge_pull_request, # noqa: F401 model_info, # noqa: F401 move_repo, # noqa: F401 paper_info, # noqa: F401 parse_safetensors_file_metadata, # noqa: F401 pause_inference_endpoint, # noqa: F401 pause_space, # noqa: F401 permanently_delete_lfs_files, # noqa: F401 preupload_lfs_files, # noqa: F401 reject_access_request, # noqa: F401 rename_discussion, # noqa: F401 repo_exists, # noqa: F401 repo_info, # noqa: F401 repo_type_and_id_from_hf_id, # noqa: F401 request_space_hardware, # noqa: F401 request_space_storage, # noqa: F401 restart_space, # noqa: F401 resume_inference_endpoint, # noqa: F401 revision_exists, # noqa: F401 run_as_future, # noqa: F401 scale_to_zero_inference_endpoint, # noqa: F401 set_space_sleep_time, # noqa: F401 space_info, # noqa: F401 super_squash_history, # noqa: F401 unlike, # noqa: F401 update_collection_item, # noqa: F401 update_collection_metadata, # noqa: F401 update_inference_endpoint, # noqa: F401 update_repo_settings, # noqa: F401 update_repo_visibility, # noqa: F401 update_webhook, # noqa: F401 upload_file, # noqa: F401 upload_folder, # noqa: F401 upload_large_folder, # noqa: F401 whoami, # noqa: F401 ) from .hf_file_system import ( HfFileSystem, # noqa: F401 HfFileSystemFile, # noqa: F401 HfFileSystemResolvedPath, # noqa: F401 HfFileSystemStreamFile, # noqa: F401 ) from .hub_mixin import ( ModelHubMixin, # noqa: F401 PyTorchModelHubMixin, # noqa: F401 ) from .inference._client import ( InferenceClient, # noqa: F401 InferenceTimeoutError, # noqa: F401 ) from .inference._generated._async_client import AsyncInferenceClient # noqa: F401 from .inference._generated.types import ( AudioClassificationInput, # noqa: F401 AudioClassificationOutputElement, # noqa: F401 AudioClassificationOutputTransform, # noqa: F401 AudioClassificationParameters, # noqa: F401 AudioToAudioInput, # noqa: F401 AudioToAudioOutputElement, # noqa: F401 AutomaticSpeechRecognitionEarlyStoppingEnum, # noqa: F401 AutomaticSpeechRecognitionGenerationParameters, # noqa: F401 AutomaticSpeechRecognitionInput, # noqa: F401 AutomaticSpeechRecognitionOutput, # noqa: F401 AutomaticSpeechRecognitionOutputChunk, # noqa: F401 AutomaticSpeechRecognitionParameters, # noqa: F401 ChatCompletionInput, # noqa: F401 ChatCompletionInputFunctionDefinition, # noqa: F401 ChatCompletionInputFunctionName, # noqa: F401 ChatCompletionInputGrammarType, # noqa: F401 ChatCompletionInputGrammarTypeType, # noqa: F401 ChatCompletionInputMessage, # noqa: F401 ChatCompletionInputMessageChunk, # noqa: F401 ChatCompletionInputMessageChunkType, # noqa: F401 ChatCompletionInputStreamOptions, # noqa: F401 ChatCompletionInputTool, # noqa: F401 ChatCompletionInputToolCall, # noqa: F401 ChatCompletionInputToolChoiceClass, # noqa: F401 ChatCompletionInputToolChoiceEnum, # noqa: F401 ChatCompletionInputURL, # noqa: F401 ChatCompletionOutput, # noqa: F401 ChatCompletionOutputComplete, # noqa: F401 ChatCompletionOutputFunctionDefinition, # noqa: F401 ChatCompletionOutputLogprob, # noqa: F401 ChatCompletionOutputLogprobs, # noqa: F401 ChatCompletionOutputMessage, # noqa: F401 ChatCompletionOutputToolCall, # noqa: F401 ChatCompletionOutputTopLogprob, # noqa: F401 ChatCompletionOutputUsage, # noqa: F401 ChatCompletionStreamOutput, # noqa: F401 ChatCompletionStreamOutputChoice, # noqa: F401 ChatCompletionStreamOutputDelta, # noqa: F401 ChatCompletionStreamOutputDeltaToolCall, # noqa: F401 ChatCompletionStreamOutputFunction, # noqa: F401 ChatCompletionStreamOutputLogprob, # noqa: F401 ChatCompletionStreamOutputLogprobs, # noqa: F401 ChatCompletionStreamOutputTopLogprob, # noqa: F401 ChatCompletionStreamOutputUsage, # noqa: F401 DepthEstimationInput, # noqa: F401 DepthEstimationOutput, # noqa: F401 DocumentQuestionAnsweringInput, # noqa: F401 DocumentQuestionAnsweringInputData, # noqa: F401 DocumentQuestionAnsweringOutputElement, # noqa: F401 DocumentQuestionAnsweringParameters, # noqa: F401 FeatureExtractionInput, # noqa: F401 FeatureExtractionInputTruncationDirection, # noqa: F401 FillMaskInput, # noqa: F401 FillMaskOutputElement, # noqa: F401 FillMaskParameters, # noqa: F401 ImageClassificationInput, # noqa: F401 ImageClassificationOutputElement, # noqa: F401 ImageClassificationOutputTransform, # noqa: F401 ImageClassificationParameters, # noqa: F401 ImageSegmentationInput, # noqa: F401 ImageSegmentationOutputElement, # noqa: F401 ImageSegmentationParameters, # noqa: F401 ImageSegmentationSubtask, # noqa: F401 ImageToImageInput, # noqa: F401 ImageToImageOutput, # noqa: F401 ImageToImageParameters, # noqa: F401 ImageToImageTargetSize, # noqa: F401 ImageToTextEarlyStoppingEnum, # noqa: F401 ImageToTextGenerationParameters, # noqa: F401 ImageToTextInput, # noqa: F401 ImageToTextOutput, # noqa: F401 ImageToTextParameters, # noqa: F401 ObjectDetectionBoundingBox, # noqa: F401 ObjectDetectionInput, # noqa: F401 ObjectDetectionOutputElement, # noqa: F401 ObjectDetectionParameters, # noqa: F401 Padding, # noqa: F401 QuestionAnsweringInput, # noqa: F401 QuestionAnsweringInputData, # noqa: F401 QuestionAnsweringOutputElement, # noqa: F401 QuestionAnsweringParameters, # noqa: F401 SentenceSimilarityInput, # noqa: F401 SentenceSimilarityInputData, # noqa: F401 SummarizationInput, # noqa: F401 SummarizationOutput, # noqa: F401 SummarizationParameters, # noqa: F401 SummarizationTruncationStrategy, # noqa: F401 TableQuestionAnsweringInput, # noqa: F401 TableQuestionAnsweringInputData, # noqa: F401 TableQuestionAnsweringOutputElement, # noqa: F401 TableQuestionAnsweringParameters, # noqa: F401 Text2TextGenerationInput, # noqa: F401 Text2TextGenerationOutput, # noqa: F401 Text2TextGenerationParameters, # noqa: F401 Text2TextGenerationTruncationStrategy, # noqa: F401 TextClassificationInput, # noqa: F401 TextClassificationOutputElement, # noqa: F401 TextClassificationOutputTransform, # noqa: F401 TextClassificationParameters, # noqa: F401 TextGenerationInput, # noqa: F401 TextGenerationInputGenerateParameters, # noqa: F401 TextGenerationInputGrammarType, # noqa: F401 TextGenerationOutput, # noqa: F401 TextGenerationOutputBestOfSequence, # noqa: F401 TextGenerationOutputDetails, # noqa: F401 TextGenerationOutputFinishReason, # noqa: F401 TextGenerationOutputPrefillToken, # noqa: F401 TextGenerationOutputToken, # noqa: F401 TextGenerationStreamOutput, # noqa: F401 TextGenerationStreamOutputStreamDetails, # noqa: F401 TextGenerationStreamOutputToken, # noqa: F401 TextToAudioEarlyStoppingEnum, # noqa: F401 TextToAudioGenerationParameters, # noqa: F401 TextToAudioInput, # noqa: F401 TextToAudioOutput, # noqa: F401 TextToAudioParameters, # noqa: F401 TextToImageInput, # noqa: F401 TextToImageOutput, # noqa: F401 TextToImageParameters, # noqa: F401 TextToSpeechEarlyStoppingEnum, # noqa: F401 TextToSpeechGenerationParameters, # noqa: F401 TextToSpeechInput, # noqa: F401 TextToSpeechOutput, # noqa: F401 TextToSpeechParameters, # noqa: F401 TextToVideoInput, # noqa: F401 TextToVideoOutput, # noqa: F401 TextToVideoParameters, # noqa: F401 TokenClassificationAggregationStrategy, # noqa: F401 TokenClassificationInput, # noqa: F401 TokenClassificationOutputElement, # noqa: F401 TokenClassificationParameters, # noqa: F401 TranslationInput, # noqa: F401 TranslationOutput, # noqa: F401 TranslationParameters, # noqa: F401 TranslationTruncationStrategy, # noqa: F401 TypeEnum, # noqa: F401 VideoClassificationInput, # noqa: F401 VideoClassificationOutputElement, # noqa: F401 VideoClassificationOutputTransform, # noqa: F401 VideoClassificationParameters, # noqa: F401 VisualQuestionAnsweringInput, # noqa: F401 VisualQuestionAnsweringInputData, # noqa: F401 VisualQuestionAnsweringOutputElement, # noqa: F401 VisualQuestionAnsweringParameters, # noqa: F401 ZeroShotClassificationInput, # noqa: F401 ZeroShotClassificationOutputElement, # noqa: F401 ZeroShotClassificationParameters, # noqa: F401 ZeroShotImageClassificationInput, # noqa: F401 ZeroShotImageClassificationOutputElement, # noqa: F401 ZeroShotImageClassificationParameters, # noqa: F401 ZeroShotObjectDetectionBoundingBox, # noqa: F401 ZeroShotObjectDetectionInput, # noqa: F401 ZeroShotObjectDetectionOutputElement, # noqa: F401 ZeroShotObjectDetectionParameters, # noqa: F401 ) from .inference_api import InferenceApi # noqa: F401 from .keras_mixin import ( KerasModelHubMixin, # noqa: F401 from_pretrained_keras, # noqa: F401 push_to_hub_keras, # noqa: F401 save_pretrained_keras, # noqa: F401 ) from .repocard import ( DatasetCard, # noqa: F401 ModelCard, # noqa: F401 RepoCard, # noqa: F401 SpaceCard, # noqa: F401 metadata_eval_result, # noqa: F401 metadata_load, # noqa: F401 metadata_save, # noqa: F401 metadata_update, # noqa: F401 ) from .repocard_data import ( CardData, # noqa: F401 DatasetCardData, # noqa: F401 EvalResult, # noqa: F401 ModelCardData, # noqa: F401 SpaceCardData, # noqa: F401 ) from .repository import Repository # noqa: F401 from .serialization import ( StateDictSplit, # noqa: F401 get_tf_storage_size, # noqa: F401 get_torch_storage_id, # noqa: F401 get_torch_storage_size, # noqa: F401 load_state_dict_from_file, # noqa: F401 load_torch_model, # noqa: F401 save_torch_model, # noqa: F401 save_torch_state_dict, # noqa: F401 split_state_dict_into_shards_factory, # noqa: F401 split_tf_state_dict_into_shards, # noqa: F401 split_torch_state_dict_into_shards, # noqa: F401 ) from .serialization._dduf import ( DDUFEntry, # noqa: F401 export_entries_as_dduf, # noqa: F401 export_folder_as_dduf, # noqa: F401 read_dduf_file, # noqa: F401 ) from .utils import ( CachedFileInfo, # noqa: F401 CachedRepoInfo, # noqa: F401 CachedRevisionInfo, # noqa: F401 CacheNotFound, # noqa: F401 CorruptedCacheException, # noqa: F401 DeleteCacheStrategy, # noqa: F401 HFCacheInfo, # noqa: F401 HfFolder, # noqa: F401 cached_assets_path, # noqa: F401 configure_http_backend, # noqa: F401 dump_environment_info, # noqa: F401 get_session, # noqa: F401 get_token, # noqa: F401 logging, # noqa: F401 scan_cache_dir, # noqa: F401 ) huggingface_hub-0.31.1/src/huggingface_hub/_commit_api.py000066400000000000000000001150401500667546600234470ustar00rootroot00000000000000""" Type definitions and utilities for the `create_commit` API """ import base64 import io import math import os import warnings from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass, field from itertools import groupby from pathlib import Path, PurePosixPath from typing import TYPE_CHECKING, Any, BinaryIO, Dict, Iterable, Iterator, List, Literal, Optional, Tuple, Union from tqdm.contrib.concurrent import thread_map from . import constants from .errors import EntryNotFoundError, HfHubHTTPError, XetAuthorizationError, XetRefreshTokenError from .file_download import hf_hub_url from .lfs import UploadInfo, lfs_upload, post_lfs_batch_info from .utils import ( FORBIDDEN_FOLDERS, XetTokenType, chunk_iterable, fetch_xet_connection_info_from_repo_info, get_session, hf_raise_for_status, logging, sha, tqdm_stream_file, validate_hf_hub_args, ) from .utils import tqdm as hf_tqdm from .utils.tqdm import _get_progress_bar_context if TYPE_CHECKING: from .hf_api import RepoFile logger = logging.get_logger(__name__) UploadMode = Literal["lfs", "regular"] # Max is 1,000 per request on the Hub for HfApi.get_paths_info # Otherwise we get: # HfHubHTTPError: 413 Client Error: Payload Too Large for url: https://huggingface.co/api/datasets/xxx (Request ID: xxx)\n\ntoo many parameters # See https://github.com/huggingface/huggingface_hub/issues/1503 FETCH_LFS_BATCH_SIZE = 500 UPLOAD_BATCH_MAX_NUM_FILES = 256 @dataclass class CommitOperationDelete: """ Data structure holding necessary info to delete a file or a folder from a repository on the Hub. Args: path_in_repo (`str`): Relative filepath in the repo, for example: `"checkpoints/1fec34a/weights.bin"` for a file or `"checkpoints/1fec34a/"` for a folder. is_folder (`bool` or `Literal["auto"]`, *optional*) Whether the Delete Operation applies to a folder or not. If "auto", the path type (file or folder) is guessed automatically by looking if path ends with a "/" (folder) or not (file). To explicitly set the path type, you can set `is_folder=True` or `is_folder=False`. """ path_in_repo: str is_folder: Union[bool, Literal["auto"]] = "auto" def __post_init__(self): self.path_in_repo = _validate_path_in_repo(self.path_in_repo) if self.is_folder == "auto": self.is_folder = self.path_in_repo.endswith("/") if not isinstance(self.is_folder, bool): raise ValueError( f"Wrong value for `is_folder`. Must be one of [`True`, `False`, `'auto'`]. Got '{self.is_folder}'." ) @dataclass class CommitOperationCopy: """ Data structure holding necessary info to copy a file in a repository on the Hub. Limitations: - Only LFS files can be copied. To copy a regular file, you need to download it locally and re-upload it - Cross-repository copies are not supported. Note: you can combine a [`CommitOperationCopy`] and a [`CommitOperationDelete`] to rename an LFS file on the Hub. Args: src_path_in_repo (`str`): Relative filepath in the repo of the file to be copied, e.g. `"checkpoints/1fec34a/weights.bin"`. path_in_repo (`str`): Relative filepath in the repo where to copy the file, e.g. `"checkpoints/1fec34a/weights_copy.bin"`. src_revision (`str`, *optional*): The git revision of the file to be copied. Can be any valid git revision. Default to the target commit revision. """ src_path_in_repo: str path_in_repo: str src_revision: Optional[str] = None # set to the OID of the file to be copied if it has already been uploaded # useful to determine if a commit will be empty or not. _src_oid: Optional[str] = None # set to the OID of the file to copy to if it has already been uploaded # useful to determine if a commit will be empty or not. _dest_oid: Optional[str] = None def __post_init__(self): self.src_path_in_repo = _validate_path_in_repo(self.src_path_in_repo) self.path_in_repo = _validate_path_in_repo(self.path_in_repo) @dataclass class CommitOperationAdd: """ Data structure holding necessary info to upload a file to a repository on the Hub. Args: path_in_repo (`str`): Relative filepath in the repo, for example: `"checkpoints/1fec34a/weights.bin"` path_or_fileobj (`str`, `Path`, `bytes`, or `BinaryIO`): Either: - a path to a local file (as `str` or `pathlib.Path`) to upload - a buffer of bytes (`bytes`) holding the content of the file to upload - a "file object" (subclass of `io.BufferedIOBase`), typically obtained with `open(path, "rb")`. It must support `seek()` and `tell()` methods. Raises: [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If `path_or_fileobj` is not one of `str`, `Path`, `bytes` or `io.BufferedIOBase`. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If `path_or_fileobj` is a `str` or `Path` but not a path to an existing file. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If `path_or_fileobj` is a `io.BufferedIOBase` but it doesn't support both `seek()` and `tell()`. """ path_in_repo: str path_or_fileobj: Union[str, Path, bytes, BinaryIO] upload_info: UploadInfo = field(init=False, repr=False) # Internal attributes # set to "lfs" or "regular" once known _upload_mode: Optional[UploadMode] = field(init=False, repr=False, default=None) # set to True if .gitignore rules prevent the file from being uploaded as LFS # (server-side check) _should_ignore: Optional[bool] = field(init=False, repr=False, default=None) # set to the remote OID of the file if it has already been uploaded # useful to determine if a commit will be empty or not _remote_oid: Optional[str] = field(init=False, repr=False, default=None) # set to True once the file has been uploaded as LFS _is_uploaded: bool = field(init=False, repr=False, default=False) # set to True once the file has been committed _is_committed: bool = field(init=False, repr=False, default=False) def __post_init__(self) -> None: """Validates `path_or_fileobj` and compute `upload_info`.""" self.path_in_repo = _validate_path_in_repo(self.path_in_repo) # Validate `path_or_fileobj` value if isinstance(self.path_or_fileobj, Path): self.path_or_fileobj = str(self.path_or_fileobj) if isinstance(self.path_or_fileobj, str): path_or_fileobj = os.path.normpath(os.path.expanduser(self.path_or_fileobj)) if not os.path.isfile(path_or_fileobj): raise ValueError(f"Provided path: '{path_or_fileobj}' is not a file on the local file system") elif not isinstance(self.path_or_fileobj, (io.BufferedIOBase, bytes)): # ^^ Inspired from: https://stackoverflow.com/questions/44584829/how-to-determine-if-file-is-opened-in-binary-or-text-mode raise ValueError( "path_or_fileobj must be either an instance of str, bytes or" " io.BufferedIOBase. If you passed a file-like object, make sure it is" " in binary mode." ) if isinstance(self.path_or_fileobj, io.BufferedIOBase): try: self.path_or_fileobj.tell() self.path_or_fileobj.seek(0, os.SEEK_CUR) except (OSError, AttributeError) as exc: raise ValueError( "path_or_fileobj is a file-like object but does not implement seek() and tell()" ) from exc # Compute "upload_info" attribute if isinstance(self.path_or_fileobj, str): self.upload_info = UploadInfo.from_path(self.path_or_fileobj) elif isinstance(self.path_or_fileobj, bytes): self.upload_info = UploadInfo.from_bytes(self.path_or_fileobj) else: self.upload_info = UploadInfo.from_fileobj(self.path_or_fileobj) @contextmanager def as_file(self, with_tqdm: bool = False) -> Iterator[BinaryIO]: """ A context manager that yields a file-like object allowing to read the underlying data behind `path_or_fileobj`. Args: with_tqdm (`bool`, *optional*, defaults to `False`): If True, iterating over the file object will display a progress bar. Only works if the file-like object is a path to a file. Pure bytes and buffers are not supported. Example: ```python >>> operation = CommitOperationAdd( ... path_in_repo="remote/dir/weights.h5", ... path_or_fileobj="./local/weights.h5", ... ) CommitOperationAdd(path_in_repo='remote/dir/weights.h5', path_or_fileobj='./local/weights.h5') >>> with operation.as_file() as file: ... content = file.read() >>> with operation.as_file(with_tqdm=True) as file: ... while True: ... data = file.read(1024) ... if not data: ... break config.json: 100%|█████████████████████████| 8.19k/8.19k [00:02<00:00, 3.72kB/s] >>> with operation.as_file(with_tqdm=True) as file: ... requests.put(..., data=file) config.json: 100%|█████████████████████████| 8.19k/8.19k [00:02<00:00, 3.72kB/s] ``` """ if isinstance(self.path_or_fileobj, str) or isinstance(self.path_or_fileobj, Path): if with_tqdm: with tqdm_stream_file(self.path_or_fileobj) as file: yield file else: with open(self.path_or_fileobj, "rb") as file: yield file elif isinstance(self.path_or_fileobj, bytes): yield io.BytesIO(self.path_or_fileobj) elif isinstance(self.path_or_fileobj, io.BufferedIOBase): prev_pos = self.path_or_fileobj.tell() yield self.path_or_fileobj self.path_or_fileobj.seek(prev_pos, io.SEEK_SET) def b64content(self) -> bytes: """ The base64-encoded content of `path_or_fileobj` Returns: `bytes` """ with self.as_file() as file: return base64.b64encode(file.read()) @property def _local_oid(self) -> Optional[str]: """Return the OID of the local file. This OID is then compared to `self._remote_oid` to check if the file has changed compared to the remote one. If the file did not change, we won't upload it again to prevent empty commits. For LFS files, the OID corresponds to the SHA256 of the file content (used a LFS ref). For regular files, the OID corresponds to the SHA1 of the file content. Note: this is slightly different to git OID computation since the oid of an LFS file is usually the git-SHA1 of the pointer file content (not the actual file content). However, using the SHA256 is enough to detect changes and more convenient client-side. """ if self._upload_mode is None: return None elif self._upload_mode == "lfs": return self.upload_info.sha256.hex() else: # Regular file => compute sha1 # => no need to read by chunk since the file is guaranteed to be <=5MB. with self.as_file() as file: return sha.git_hash(file.read()) def _validate_path_in_repo(path_in_repo: str) -> str: # Validate `path_in_repo` value to prevent a server-side issue if path_in_repo.startswith("/"): path_in_repo = path_in_repo[1:] if path_in_repo == "." or path_in_repo == ".." or path_in_repo.startswith("../"): raise ValueError(f"Invalid `path_in_repo` in CommitOperation: '{path_in_repo}'") if path_in_repo.startswith("./"): path_in_repo = path_in_repo[2:] for forbidden in FORBIDDEN_FOLDERS: if any(part == forbidden for part in path_in_repo.split("/")): raise ValueError( f"Invalid `path_in_repo` in CommitOperation: cannot update files under a '{forbidden}/' folder (path:" f" '{path_in_repo}')." ) return path_in_repo CommitOperation = Union[CommitOperationAdd, CommitOperationCopy, CommitOperationDelete] def _warn_on_overwriting_operations(operations: List[CommitOperation]) -> None: """ Warn user when a list of operations is expected to overwrite itself in a single commit. Rules: - If a filepath is updated by multiple `CommitOperationAdd` operations, a warning message is triggered. - If a filepath is updated at least once by a `CommitOperationAdd` and then deleted by a `CommitOperationDelete`, a warning is triggered. - If a `CommitOperationDelete` deletes a filepath that is then updated by a `CommitOperationAdd`, no warning is triggered. This is usually useless (no need to delete before upload) but can happen if a user deletes an entire folder and then add new files to it. """ nb_additions_per_path: Dict[str, int] = defaultdict(int) for operation in operations: path_in_repo = operation.path_in_repo if isinstance(operation, CommitOperationAdd): if nb_additions_per_path[path_in_repo] > 0: warnings.warn( "About to update multiple times the same file in the same commit:" f" '{path_in_repo}'. This can cause undesired inconsistencies in" " your repo." ) nb_additions_per_path[path_in_repo] += 1 for parent in PurePosixPath(path_in_repo).parents: # Also keep track of number of updated files per folder # => warns if deleting a folder overwrite some contained files nb_additions_per_path[str(parent)] += 1 if isinstance(operation, CommitOperationDelete): if nb_additions_per_path[str(PurePosixPath(path_in_repo))] > 0: if operation.is_folder: warnings.warn( "About to delete a folder containing files that have just been" f" updated within the same commit: '{path_in_repo}'. This can" " cause undesired inconsistencies in your repo." ) else: warnings.warn( "About to delete a file that have just been updated within the" f" same commit: '{path_in_repo}'. This can cause undesired" " inconsistencies in your repo." ) @validate_hf_hub_args def _upload_lfs_files( *, additions: List[CommitOperationAdd], repo_type: str, repo_id: str, headers: Dict[str, str], endpoint: Optional[str] = None, num_threads: int = 5, revision: Optional[str] = None, ): """ Uploads the content of `additions` to the Hub using the large file storage protocol. Relevant external documentation: - LFS Batch API: https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md Args: additions (`List` of `CommitOperationAdd`): The files to be uploaded repo_type (`str`): Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`. repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. headers (`Dict[str, str]`): Headers to use for the request, including authorization headers and user agent. num_threads (`int`, *optional*): The number of concurrent threads to use when uploading. Defaults to 5. revision (`str`, *optional*): The git revision to upload to. Raises: [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) If an upload failed for any reason [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If the server returns malformed responses [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) If the LFS batch endpoint returned an HTTP error. """ # Step 1: retrieve upload instructions from the LFS batch endpoint. # Upload instructions are retrieved by chunk of 256 files to avoid reaching # the payload limit. batch_actions: List[Dict] = [] for chunk in chunk_iterable(additions, chunk_size=UPLOAD_BATCH_MAX_NUM_FILES): batch_actions_chunk, batch_errors_chunk = post_lfs_batch_info( upload_infos=[op.upload_info for op in chunk], repo_id=repo_id, repo_type=repo_type, revision=revision, endpoint=endpoint, headers=headers, token=None, # already passed in 'headers' ) # If at least 1 error, we do not retrieve information for other chunks if batch_errors_chunk: message = "\n".join( [ f"Encountered error for file with OID {err.get('oid')}: `{err.get('error', {}).get('message')}" for err in batch_errors_chunk ] ) raise ValueError(f"LFS batch endpoint returned errors:\n{message}") batch_actions += batch_actions_chunk oid2addop = {add_op.upload_info.sha256.hex(): add_op for add_op in additions} # Step 2: ignore files that have already been uploaded filtered_actions = [] for action in batch_actions: if action.get("actions") is None: logger.debug( f"Content of file {oid2addop[action['oid']].path_in_repo} is already" " present upstream - skipping upload." ) else: filtered_actions.append(action) if len(filtered_actions) == 0: logger.debug("No LFS files to upload.") return # Step 3: upload files concurrently according to these instructions def _wrapped_lfs_upload(batch_action) -> None: try: operation = oid2addop[batch_action["oid"]] lfs_upload(operation=operation, lfs_batch_action=batch_action, headers=headers, endpoint=endpoint) except Exception as exc: raise RuntimeError(f"Error while uploading '{operation.path_in_repo}' to the Hub.") from exc if constants.HF_HUB_ENABLE_HF_TRANSFER: logger.debug(f"Uploading {len(filtered_actions)} LFS files to the Hub using `hf_transfer`.") for action in hf_tqdm(filtered_actions, name="huggingface_hub.lfs_upload"): _wrapped_lfs_upload(action) elif len(filtered_actions) == 1: logger.debug("Uploading 1 LFS file to the Hub") _wrapped_lfs_upload(filtered_actions[0]) else: logger.debug( f"Uploading {len(filtered_actions)} LFS files to the Hub using up to {num_threads} threads concurrently" ) thread_map( _wrapped_lfs_upload, filtered_actions, desc=f"Upload {len(filtered_actions)} LFS files", max_workers=num_threads, tqdm_class=hf_tqdm, ) @validate_hf_hub_args def _upload_xet_files( *, additions: List[CommitOperationAdd], repo_type: str, repo_id: str, headers: Dict[str, str], endpoint: Optional[str] = None, revision: Optional[str] = None, create_pr: Optional[bool] = None, ): """ Uploads the content of `additions` to the Hub using the xet storage protocol. This chunks the files and deduplicates the chunks before uploading them to xetcas storage. Args: additions (`List` of `CommitOperationAdd`): The files to be uploaded. repo_type (`str`): Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`. repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. headers (`Dict[str, str]`): Headers to use for the request, including authorization headers and user agent. endpoint: (`str`, *optional*): The endpoint to use for the xetcas service. Defaults to `constants.ENDPOINT`. revision (`str`, *optional*): The git revision to upload to. create_pr (`bool`, *optional*): Whether or not to create a Pull Request with that commit. Raises: [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) If an upload failed for any reason. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If the server returns malformed responses or if the user is unauthorized to upload to xet storage. [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) If the LFS batch endpoint returned an HTTP error. **How it works:** The file download system uses Xet storage, which is a content-addressable storage system that breaks files into chunks for efficient storage and transfer. `hf_xet.upload_files` manages uploading files by: - Taking a list of file paths to upload - Breaking files into smaller chunks for efficient storage - Avoiding duplicate storage by recognizing identical chunks across files - Connecting to a storage server (CAS server) that manages these chunks The upload process works like this: 1. Create a local folder at ~/.cache/huggingface/xet/chunk-cache to store file chunks for reuse. 2. Process files in parallel (up to 8 files at once): 2.1. Read the file content. 2.2. Split the file content into smaller chunks based on content patterns: each chunk gets a unique ID based on what's in it. 2.3. For each chunk: - Check if it already exists in storage. - Skip uploading chunks that already exist. 2.4. Group chunks into larger blocks for efficient transfer. 2.5. Upload these blocks to the storage server. 2.6. Create and upload information about how the file is structured. 3. Return reference files that contain information about the uploaded files, which can be used later to download them. """ if len(additions) == 0: return # at this point, we know that hf_xet is installed from hf_xet import upload_bytes, upload_files try: xet_connection_info = fetch_xet_connection_info_from_repo_info( token_type=XetTokenType.WRITE, repo_id=repo_id, repo_type=repo_type, revision=revision, headers=headers, endpoint=endpoint, params={"create_pr": "1"} if create_pr else None, ) except HfHubHTTPError as e: if e.response.status_code == 401: raise XetAuthorizationError( f"You are unauthorized to upload to xet storage for {repo_type}/{repo_id}. " f"Please check that you have configured your access token with write access to the repo." ) from e raise xet_endpoint = xet_connection_info.endpoint access_token_info = (xet_connection_info.access_token, xet_connection_info.expiration_unix_epoch) def token_refresher() -> Tuple[str, int]: new_xet_connection = fetch_xet_connection_info_from_repo_info( token_type=XetTokenType.WRITE, repo_id=repo_id, repo_type=repo_type, revision=revision, headers=headers, endpoint=endpoint, params={"create_pr": "1"} if create_pr else None, ) if new_xet_connection is None: raise XetRefreshTokenError("Failed to refresh xet token") return new_xet_connection.access_token, new_xet_connection.expiration_unix_epoch num_chunks = math.ceil(len(additions) / UPLOAD_BATCH_MAX_NUM_FILES) num_chunks_num_digits = int(math.log10(num_chunks)) + 1 for i, chunk in enumerate(chunk_iterable(additions, chunk_size=UPLOAD_BATCH_MAX_NUM_FILES)): _chunk = [op for op in chunk] bytes_ops = [op for op in _chunk if isinstance(op.path_or_fileobj, bytes)] paths_ops = [op for op in _chunk if isinstance(op.path_or_fileobj, (str, Path))] expected_size = sum(op.upload_info.size for op in bytes_ops + paths_ops) if num_chunks > 1: description = f"Uploading Batch [{str(i + 1).zfill(num_chunks_num_digits)}/{num_chunks}]..." else: description = "Uploading..." progress_cm = _get_progress_bar_context( desc=description, total=expected_size, initial=0, unit="B", unit_scale=True, name="huggingface_hub.xet_put", log_level=logger.getEffectiveLevel(), ) with progress_cm as progress: def update_progress(increment: int): progress.update(increment) if len(paths_ops) > 0: upload_files( [str(op.path_or_fileobj) for op in paths_ops], xet_endpoint, access_token_info, token_refresher, update_progress, repo_type, ) if len(bytes_ops) > 0: upload_bytes( [op.path_or_fileobj for op in bytes_ops], xet_endpoint, access_token_info, token_refresher, update_progress, repo_type, ) return def _validate_preupload_info(preupload_info: dict): files = preupload_info.get("files") if not isinstance(files, list): raise ValueError("preupload_info is improperly formatted") for file_info in files: if not ( isinstance(file_info, dict) and isinstance(file_info.get("path"), str) and isinstance(file_info.get("uploadMode"), str) and (file_info["uploadMode"] in ("lfs", "regular")) ): raise ValueError("preupload_info is improperly formatted:") return preupload_info @validate_hf_hub_args def _fetch_upload_modes( additions: Iterable[CommitOperationAdd], repo_type: str, repo_id: str, headers: Dict[str, str], revision: str, endpoint: Optional[str] = None, create_pr: bool = False, gitignore_content: Optional[str] = None, ) -> None: """ Requests the Hub "preupload" endpoint to determine whether each input file should be uploaded as a regular git blob, as a git LFS blob, or as a XET file. Input `additions` are mutated in-place with the upload mode. Args: additions (`Iterable` of :class:`CommitOperationAdd`): Iterable of :class:`CommitOperationAdd` describing the files to upload to the Hub. repo_type (`str`): Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`. repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. headers (`Dict[str, str]`): Headers to use for the request, including authorization headers and user agent. revision (`str`): The git revision to upload the files to. Can be any valid git revision. gitignore_content (`str`, *optional*): The content of the `.gitignore` file to know which files should be ignored. The order of priority is to first check if `gitignore_content` is passed, then check if the `.gitignore` file is present in the list of files to commit and finally default to the `.gitignore` file already hosted on the Hub (if any). Raises: [`~utils.HfHubHTTPError`] If the Hub API returned an error. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If the Hub API response is improperly formatted. """ endpoint = endpoint if endpoint is not None else constants.ENDPOINT # Fetch upload mode (LFS or regular) chunk by chunk. upload_modes: Dict[str, UploadMode] = {} should_ignore_info: Dict[str, bool] = {} oid_info: Dict[str, Optional[str]] = {} for chunk in chunk_iterable(additions, 256): payload: Dict = { "files": [ { "path": op.path_in_repo, "sample": base64.b64encode(op.upload_info.sample).decode("ascii"), "size": op.upload_info.size, } for op in chunk ] } if gitignore_content is not None: payload["gitIgnore"] = gitignore_content resp = get_session().post( f"{endpoint}/api/{repo_type}s/{repo_id}/preupload/{revision}", json=payload, headers=headers, params={"create_pr": "1"} if create_pr else None, ) hf_raise_for_status(resp) preupload_info = _validate_preupload_info(resp.json()) upload_modes.update(**{file["path"]: file["uploadMode"] for file in preupload_info["files"]}) should_ignore_info.update(**{file["path"]: file["shouldIgnore"] for file in preupload_info["files"]}) oid_info.update(**{file["path"]: file.get("oid") for file in preupload_info["files"]}) # Set upload mode for each addition operation for addition in additions: addition._upload_mode = upload_modes[addition.path_in_repo] addition._should_ignore = should_ignore_info[addition.path_in_repo] addition._remote_oid = oid_info[addition.path_in_repo] # Empty files cannot be uploaded as LFS (S3 would fail with a 501 Not Implemented) # => empty files are uploaded as "regular" to still allow users to commit them. for addition in additions: if addition.upload_info.size == 0: addition._upload_mode = "regular" @validate_hf_hub_args def _fetch_files_to_copy( copies: Iterable[CommitOperationCopy], repo_type: str, repo_id: str, headers: Dict[str, str], revision: str, endpoint: Optional[str] = None, ) -> Dict[Tuple[str, Optional[str]], Union["RepoFile", bytes]]: """ Fetch information about the files to copy. For LFS files, we only need their metadata (file size and sha256) while for regular files we need to download the raw content from the Hub. Args: copies (`Iterable` of :class:`CommitOperationCopy`): Iterable of :class:`CommitOperationCopy` describing the files to copy on the Hub. repo_type (`str`): Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`. repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. headers (`Dict[str, str]`): Headers to use for the request, including authorization headers and user agent. revision (`str`): The git revision to upload the files to. Can be any valid git revision. Returns: `Dict[Tuple[str, Optional[str]], Union[RepoFile, bytes]]]` Key is the file path and revision of the file to copy. Value is the raw content as bytes (for regular files) or the file information as a RepoFile (for LFS files). Raises: [`~utils.HfHubHTTPError`] If the Hub API returned an error. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If the Hub API response is improperly formatted. """ from .hf_api import HfApi, RepoFolder hf_api = HfApi(endpoint=endpoint, headers=headers) files_to_copy: Dict[Tuple[str, Optional[str]], Union["RepoFile", bytes]] = {} # Store (path, revision) -> oid mapping oid_info: Dict[Tuple[str, Optional[str]], Optional[str]] = {} # 1. Fetch OIDs for destination paths in batches. dest_paths = [op.path_in_repo for op in copies] for offset in range(0, len(dest_paths), FETCH_LFS_BATCH_SIZE): dest_repo_files = hf_api.get_paths_info( repo_id=repo_id, paths=dest_paths[offset : offset + FETCH_LFS_BATCH_SIZE], revision=revision, repo_type=repo_type, ) for file in dest_repo_files: if not isinstance(file, RepoFolder): oid_info[(file.path, revision)] = file.blob_id # 2. Group by source revision and fetch source file info in batches. for src_revision, operations in groupby(copies, key=lambda op: op.src_revision): operations = list(operations) # type: ignore src_paths = [op.src_path_in_repo for op in operations] for offset in range(0, len(src_paths), FETCH_LFS_BATCH_SIZE): src_repo_files = hf_api.get_paths_info( repo_id=repo_id, paths=src_paths[offset : offset + FETCH_LFS_BATCH_SIZE], revision=src_revision or revision, repo_type=repo_type, ) for src_repo_file in src_repo_files: if isinstance(src_repo_file, RepoFolder): raise NotImplementedError("Copying a folder is not implemented.") oid_info[(src_repo_file.path, src_revision)] = src_repo_file.blob_id # If it's an LFS file, store the RepoFile object. Otherwise, download raw bytes. if src_repo_file.lfs: files_to_copy[(src_repo_file.path, src_revision)] = src_repo_file else: # TODO: (optimization) download regular files to copy concurrently url = hf_hub_url( endpoint=endpoint, repo_type=repo_type, repo_id=repo_id, revision=src_revision or revision, filename=src_repo_file.path, ) response = get_session().get(url, headers=headers) hf_raise_for_status(response) files_to_copy[(src_repo_file.path, src_revision)] = response.content # 3. Ensure all operations found a corresponding file in the Hub # and track src/dest OIDs for each operation. for operation in operations: if (operation.src_path_in_repo, src_revision) not in files_to_copy: raise EntryNotFoundError( f"Cannot copy {operation.src_path_in_repo} at revision " f"{src_revision or revision}: file is missing on repo." ) operation._src_oid = oid_info.get((operation.src_path_in_repo, operation.src_revision)) operation._dest_oid = oid_info.get((operation.path_in_repo, revision)) return files_to_copy def _prepare_commit_payload( operations: Iterable[CommitOperation], files_to_copy: Dict[Tuple[str, Optional[str]], Union["RepoFile", bytes]], commit_message: str, commit_description: Optional[str] = None, parent_commit: Optional[str] = None, ) -> Iterable[Dict[str, Any]]: """ Builds the payload to POST to the `/commit` API of the Hub. Payload is returned as an iterator so that it can be streamed as a ndjson in the POST request. For more information, see: - https://github.com/huggingface/huggingface_hub/issues/1085#issuecomment-1265208073 - http://ndjson.org/ """ commit_description = commit_description if commit_description is not None else "" # 1. Send a header item with the commit metadata header_value = {"summary": commit_message, "description": commit_description} if parent_commit is not None: header_value["parentCommit"] = parent_commit yield {"key": "header", "value": header_value} nb_ignored_files = 0 # 2. Send operations, one per line for operation in operations: # Skip ignored files if isinstance(operation, CommitOperationAdd) and operation._should_ignore: logger.debug(f"Skipping file '{operation.path_in_repo}' in commit (ignored by gitignore file).") nb_ignored_files += 1 continue # 2.a. Case adding a regular file if isinstance(operation, CommitOperationAdd) and operation._upload_mode == "regular": yield { "key": "file", "value": { "content": operation.b64content().decode(), "path": operation.path_in_repo, "encoding": "base64", }, } # 2.b. Case adding an LFS file elif isinstance(operation, CommitOperationAdd) and operation._upload_mode == "lfs": yield { "key": "lfsFile", "value": { "path": operation.path_in_repo, "algo": "sha256", "oid": operation.upload_info.sha256.hex(), "size": operation.upload_info.size, }, } # 2.c. Case deleting a file or folder elif isinstance(operation, CommitOperationDelete): yield { "key": "deletedFolder" if operation.is_folder else "deletedFile", "value": {"path": operation.path_in_repo}, } # 2.d. Case copying a file or folder elif isinstance(operation, CommitOperationCopy): file_to_copy = files_to_copy[(operation.src_path_in_repo, operation.src_revision)] if isinstance(file_to_copy, bytes): yield { "key": "file", "value": { "content": base64.b64encode(file_to_copy).decode(), "path": operation.path_in_repo, "encoding": "base64", }, } elif file_to_copy.lfs: yield { "key": "lfsFile", "value": { "path": operation.path_in_repo, "algo": "sha256", "oid": file_to_copy.lfs.sha256, }, } else: raise ValueError( "Malformed files_to_copy (should be raw file content as bytes or RepoFile objects with LFS info." ) # 2.e. Never expected to happen else: raise ValueError( f"Unknown operation to commit. Operation: {operation}. Upload mode:" f" {getattr(operation, '_upload_mode', None)}" ) if nb_ignored_files > 0: logger.info(f"Skipped {nb_ignored_files} file(s) in commit (ignored by gitignore file).") huggingface_hub-0.31.1/src/huggingface_hub/_commit_scheduler.py000066400000000000000000000345271500667546600246660ustar00rootroot00000000000000import atexit import logging import os import time from concurrent.futures import Future from dataclasses import dataclass from io import SEEK_END, SEEK_SET, BytesIO from pathlib import Path from threading import Lock, Thread from typing import Dict, List, Optional, Union from .hf_api import DEFAULT_IGNORE_PATTERNS, CommitInfo, CommitOperationAdd, HfApi from .utils import filter_repo_objects logger = logging.getLogger(__name__) @dataclass(frozen=True) class _FileToUpload: """Temporary dataclass to store info about files to upload. Not meant to be used directly.""" local_path: Path path_in_repo: str size_limit: int last_modified: float class CommitScheduler: """ Scheduler to upload a local folder to the Hub at regular intervals (e.g. push to hub every 5 minutes). The recommended way to use the scheduler is to use it as a context manager. This ensures that the scheduler is properly stopped and the last commit is triggered when the script ends. The scheduler can also be stopped manually with the `stop` method. Checkout the [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#scheduled-uploads) to learn more about how to use it. Args: repo_id (`str`): The id of the repo to commit to. folder_path (`str` or `Path`): Path to the local folder to upload regularly. every (`int` or `float`, *optional*): The number of minutes between each commit. Defaults to 5 minutes. path_in_repo (`str`, *optional*): Relative path of the directory in the repo, for example: `"checkpoints/"`. Defaults to the root folder of the repository. repo_type (`str`, *optional*): The type of the repo to commit to. Defaults to `model`. revision (`str`, *optional*): The revision of the repo to commit to. Defaults to `main`. private (`bool`, *optional*): Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists. token (`str`, *optional*): The token to use to commit to the repo. Defaults to the token saved on the machine. allow_patterns (`List[str]` or `str`, *optional*): If provided, only files matching at least one pattern are uploaded. ignore_patterns (`List[str]` or `str`, *optional*): If provided, files matching any of the patterns are not uploaded. squash_history (`bool`, *optional*): Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is useful to avoid degraded performances on the repo when it grows too large. hf_api (`HfApi`, *optional*): The [`HfApi`] client to use to commit to the Hub. Can be set with custom settings (user agent, token,...). Example: ```py >>> from pathlib import Path >>> from huggingface_hub import CommitScheduler # Scheduler uploads every 10 minutes >>> csv_path = Path("watched_folder/data.csv") >>> CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path=csv_path.parent, every=10) >>> with csv_path.open("a") as f: ... f.write("first line") # Some time later (...) >>> with csv_path.open("a") as f: ... f.write("second line") ``` Example using a context manager: ```py >>> from pathlib import Path >>> from huggingface_hub import CommitScheduler >>> with CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path="watched_folder", every=10) as scheduler: ... csv_path = Path("watched_folder/data.csv") ... with csv_path.open("a") as f: ... f.write("first line") ... (...) ... with csv_path.open("a") as f: ... f.write("second line") # Scheduler is now stopped and last commit have been triggered ``` """ def __init__( self, *, repo_id: str, folder_path: Union[str, Path], every: Union[int, float] = 5, path_in_repo: Optional[str] = None, repo_type: Optional[str] = None, revision: Optional[str] = None, private: Optional[bool] = None, token: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, squash_history: bool = False, hf_api: Optional["HfApi"] = None, ) -> None: self.api = hf_api or HfApi(token=token) # Folder self.folder_path = Path(folder_path).expanduser().resolve() self.path_in_repo = path_in_repo or "" self.allow_patterns = allow_patterns if ignore_patterns is None: ignore_patterns = [] elif isinstance(ignore_patterns, str): ignore_patterns = [ignore_patterns] self.ignore_patterns = ignore_patterns + DEFAULT_IGNORE_PATTERNS if self.folder_path.is_file(): raise ValueError(f"'folder_path' must be a directory, not a file: '{self.folder_path}'.") self.folder_path.mkdir(parents=True, exist_ok=True) # Repository repo_url = self.api.create_repo(repo_id=repo_id, private=private, repo_type=repo_type, exist_ok=True) self.repo_id = repo_url.repo_id self.repo_type = repo_type self.revision = revision self.token = token # Keep track of already uploaded files self.last_uploaded: Dict[Path, float] = {} # key is local path, value is timestamp # Scheduler if not every > 0: raise ValueError(f"'every' must be a positive integer, not '{every}'.") self.lock = Lock() self.every = every self.squash_history = squash_history logger.info(f"Scheduled job to push '{self.folder_path}' to '{self.repo_id}' every {self.every} minutes.") self._scheduler_thread = Thread(target=self._run_scheduler, daemon=True) self._scheduler_thread.start() atexit.register(self._push_to_hub) self.__stopped = False def stop(self) -> None: """Stop the scheduler. A stopped scheduler cannot be restarted. Mostly for tests purposes. """ self.__stopped = True def __enter__(self) -> "CommitScheduler": return self def __exit__(self, exc_type, exc_value, traceback) -> None: # Upload last changes before exiting self.trigger().result() self.stop() return def _run_scheduler(self) -> None: """Dumb thread waiting between each scheduled push to Hub.""" while True: self.last_future = self.trigger() time.sleep(self.every * 60) if self.__stopped: break def trigger(self) -> Future: """Trigger a `push_to_hub` and return a future. This method is automatically called every `every` minutes. You can also call it manually to trigger a commit immediately, without waiting for the next scheduled commit. """ return self.api.run_as_future(self._push_to_hub) def _push_to_hub(self) -> Optional[CommitInfo]: if self.__stopped: # If stopped, already scheduled commits are ignored return None logger.info("(Background) scheduled commit triggered.") try: value = self.push_to_hub() if self.squash_history: logger.info("(Background) squashing repo history.") self.api.super_squash_history(repo_id=self.repo_id, repo_type=self.repo_type, branch=self.revision) return value except Exception as e: logger.error(f"Error while pushing to Hub: {e}") # Depending on the setup, error might be silenced raise def push_to_hub(self) -> Optional[CommitInfo]: """ Push folder to the Hub and return the commit info. This method is not meant to be called directly. It is run in the background by the scheduler, respecting a queue mechanism to avoid concurrent commits. Making a direct call to the method might lead to concurrency issues. The default behavior of `push_to_hub` is to assume an append-only folder. It lists all files in the folder and uploads only changed files. If no changes are found, the method returns without committing anything. If you want to change this behavior, you can inherit from [`CommitScheduler`] and override this method. This can be useful for example to compress data together in a single file before committing. For more details and examples, check out our [integration guide](https://huggingface.co/docs/huggingface_hub/main/en/guides/upload#scheduled-uploads). """ # Check files to upload (with lock) with self.lock: logger.debug("Listing files to upload for scheduled commit.") # List files from folder (taken from `_prepare_upload_folder_additions`) relpath_to_abspath = { path.relative_to(self.folder_path).as_posix(): path for path in sorted(self.folder_path.glob("**/*")) # sorted to be deterministic if path.is_file() } prefix = f"{self.path_in_repo.strip('/')}/" if self.path_in_repo else "" # Filter with pattern + filter out unchanged files + retrieve current file size files_to_upload: List[_FileToUpload] = [] for relpath in filter_repo_objects( relpath_to_abspath.keys(), allow_patterns=self.allow_patterns, ignore_patterns=self.ignore_patterns ): local_path = relpath_to_abspath[relpath] stat = local_path.stat() if self.last_uploaded.get(local_path) is None or self.last_uploaded[local_path] != stat.st_mtime: files_to_upload.append( _FileToUpload( local_path=local_path, path_in_repo=prefix + relpath, size_limit=stat.st_size, last_modified=stat.st_mtime, ) ) # Return if nothing to upload if len(files_to_upload) == 0: logger.debug("Dropping schedule commit: no changed file to upload.") return None # Convert `_FileToUpload` as `CommitOperationAdd` (=> compute file shas + limit to file size) logger.debug("Removing unchanged files since previous scheduled commit.") add_operations = [ CommitOperationAdd( # Cap the file to its current size, even if the user append data to it while a scheduled commit is happening path_or_fileobj=PartialFileIO(file_to_upload.local_path, size_limit=file_to_upload.size_limit), path_in_repo=file_to_upload.path_in_repo, ) for file_to_upload in files_to_upload ] # Upload files (append mode expected - no need for lock) logger.debug("Uploading files for scheduled commit.") commit_info = self.api.create_commit( repo_id=self.repo_id, repo_type=self.repo_type, operations=add_operations, commit_message="Scheduled Commit", revision=self.revision, ) # Successful commit: keep track of the latest "last_modified" for each file for file in files_to_upload: self.last_uploaded[file.local_path] = file.last_modified return commit_info class PartialFileIO(BytesIO): """A file-like object that reads only the first part of a file. Useful to upload a file to the Hub when the user might still be appending data to it. Only the first part of the file is uploaded (i.e. the part that was available when the filesystem was first scanned). In practice, only used internally by the CommitScheduler to regularly push a folder to the Hub with minimal disturbance for the user. The object is passed to `CommitOperationAdd`. Only supports `read`, `tell` and `seek` methods. Args: file_path (`str` or `Path`): Path to the file to read. size_limit (`int`): The maximum number of bytes to read from the file. If the file is larger than this, only the first part will be read (and uploaded). """ def __init__(self, file_path: Union[str, Path], size_limit: int) -> None: self._file_path = Path(file_path) self._file = self._file_path.open("rb") self._size_limit = min(size_limit, os.fstat(self._file.fileno()).st_size) def __del__(self) -> None: self._file.close() return super().__del__() def __repr__(self) -> str: return f"" def __len__(self) -> int: return self._size_limit def __getattribute__(self, name: str): if name.startswith("_") or name in ("read", "tell", "seek"): # only 3 public methods supported return super().__getattribute__(name) raise NotImplementedError(f"PartialFileIO does not support '{name}'.") def tell(self) -> int: """Return the current file position.""" return self._file.tell() def seek(self, __offset: int, __whence: int = SEEK_SET) -> int: """Change the stream position to the given offset. Behavior is the same as a regular file, except that the position is capped to the size limit. """ if __whence == SEEK_END: # SEEK_END => set from the truncated end __offset = len(self) + __offset __whence = SEEK_SET pos = self._file.seek(__offset, __whence) if pos > self._size_limit: return self._file.seek(self._size_limit) return pos def read(self, __size: Optional[int] = -1) -> bytes: """Read at most `__size` bytes from the file. Behavior is the same as a regular file, except that it is capped to the size limit. """ current = self._file.tell() if __size is None or __size < 0: # Read until file limit truncated_size = self._size_limit - current else: # Read until file limit or __size truncated_size = min(__size, self._size_limit - current) return self._file.read(truncated_size) huggingface_hub-0.31.1/src/huggingface_hub/_inference_endpoints.py000066400000000000000000000420051500667546600253470ustar00rootroot00000000000000import time from dataclasses import dataclass, field from datetime import datetime from enum import Enum from typing import TYPE_CHECKING, Dict, Optional, Union from huggingface_hub.errors import InferenceEndpointError, InferenceEndpointTimeoutError from .utils import get_session, logging, parse_datetime if TYPE_CHECKING: from .hf_api import HfApi from .inference._client import InferenceClient from .inference._generated._async_client import AsyncInferenceClient logger = logging.get_logger(__name__) class InferenceEndpointStatus(str, Enum): PENDING = "pending" INITIALIZING = "initializing" UPDATING = "updating" UPDATE_FAILED = "updateFailed" RUNNING = "running" PAUSED = "paused" FAILED = "failed" SCALED_TO_ZERO = "scaledToZero" class InferenceEndpointType(str, Enum): PUBlIC = "public" PROTECTED = "protected" PRIVATE = "private" @dataclass class InferenceEndpoint: """ Contains information about a deployed Inference Endpoint. Args: name (`str`): The unique name of the Inference Endpoint. namespace (`str`): The namespace where the Inference Endpoint is located. repository (`str`): The name of the model repository deployed on this Inference Endpoint. status ([`InferenceEndpointStatus`]): The current status of the Inference Endpoint. url (`str`, *optional*): The URL of the Inference Endpoint, if available. Only a deployed Inference Endpoint will have a URL. framework (`str`): The machine learning framework used for the model. revision (`str`): The specific model revision deployed on the Inference Endpoint. task (`str`): The task associated with the deployed model. created_at (`datetime.datetime`): The timestamp when the Inference Endpoint was created. updated_at (`datetime.datetime`): The timestamp of the last update of the Inference Endpoint. type ([`InferenceEndpointType`]): The type of the Inference Endpoint (public, protected, private). raw (`Dict`): The raw dictionary data returned from the API. token (`str` or `bool`, *optional*): Authentication token for the Inference Endpoint, if set when requesting the API. Will default to the locally saved token if not provided. Pass `token=False` if you don't want to send your token to the server. Example: ```python >>> from huggingface_hub import get_inference_endpoint >>> endpoint = get_inference_endpoint("my-text-to-image") >>> endpoint InferenceEndpoint(name='my-text-to-image', ...) # Get status >>> endpoint.status 'running' >>> endpoint.url 'https://my-text-to-image.region.vendor.endpoints.huggingface.cloud' # Run inference >>> endpoint.client.text_to_image(...) # Pause endpoint to save $$$ >>> endpoint.pause() # ... # Resume and wait for deployment >>> endpoint.resume() >>> endpoint.wait() >>> endpoint.client.text_to_image(...) ``` """ # Field in __repr__ name: str = field(init=False) namespace: str repository: str = field(init=False) status: InferenceEndpointStatus = field(init=False) url: Optional[str] = field(init=False) # Other fields framework: str = field(repr=False, init=False) revision: str = field(repr=False, init=False) task: str = field(repr=False, init=False) created_at: datetime = field(repr=False, init=False) updated_at: datetime = field(repr=False, init=False) type: InferenceEndpointType = field(repr=False, init=False) # Raw dict from the API raw: Dict = field(repr=False) # Internal fields _token: Union[str, bool, None] = field(repr=False, compare=False) _api: "HfApi" = field(repr=False, compare=False) @classmethod def from_raw( cls, raw: Dict, namespace: str, token: Union[str, bool, None] = None, api: Optional["HfApi"] = None ) -> "InferenceEndpoint": """Initialize object from raw dictionary.""" if api is None: from .hf_api import HfApi api = HfApi() if token is None: token = api.token # All other fields are populated in __post_init__ return cls(raw=raw, namespace=namespace, _token=token, _api=api) def __post_init__(self) -> None: """Populate fields from raw dictionary.""" self._populate_from_raw() @property def client(self) -> "InferenceClient": """Returns a client to make predictions on this Inference Endpoint. Returns: [`InferenceClient`]: an inference client pointing to the deployed endpoint. Raises: [`InferenceEndpointError`]: If the Inference Endpoint is not yet deployed. """ if self.url is None: raise InferenceEndpointError( "Cannot create a client for this Inference Endpoint as it is not yet deployed. " "Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again." ) from .inference._client import InferenceClient return InferenceClient( model=self.url, token=self._token, # type: ignore[arg-type] # boolean token shouldn't be possible. In practice it's ok. ) @property def async_client(self) -> "AsyncInferenceClient": """Returns a client to make predictions on this Inference Endpoint. Returns: [`AsyncInferenceClient`]: an asyncio-compatible inference client pointing to the deployed endpoint. Raises: [`InferenceEndpointError`]: If the Inference Endpoint is not yet deployed. """ if self.url is None: raise InferenceEndpointError( "Cannot create a client for this Inference Endpoint as it is not yet deployed. " "Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again." ) from .inference._generated._async_client import AsyncInferenceClient return AsyncInferenceClient( model=self.url, token=self._token, # type: ignore[arg-type] # boolean token shouldn't be possible. In practice it's ok. ) def wait(self, timeout: Optional[int] = None, refresh_every: int = 5) -> "InferenceEndpoint": """Wait for the Inference Endpoint to be deployed. Information from the server will be fetched every 1s. If the Inference Endpoint is not deployed after `timeout` seconds, a [`InferenceEndpointTimeoutError`] will be raised. The [`InferenceEndpoint`] will be mutated in place with the latest data. Args: timeout (`int`, *optional*): The maximum time to wait for the Inference Endpoint to be deployed, in seconds. If `None`, will wait indefinitely. refresh_every (`int`, *optional*): The time to wait between each fetch of the Inference Endpoint status, in seconds. Defaults to 5s. Returns: [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. Raises: [`InferenceEndpointError`] If the Inference Endpoint ended up in a failed state. [`InferenceEndpointTimeoutError`] If the Inference Endpoint is not deployed after `timeout` seconds. """ if timeout is not None and timeout < 0: raise ValueError("`timeout` cannot be negative.") if refresh_every <= 0: raise ValueError("`refresh_every` must be positive.") start = time.time() while True: if self.status == InferenceEndpointStatus.FAILED: raise InferenceEndpointError( f"Inference Endpoint {self.name} failed to deploy. Please check the logs for more information." ) if self.status == InferenceEndpointStatus.UPDATE_FAILED: raise InferenceEndpointError( f"Inference Endpoint {self.name} failed to update. Please check the logs for more information." ) if self.status == InferenceEndpointStatus.RUNNING and self.url is not None: # Verify the endpoint is actually reachable response = get_session().get(self.url, headers=self._api._build_hf_headers(token=self._token)) if response.status_code == 200: logger.info("Inference Endpoint is ready to be used.") return self if timeout is not None: if time.time() - start > timeout: raise InferenceEndpointTimeoutError("Timeout while waiting for Inference Endpoint to be deployed.") logger.info(f"Inference Endpoint is not deployed yet ({self.status}). Waiting {refresh_every}s...") time.sleep(refresh_every) self.fetch() def fetch(self) -> "InferenceEndpoint": """Fetch latest information about the Inference Endpoint. Returns: [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. """ obj = self._api.get_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type] self.raw = obj.raw self._populate_from_raw() return self def update( self, *, # Compute update accelerator: Optional[str] = None, instance_size: Optional[str] = None, instance_type: Optional[str] = None, min_replica: Optional[int] = None, max_replica: Optional[int] = None, scale_to_zero_timeout: Optional[int] = None, # Model update repository: Optional[str] = None, framework: Optional[str] = None, revision: Optional[str] = None, task: Optional[str] = None, custom_image: Optional[Dict] = None, secrets: Optional[Dict[str, str]] = None, ) -> "InferenceEndpoint": """Update the Inference Endpoint. This method allows the update of either the compute configuration, the deployed model, or both. All arguments are optional but at least one must be provided. This is an alias for [`HfApi.update_inference_endpoint`]. The current object is mutated in place with the latest data from the server. Args: accelerator (`str`, *optional*): The hardware accelerator to be used for inference (e.g. `"cpu"`). instance_size (`str`, *optional*): The size or type of the instance to be used for hosting the model (e.g. `"x4"`). instance_type (`str`, *optional*): The cloud instance type where the Inference Endpoint will be deployed (e.g. `"intel-icl"`). min_replica (`int`, *optional*): The minimum number of replicas (instances) to keep running for the Inference Endpoint. max_replica (`int`, *optional*): The maximum number of replicas (instances) to scale to for the Inference Endpoint. scale_to_zero_timeout (`int`, *optional*): The duration in minutes before an inactive endpoint is scaled to zero. repository (`str`, *optional*): The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`). framework (`str`, *optional*): The machine learning framework used for the model (e.g. `"custom"`). revision (`str`, *optional*): The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`). task (`str`, *optional*): The task on which to deploy the model (e.g. `"text-classification"`). custom_image (`Dict`, *optional*): A custom Docker image to use for the Inference Endpoint. This is useful if you want to deploy an Inference Endpoint running on the `text-generation-inference` (TGI) framework (see examples). secrets (`Dict[str, str]`, *optional*): Secret values to inject in the container environment. Returns: [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. """ # Make API call obj = self._api.update_inference_endpoint( name=self.name, namespace=self.namespace, accelerator=accelerator, instance_size=instance_size, instance_type=instance_type, min_replica=min_replica, max_replica=max_replica, scale_to_zero_timeout=scale_to_zero_timeout, repository=repository, framework=framework, revision=revision, task=task, custom_image=custom_image, secrets=secrets, token=self._token, # type: ignore [arg-type] ) # Mutate current object self.raw = obj.raw self._populate_from_raw() return self def pause(self) -> "InferenceEndpoint": """Pause the Inference Endpoint. A paused Inference Endpoint will not be charged. It can be resumed at any time using [`InferenceEndpoint.resume`]. This is different than scaling the Inference Endpoint to zero with [`InferenceEndpoint.scale_to_zero`], which would be automatically restarted when a request is made to it. This is an alias for [`HfApi.pause_inference_endpoint`]. The current object is mutated in place with the latest data from the server. Returns: [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. """ obj = self._api.pause_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type] self.raw = obj.raw self._populate_from_raw() return self def resume(self, running_ok: bool = True) -> "InferenceEndpoint": """Resume the Inference Endpoint. This is an alias for [`HfApi.resume_inference_endpoint`]. The current object is mutated in place with the latest data from the server. Args: running_ok (`bool`, *optional*): If `True`, the method will not raise an error if the Inference Endpoint is already running. Defaults to `True`. Returns: [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. """ obj = self._api.resume_inference_endpoint( name=self.name, namespace=self.namespace, running_ok=running_ok, token=self._token ) # type: ignore [arg-type] self.raw = obj.raw self._populate_from_raw() return self def scale_to_zero(self) -> "InferenceEndpoint": """Scale Inference Endpoint to zero. An Inference Endpoint scaled to zero will not be charged. It will be resume on the next request to it, with a cold start delay. This is different than pausing the Inference Endpoint with [`InferenceEndpoint.pause`], which would require a manual resume with [`InferenceEndpoint.resume`]. This is an alias for [`HfApi.scale_to_zero_inference_endpoint`]. The current object is mutated in place with the latest data from the server. Returns: [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. """ obj = self._api.scale_to_zero_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type] self.raw = obj.raw self._populate_from_raw() return self def delete(self) -> None: """Delete the Inference Endpoint. This operation is not reversible. If you don't want to be charged for an Inference Endpoint, it is preferable to pause it with [`InferenceEndpoint.pause`] or scale it to zero with [`InferenceEndpoint.scale_to_zero`]. This is an alias for [`HfApi.delete_inference_endpoint`]. """ self._api.delete_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type] def _populate_from_raw(self) -> None: """Populate fields from raw dictionary. Called in __post_init__ + each time the Inference Endpoint is updated. """ # Repr fields self.name = self.raw["name"] self.repository = self.raw["model"]["repository"] self.status = self.raw["status"]["state"] self.url = self.raw["status"].get("url") # Other fields self.framework = self.raw["model"]["framework"] self.revision = self.raw["model"]["revision"] self.task = self.raw["model"]["task"] self.created_at = parse_datetime(self.raw["status"]["createdAt"]) self.updated_at = parse_datetime(self.raw["status"]["updatedAt"]) self.type = self.raw["type"] huggingface_hub-0.31.1/src/huggingface_hub/_local_folder.py000066400000000000000000000402771500667546600237640ustar00rootroot00000000000000# coding=utf-8 # Copyright 2024-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains utilities to handle the `../.cache/huggingface` folder in local directories. First discussed in https://github.com/huggingface/huggingface_hub/issues/1738 to store download metadata when downloading files from the hub to a local directory (without using the cache). ./.cache/huggingface folder structure: [4.0K] data ├── [4.0K] .cache │ └── [4.0K] huggingface │ └── [4.0K] download │ ├── [ 16] file.parquet.metadata │ ├── [ 16] file.txt.metadata │ └── [4.0K] folder │ └── [ 16] file.parquet.metadata │ ├── [6.5G] file.parquet ├── [1.5K] file.txt └── [4.0K] folder └── [ 16] file.parquet Download metadata file structure: ``` # file.txt.metadata 11c5a3d5811f50298f278a704980280950aedb10 a16a55fda99d2f2e7b69cce5cf93ff4ad3049930 1712656091.123 # file.parquet.metadata 11c5a3d5811f50298f278a704980280950aedb10 7c5d3f4b8b76583b422fcb9189ad6c89d5d97a094541ce8932dce3ecabde1421 1712656091.123 } ``` """ import base64 import hashlib import logging import os import time from dataclasses import dataclass from pathlib import Path from typing import Optional from .utils import WeakFileLock logger = logging.getLogger(__name__) @dataclass class LocalDownloadFilePaths: """ Paths to the files related to a download process in a local dir. Returned by [`get_local_download_paths`]. Attributes: file_path (`Path`): Path where the file will be saved. lock_path (`Path`): Path to the lock file used to ensure atomicity when reading/writing metadata. metadata_path (`Path`): Path to the metadata file. """ file_path: Path lock_path: Path metadata_path: Path def incomplete_path(self, etag: str) -> Path: """Return the path where a file will be temporarily downloaded before being moved to `file_path`.""" return self.metadata_path.parent / f"{_short_hash(self.metadata_path.name)}.{etag}.incomplete" @dataclass(frozen=True) class LocalUploadFilePaths: """ Paths to the files related to an upload process in a local dir. Returned by [`get_local_upload_paths`]. Attributes: path_in_repo (`str`): Path of the file in the repo. file_path (`Path`): Path where the file will be saved. lock_path (`Path`): Path to the lock file used to ensure atomicity when reading/writing metadata. metadata_path (`Path`): Path to the metadata file. """ path_in_repo: str file_path: Path lock_path: Path metadata_path: Path @dataclass class LocalDownloadFileMetadata: """ Metadata about a file in the local directory related to a download process. Attributes: filename (`str`): Path of the file in the repo. commit_hash (`str`): Commit hash of the file in the repo. etag (`str`): ETag of the file in the repo. Used to check if the file has changed. For LFS files, this is the sha256 of the file. For regular files, it corresponds to the git hash. timestamp (`int`): Unix timestamp of when the metadata was saved i.e. when the metadata was accurate. """ filename: str commit_hash: str etag: str timestamp: float @dataclass class LocalUploadFileMetadata: """ Metadata about a file in the local directory related to an upload process. """ size: int # Default values correspond to "we don't know yet" timestamp: Optional[float] = None should_ignore: Optional[bool] = None sha256: Optional[str] = None upload_mode: Optional[str] = None is_uploaded: bool = False is_committed: bool = False def save(self, paths: LocalUploadFilePaths) -> None: """Save the metadata to disk.""" with WeakFileLock(paths.lock_path): with paths.metadata_path.open("w") as f: new_timestamp = time.time() f.write(str(new_timestamp) + "\n") f.write(str(self.size)) # never None f.write("\n") if self.should_ignore is not None: f.write(str(int(self.should_ignore))) f.write("\n") if self.sha256 is not None: f.write(self.sha256) f.write("\n") if self.upload_mode is not None: f.write(self.upload_mode) f.write("\n") f.write(str(int(self.is_uploaded)) + "\n") f.write(str(int(self.is_committed)) + "\n") self.timestamp = new_timestamp def get_local_download_paths(local_dir: Path, filename: str) -> LocalDownloadFilePaths: """Compute paths to the files related to a download process. Folders containing the paths are all guaranteed to exist. Args: local_dir (`Path`): Path to the local directory in which files are downloaded. filename (`str`): Path of the file in the repo. Return: [`LocalDownloadFilePaths`]: the paths to the files (file_path, lock_path, metadata_path, incomplete_path). """ # filename is the path in the Hub repository (separated by '/') # make sure to have a cross platform transcription sanitized_filename = os.path.join(*filename.split("/")) if os.name == "nt": if sanitized_filename.startswith("..\\") or "\\..\\" in sanitized_filename: raise ValueError( f"Invalid filename: cannot handle filename '{sanitized_filename}' on Windows. Please ask the repository" " owner to rename this file." ) file_path = local_dir / sanitized_filename metadata_path = _huggingface_dir(local_dir) / "download" / f"{sanitized_filename}.metadata" lock_path = metadata_path.with_suffix(".lock") # Some Windows versions do not allow for paths longer than 255 characters. # In this case, we must specify it as an extended path by using the "\\?\" prefix if os.name == "nt": if not str(local_dir).startswith("\\\\?\\") and len(os.path.abspath(lock_path)) > 255: file_path = Path("\\\\?\\" + os.path.abspath(file_path)) lock_path = Path("\\\\?\\" + os.path.abspath(lock_path)) metadata_path = Path("\\\\?\\" + os.path.abspath(metadata_path)) file_path.parent.mkdir(parents=True, exist_ok=True) metadata_path.parent.mkdir(parents=True, exist_ok=True) return LocalDownloadFilePaths(file_path=file_path, lock_path=lock_path, metadata_path=metadata_path) def get_local_upload_paths(local_dir: Path, filename: str) -> LocalUploadFilePaths: """Compute paths to the files related to an upload process. Folders containing the paths are all guaranteed to exist. Args: local_dir (`Path`): Path to the local directory that is uploaded. filename (`str`): Path of the file in the repo. Return: [`LocalUploadFilePaths`]: the paths to the files (file_path, lock_path, metadata_path). """ # filename is the path in the Hub repository (separated by '/') # make sure to have a cross platform transcription sanitized_filename = os.path.join(*filename.split("/")) if os.name == "nt": if sanitized_filename.startswith("..\\") or "\\..\\" in sanitized_filename: raise ValueError( f"Invalid filename: cannot handle filename '{sanitized_filename}' on Windows. Please ask the repository" " owner to rename this file." ) file_path = local_dir / sanitized_filename metadata_path = _huggingface_dir(local_dir) / "upload" / f"{sanitized_filename}.metadata" lock_path = metadata_path.with_suffix(".lock") # Some Windows versions do not allow for paths longer than 255 characters. # In this case, we must specify it as an extended path by using the "\\?\" prefix if os.name == "nt": if not str(local_dir).startswith("\\\\?\\") and len(os.path.abspath(lock_path)) > 255: file_path = Path("\\\\?\\" + os.path.abspath(file_path)) lock_path = Path("\\\\?\\" + os.path.abspath(lock_path)) metadata_path = Path("\\\\?\\" + os.path.abspath(metadata_path)) file_path.parent.mkdir(parents=True, exist_ok=True) metadata_path.parent.mkdir(parents=True, exist_ok=True) return LocalUploadFilePaths( path_in_repo=filename, file_path=file_path, lock_path=lock_path, metadata_path=metadata_path ) def read_download_metadata(local_dir: Path, filename: str) -> Optional[LocalDownloadFileMetadata]: """Read metadata about a file in the local directory related to a download process. Args: local_dir (`Path`): Path to the local directory in which files are downloaded. filename (`str`): Path of the file in the repo. Return: `[LocalDownloadFileMetadata]` or `None`: the metadata if it exists, `None` otherwise. """ paths = get_local_download_paths(local_dir, filename) with WeakFileLock(paths.lock_path): if paths.metadata_path.exists(): try: with paths.metadata_path.open() as f: commit_hash = f.readline().strip() etag = f.readline().strip() timestamp = float(f.readline().strip()) metadata = LocalDownloadFileMetadata( filename=filename, commit_hash=commit_hash, etag=etag, timestamp=timestamp, ) except Exception as e: # remove the metadata file if it is corrupted / not the right format logger.warning( f"Invalid metadata file {paths.metadata_path}: {e}. Removing it from disk and continue." ) try: paths.metadata_path.unlink() except Exception as e: logger.warning(f"Could not remove corrupted metadata file {paths.metadata_path}: {e}") try: # check if the file exists and hasn't been modified since the metadata was saved stat = paths.file_path.stat() if ( stat.st_mtime - 1 <= metadata.timestamp ): # allow 1s difference as stat.st_mtime might not be precise return metadata logger.info(f"Ignored metadata for '{filename}' (outdated). Will re-compute hash.") except FileNotFoundError: # file does not exist => metadata is outdated return None return None def read_upload_metadata(local_dir: Path, filename: str) -> LocalUploadFileMetadata: """Read metadata about a file in the local directory related to an upload process. TODO: factorize logic with `read_download_metadata`. Args: local_dir (`Path`): Path to the local directory in which files are downloaded. filename (`str`): Path of the file in the repo. Return: `[LocalUploadFileMetadata]` or `None`: the metadata if it exists, `None` otherwise. """ paths = get_local_upload_paths(local_dir, filename) with WeakFileLock(paths.lock_path): if paths.metadata_path.exists(): try: with paths.metadata_path.open() as f: timestamp = float(f.readline().strip()) size = int(f.readline().strip()) # never None _should_ignore = f.readline().strip() should_ignore = None if _should_ignore == "" else bool(int(_should_ignore)) _sha256 = f.readline().strip() sha256 = None if _sha256 == "" else _sha256 _upload_mode = f.readline().strip() upload_mode = None if _upload_mode == "" else _upload_mode if upload_mode not in (None, "regular", "lfs"): raise ValueError(f"Invalid upload mode in metadata {paths.path_in_repo}: {upload_mode}") is_uploaded = bool(int(f.readline().strip())) is_committed = bool(int(f.readline().strip())) metadata = LocalUploadFileMetadata( timestamp=timestamp, size=size, should_ignore=should_ignore, sha256=sha256, upload_mode=upload_mode, is_uploaded=is_uploaded, is_committed=is_committed, ) except Exception as e: # remove the metadata file if it is corrupted / not the right format logger.warning( f"Invalid metadata file {paths.metadata_path}: {e}. Removing it from disk and continue." ) try: paths.metadata_path.unlink() except Exception as e: logger.warning(f"Could not remove corrupted metadata file {paths.metadata_path}: {e}") # TODO: can we do better? if ( metadata.timestamp is not None and metadata.is_uploaded # file was uploaded and not metadata.is_committed # but not committed and time.time() - metadata.timestamp > 20 * 3600 # and it's been more than 20 hours ): # => we consider it as garbage-collected by S3 metadata.is_uploaded = False # check if the file exists and hasn't been modified since the metadata was saved try: if metadata.timestamp is not None and paths.file_path.stat().st_mtime <= metadata.timestamp: return metadata logger.info(f"Ignored metadata for '{filename}' (outdated). Will re-compute hash.") except FileNotFoundError: # file does not exist => metadata is outdated pass # empty metadata => we don't know anything expect its size return LocalUploadFileMetadata(size=paths.file_path.stat().st_size) def write_download_metadata(local_dir: Path, filename: str, commit_hash: str, etag: str) -> None: """Write metadata about a file in the local directory related to a download process. Args: local_dir (`Path`): Path to the local directory in which files are downloaded. """ paths = get_local_download_paths(local_dir, filename) with WeakFileLock(paths.lock_path): with paths.metadata_path.open("w") as f: f.write(f"{commit_hash}\n{etag}\n{time.time()}\n") def _huggingface_dir(local_dir: Path) -> Path: """Return the path to the `.cache/huggingface` directory in a local directory.""" # Wrap in lru_cache to avoid overwriting the .gitignore file if called multiple times path = local_dir / ".cache" / "huggingface" path.mkdir(exist_ok=True, parents=True) # Create a .gitignore file in the .cache/huggingface directory if it doesn't exist # Should be thread-safe enough like this. gitignore = path / ".gitignore" gitignore_lock = path / ".gitignore.lock" if not gitignore.exists(): try: with WeakFileLock(gitignore_lock, timeout=0.1): gitignore.write_text("*") except IndexError: pass except OSError: # TimeoutError, FileNotFoundError, PermissionError, etc. pass try: gitignore_lock.unlink() except OSError: pass return path def _short_hash(filename: str) -> str: return base64.urlsafe_b64encode(hashlib.sha1(filename.encode()).digest()).decode() huggingface_hub-0.31.1/src/huggingface_hub/_login.py000066400000000000000000000475121500667546600224460ustar00rootroot00000000000000# Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains methods to log in to the Hub.""" import os import subprocess from getpass import getpass from pathlib import Path from typing import Optional from . import constants from .commands._cli_utils import ANSI from .utils import ( capture_output, get_token, is_google_colab, is_notebook, list_credential_helpers, logging, run_subprocess, set_git_credential, unset_git_credential, ) from .utils._auth import ( _get_token_by_name, _get_token_from_environment, _get_token_from_file, _get_token_from_google_colab, _save_stored_tokens, _save_token, get_stored_tokens, ) from .utils._deprecation import _deprecate_arguments, _deprecate_positional_args logger = logging.get_logger(__name__) _HF_LOGO_ASCII = """ _| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| _|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_| _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| _| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_| """ @_deprecate_arguments( version="1.0", deprecated_args="write_permission", custom_message="Fine-grained tokens added complexity to the permissions, making it irrelevant to check if a token has 'write' access.", ) @_deprecate_positional_args(version="1.0") def login( token: Optional[str] = None, *, add_to_git_credential: bool = False, new_session: bool = True, write_permission: bool = False, ) -> None: """Login the machine to access the Hub. The `token` is persisted in cache and set as a git credential. Once done, the machine is logged in and the access token will be available across all `huggingface_hub` components. If `token` is not provided, it will be prompted to the user either with a widget (in a notebook) or via the terminal. To log in from outside of a script, one can also use `huggingface-cli login` which is a cli command that wraps [`login`]. [`login`] is a drop-in replacement method for [`notebook_login`] as it wraps and extends its capabilities. When the token is not passed, [`login`] will automatically detect if the script runs in a notebook or not. However, this detection might not be accurate due to the variety of notebooks that exists nowadays. If that is the case, you can always force the UI by using [`notebook_login`] or [`interpreter_login`]. Args: token (`str`, *optional*): User access token to generate from https://huggingface.co/settings/token. add_to_git_credential (`bool`, defaults to `False`): If `True`, token will be set as git credential. If no git credential helper is configured, a warning will be displayed to the user. If `token` is `None`, the value of `add_to_git_credential` is ignored and will be prompted again to the end user. new_session (`bool`, defaults to `True`): If `True`, will request a token even if one is already saved on the machine. write_permission (`bool`): Ignored and deprecated argument. Raises: [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If an organization token is passed. Only personal account tokens are valid to log in. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If token is invalid. [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError) If running in a notebook but `ipywidgets` is not installed. """ if token is not None: if not add_to_git_credential: logger.info( "The token has not been saved to the git credentials helper. Pass " "`add_to_git_credential=True` in this function directly or " "`--add-to-git-credential` if using via `huggingface-cli` if " "you want to set the git credential as well." ) _login(token, add_to_git_credential=add_to_git_credential) elif is_notebook(): notebook_login(new_session=new_session) else: interpreter_login(new_session=new_session) def logout(token_name: Optional[str] = None) -> None: """Logout the machine from the Hub. Token is deleted from the machine and removed from git credential. Args: token_name (`str`, *optional*): Name of the access token to logout from. If `None`, will logout from all saved access tokens. Raises: [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError): If the access token name is not found. """ if get_token() is None and not get_stored_tokens(): # No active token and no saved access tokens logger.warning("Not logged in!") return if not token_name: # Delete all saved access tokens and token for file_path in (constants.HF_TOKEN_PATH, constants.HF_STORED_TOKENS_PATH): try: Path(file_path).unlink() except FileNotFoundError: pass logger.info("Successfully logged out from all access tokens.") else: _logout_from_token(token_name) logger.info(f"Successfully logged out from access token: {token_name}.") unset_git_credential() # Check if still logged in if _get_token_from_google_colab() is not None: raise EnvironmentError( "You are automatically logged in using a Google Colab secret.\n" "To log out, you must unset the `HF_TOKEN` secret in your Colab settings." ) if _get_token_from_environment() is not None: raise EnvironmentError( "Token has been deleted from your machine but you are still logged in.\n" "To log out, you must clear out both `HF_TOKEN` and `HUGGING_FACE_HUB_TOKEN` environment variables." ) def auth_switch(token_name: str, add_to_git_credential: bool = False) -> None: """Switch to a different access token. Args: token_name (`str`): Name of the access token to switch to. add_to_git_credential (`bool`, defaults to `False`): If `True`, token will be set as git credential. If no git credential helper is configured, a warning will be displayed to the user. If `token` is `None`, the value of `add_to_git_credential` is ignored and will be prompted again to the end user. Raises: [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError): If the access token name is not found. """ token = _get_token_by_name(token_name) if not token: raise ValueError(f"Access token {token_name} not found in {constants.HF_STORED_TOKENS_PATH}") # Write token to HF_TOKEN_PATH _set_active_token(token_name, add_to_git_credential) logger.info(f"The current active token is: {token_name}") token_from_environment = _get_token_from_environment() if token_from_environment is not None and token_from_environment != token: logger.warning( "The environment variable `HF_TOKEN` is set and will override the access token you've just switched to." ) def auth_list() -> None: """List all stored access tokens.""" tokens = get_stored_tokens() if not tokens: logger.info("No access tokens found.") return # Find current token current_token = get_token() current_token_name = None for token_name in tokens: if tokens.get(token_name) == current_token: current_token_name = token_name # Print header max_offset = max(len("token"), max(len(token) for token in tokens)) + 2 print(f" {{:<{max_offset}}}| {{:<15}}".format("name", "token")) print("-" * (max_offset + 2) + "|" + "-" * 15) # Print saved access tokens for token_name in tokens: token = tokens.get(token_name, "") masked_token = f"{token[:3]}****{token[-4:]}" if token != "" else token is_current = "*" if token == current_token else " " print(f"{is_current} {{:<{max_offset}}}| {{:<15}}".format(token_name, masked_token)) if _get_token_from_environment(): logger.warning( "\nNote: Environment variable `HF_TOKEN` is set and is the current active token independently from the stored tokens listed above." ) elif current_token_name is None: logger.warning( "\nNote: No active token is set and no environment variable `HF_TOKEN` is found. Use `huggingface-cli login` to log in." ) ### # Interpreter-based login (text) ### @_deprecate_arguments( version="1.0", deprecated_args="write_permission", custom_message="Fine-grained tokens added complexity to the permissions, making it irrelevant to check if a token has 'write' access.", ) @_deprecate_positional_args(version="1.0") def interpreter_login(*, new_session: bool = True, write_permission: bool = False) -> None: """ Displays a prompt to log in to the HF website and store the token. This is equivalent to [`login`] without passing a token when not run in a notebook. [`interpreter_login`] is useful if you want to force the use of the terminal prompt instead of a notebook widget. For more details, see [`login`]. Args: new_session (`bool`, defaults to `True`): If `True`, will request a token even if one is already saved on the machine. write_permission (`bool`): Ignored and deprecated argument. """ if not new_session and get_token() is not None: logger.info("User is already logged in.") return from .commands.delete_cache import _ask_for_confirmation_no_tui print(_HF_LOGO_ASCII) if get_token() is not None: logger.info( " A token is already saved on your machine. Run `huggingface-cli" " whoami` to get more information or `huggingface-cli logout` if you want" " to log out." ) logger.info(" Setting a new token will erase the existing one.") logger.info( " To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens ." ) if os.name == "nt": logger.info("Token can be pasted using 'Right-Click'.") token = getpass("Enter your token (input will not be visible): ") add_to_git_credential = _ask_for_confirmation_no_tui("Add token as git credential?") _login(token=token, add_to_git_credential=add_to_git_credential) ### # Notebook-based login (widget) ### NOTEBOOK_LOGIN_PASSWORD_HTML = """
Hugging Face
Immediately click login after typing your password or it might be stored in plain text in this notebook file.
""" NOTEBOOK_LOGIN_TOKEN_HTML_START = """
Hugging Face
Copy a token from your Hugging Face tokens page and paste it below.
Immediately click login after copying your token or it might be stored in plain text in this notebook file.
""" NOTEBOOK_LOGIN_TOKEN_HTML_END = """ Pro Tip: If you don't already have one, you can create a dedicated 'notebooks' token with 'write' access, that you can then easily reuse for all notebooks. """ @_deprecate_arguments( version="1.0", deprecated_args="write_permission", custom_message="Fine-grained tokens added complexity to the permissions, making it irrelevant to check if a token has 'write' access.", ) @_deprecate_positional_args(version="1.0") def notebook_login(*, new_session: bool = True, write_permission: bool = False) -> None: """ Displays a widget to log in to the HF website and store the token. This is equivalent to [`login`] without passing a token when run in a notebook. [`notebook_login`] is useful if you want to force the use of the notebook widget instead of a prompt in the terminal. For more details, see [`login`]. Args: new_session (`bool`, defaults to `True`): If `True`, will request a token even if one is already saved on the machine. write_permission (`bool`): Ignored and deprecated argument. """ try: import ipywidgets.widgets as widgets # type: ignore from IPython.display import display # type: ignore except ImportError: raise ImportError( "The `notebook_login` function can only be used in a notebook (Jupyter or" " Colab) and you need the `ipywidgets` module: `pip install ipywidgets`." ) if not new_session and get_token() is not None: logger.info("User is already logged in.") return box_layout = widgets.Layout(display="flex", flex_flow="column", align_items="center", width="50%") token_widget = widgets.Password(description="Token:") git_checkbox_widget = widgets.Checkbox(value=True, description="Add token as git credential?") token_finish_button = widgets.Button(description="Login") login_token_widget = widgets.VBox( [ widgets.HTML(NOTEBOOK_LOGIN_TOKEN_HTML_START), token_widget, git_checkbox_widget, token_finish_button, widgets.HTML(NOTEBOOK_LOGIN_TOKEN_HTML_END), ], layout=box_layout, ) display(login_token_widget) # On click events def login_token_event(t): """Event handler for the login button.""" token = token_widget.value add_to_git_credential = git_checkbox_widget.value # Erase token and clear value to make sure it's not saved in the notebook. token_widget.value = "" # Hide inputs login_token_widget.children = [widgets.Label("Connecting...")] try: with capture_output() as captured: _login(token, add_to_git_credential=add_to_git_credential) message = captured.getvalue() except Exception as error: message = str(error) # Print result (success message or error) login_token_widget.children = [widgets.Label(line) for line in message.split("\n") if line.strip()] token_finish_button.on_click(login_token_event) ### # Login private helpers ### def _login( token: str, add_to_git_credential: bool, ) -> None: from .hf_api import whoami # avoid circular import if token.startswith("api_org"): raise ValueError("You must use your personal account token, not an organization token.") token_info = whoami(token) permission = token_info["auth"]["accessToken"]["role"] logger.info(f"Token is valid (permission: {permission}).") token_name = token_info["auth"]["accessToken"]["displayName"] # Store token locally _save_token(token=token, token_name=token_name) # Set active token _set_active_token(token_name=token_name, add_to_git_credential=add_to_git_credential) logger.info("Login successful.") if _get_token_from_environment(): logger.warning( "Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured." ) else: logger.info(f"The current active token is: `{token_name}`") def _logout_from_token(token_name: str) -> None: """Logout from a specific access token. Args: token_name (`str`): The name of the access token to logout from. Raises: [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError): If the access token name is not found. """ stored_tokens = get_stored_tokens() # If there is no access tokens saved or the access token name is not found, do nothing if not stored_tokens or token_name not in stored_tokens: return token = stored_tokens.pop(token_name) _save_stored_tokens(stored_tokens) if token == _get_token_from_file(): logger.warning(f"Active token '{token_name}' has been deleted.") Path(constants.HF_TOKEN_PATH).unlink(missing_ok=True) def _set_active_token( token_name: str, add_to_git_credential: bool, ) -> None: """Set the active access token. Args: token_name (`str`): The name of the token to set as active. """ token = _get_token_by_name(token_name) if not token: raise ValueError(f"Token {token_name} not found in {constants.HF_STORED_TOKENS_PATH}") if add_to_git_credential: if _is_git_credential_helper_configured(): set_git_credential(token) logger.info( "Your token has been saved in your configured git credential helpers" + f" ({','.join(list_credential_helpers())})." ) else: logger.warning("Token has not been saved to git credential helper.") # Write token to HF_TOKEN_PATH path = Path(constants.HF_TOKEN_PATH) path.parent.mkdir(parents=True, exist_ok=True) path.write_text(token) logger.info(f"Your token has been saved to {constants.HF_TOKEN_PATH}") def _is_git_credential_helper_configured() -> bool: """Check if a git credential helper is configured. Warns user if not the case (except for Google Colab where "store" is set by default by `huggingface_hub`). """ helpers = list_credential_helpers() if len(helpers) > 0: return True # Do not warn: at least 1 helper is set # Only in Google Colab to avoid the warning message # See https://github.com/huggingface/huggingface_hub/issues/1043#issuecomment-1247010710 if is_google_colab(): _set_store_as_git_credential_helper_globally() return True # Do not warn: "store" is used by default in Google Colab # Otherwise, warn user print( ANSI.red( "Cannot authenticate through git-credential as no helper is defined on your" " machine.\nYou might have to re-authenticate when pushing to the Hugging" " Face Hub.\nRun the following command in your terminal in case you want to" " set the 'store' credential helper as default.\n\ngit config --global" " credential.helper store\n\nRead" " https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more" " details." ) ) return False def _set_store_as_git_credential_helper_globally() -> None: """Set globally the credential.helper to `store`. To be used only in Google Colab as we assume the user doesn't care about the git credential config. It is the only particular case where we don't want to display the warning message in [`notebook_login()`]. Related: - https://github.com/huggingface/huggingface_hub/issues/1043 - https://github.com/huggingface/huggingface_hub/issues/1051 - https://git-scm.com/docs/git-credential-store """ try: run_subprocess("git config --global credential.helper store") except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) huggingface_hub-0.31.1/src/huggingface_hub/_snapshot_download.py000066400000000000000000000352471500667546600250660ustar00rootroot00000000000000import os from pathlib import Path from typing import Dict, List, Literal, Optional, Union import requests from tqdm.auto import tqdm as base_tqdm from tqdm.contrib.concurrent import thread_map from . import constants from .errors import GatedRepoError, LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name from .hf_api import DatasetInfo, HfApi, ModelInfo, SpaceInfo from .utils import OfflineModeIsEnabled, filter_repo_objects, logging, validate_hf_hub_args from .utils import tqdm as hf_tqdm logger = logging.get_logger(__name__) @validate_hf_hub_args def snapshot_download( repo_id: str, *, repo_type: Optional[str] = None, revision: Optional[str] = None, cache_dir: Union[str, Path, None] = None, local_dir: Union[str, Path, None] = None, library_name: Optional[str] = None, library_version: Optional[str] = None, user_agent: Optional[Union[Dict, str]] = None, proxies: Optional[Dict] = None, etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, force_download: bool = False, token: Optional[Union[bool, str]] = None, local_files_only: bool = False, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, max_workers: int = 8, tqdm_class: Optional[base_tqdm] = None, headers: Optional[Dict[str, str]] = None, endpoint: Optional[str] = None, # Deprecated args local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", resume_download: Optional[bool] = None, ) -> str: """Download repo files. Download a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from a repo, because you don't know which ones you will need a priori. All files are nested inside a folder in order to keep their actual filename relative to that folder. You can also filter which files to download using `allow_patterns` and `ignore_patterns`. If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir` to store some metadata related to the downloaded files. While this mechanism is not as robust as the main cache-system, it's optimized for regularly pulling the latest version of a repository. An alternative would be to clone the repo but this requires git and git-lfs to be installed and properly configured. It is also not possible to filter which files to download when cloning a repository using git. Args: repo_id (`str`): A user or an organization name and a repo name separated by a `/`. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if downloading from a dataset or space, `None` or `"model"` if downloading from a model. Default is `None`. revision (`str`, *optional*): An optional Git revision id which can be a branch name, a tag, or a commit hash. cache_dir (`str`, `Path`, *optional*): Path to the folder where cached files are stored. local_dir (`str` or `Path`, *optional*): If provided, the downloaded files will be placed under this directory. library_name (`str`, *optional*): The name of the library to which the object corresponds. library_version (`str`, *optional*): The version of the library. user_agent (`str`, `dict`, *optional*): The user-agent info in the form of a dictionary or a string. proxies (`dict`, *optional*): Dictionary mapping protocol to the URL of the proxy passed to `requests.request`. etag_timeout (`float`, *optional*, defaults to `10`): When fetching ETag, how many seconds to wait for the server to send data before giving up which is passed to `requests.request`. force_download (`bool`, *optional*, defaults to `False`): Whether the file should be downloaded even if it already exists in the local cache. token (`str`, `bool`, *optional*): A token to be used for the download. - If `True`, the token is read from the HuggingFace config folder. - If a string, it's used as the authentication token. headers (`dict`, *optional*): Additional headers to include in the request. Those headers take precedence over the others. local_files_only (`bool`, *optional*, defaults to `False`): If `True`, avoid downloading the file and return the path to the local cached file if it exists. allow_patterns (`List[str]` or `str`, *optional*): If provided, only files matching at least one pattern are downloaded. ignore_patterns (`List[str]` or `str`, *optional*): If provided, files matching any of the patterns are not downloaded. max_workers (`int`, *optional*): Number of concurrent threads to download files (1 thread = 1 file download). Defaults to 8. tqdm_class (`tqdm`, *optional*): If provided, overwrites the default behavior for the progress bar. Passed argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior. Note that the `tqdm_class` is not passed to each individual download. Defaults to the custom HF progress bar that can be disabled by setting `HF_HUB_DISABLE_PROGRESS_BARS` environment variable. Returns: `str`: folder path of the repo snapshot. Raises: [`~utils.RepositoryNotFoundError`] If the repository to download from cannot be found. This may be because it doesn't exist, or because it is set to `private` and you do not have access. [`~utils.RevisionNotFoundError`] If the revision to download from cannot be found. [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) If `token=True` and the token cannot be found. [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if ETag cannot be determined. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid. """ if cache_dir is None: cache_dir = constants.HF_HUB_CACHE if revision is None: revision = constants.DEFAULT_REVISION if isinstance(cache_dir, Path): cache_dir = str(cache_dir) if repo_type is None: repo_type = "model" if repo_type not in constants.REPO_TYPES: raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(constants.REPO_TYPES)}") storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type)) repo_info: Union[ModelInfo, DatasetInfo, SpaceInfo, None] = None api_call_error: Optional[Exception] = None if not local_files_only: # try/except logic to handle different errors => taken from `hf_hub_download` try: # if we have internet connection we want to list files to download api = HfApi( library_name=library_name, library_version=library_version, user_agent=user_agent, endpoint=endpoint, headers=headers, ) repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision, token=token) except (requests.exceptions.SSLError, requests.exceptions.ProxyError): # Actually raise for those subclasses of ConnectionError raise except ( requests.exceptions.ConnectionError, requests.exceptions.Timeout, OfflineModeIsEnabled, ) as error: # Internet connection is down # => will try to use local files only api_call_error = error pass except RevisionNotFoundError: # The repo was found but the revision doesn't exist on the Hub (never existed or got deleted) raise except requests.HTTPError as error: # Multiple reasons for an http error: # - Repository is private and invalid/missing token sent # - Repository is gated and invalid/missing token sent # - Hub is down (error 500 or 504) # => let's switch to 'local_files_only=True' to check if the files are already cached. # (if it's not the case, the error will be re-raised) api_call_error = error pass # At this stage, if `repo_info` is None it means either: # - internet connection is down # - internet connection is deactivated (local_files_only=True or HF_HUB_OFFLINE=True) # - repo is private/gated and invalid/missing token sent # - Hub is down # => let's look if we can find the appropriate folder in the cache: # - if the specified revision is a commit hash, look inside "snapshots". # - f the specified revision is a branch or tag, look inside "refs". # => if local_dir is not None, we will return the path to the local folder if it exists. if repo_info is None: # Try to get which commit hash corresponds to the specified revision commit_hash = None if REGEX_COMMIT_HASH.match(revision): commit_hash = revision else: ref_path = os.path.join(storage_folder, "refs", revision) if os.path.exists(ref_path): # retrieve commit_hash from refs file with open(ref_path) as f: commit_hash = f.read() # Try to locate snapshot folder for this commit hash if commit_hash is not None and local_dir is None: snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) if os.path.exists(snapshot_folder): # Snapshot folder exists => let's return it # (but we can't check if all the files are actually there) return snapshot_folder # If local_dir is not None, return it if it exists and is not empty if local_dir is not None: local_dir = Path(local_dir) if local_dir.is_dir() and any(local_dir.iterdir()): logger.warning( f"Returning existing local_dir `{local_dir}` as remote repo cannot be accessed in `snapshot_download` ({api_call_error})." ) return str(local_dir.resolve()) # If we couldn't find the appropriate folder on disk, raise an error. if local_files_only: raise LocalEntryNotFoundError( "Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and " "outgoing traffic has been disabled. To enable repo look-ups and downloads online, pass " "'local_files_only=False' as input." ) elif isinstance(api_call_error, OfflineModeIsEnabled): raise LocalEntryNotFoundError( "Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and " "outgoing traffic has been disabled. To enable repo look-ups and downloads online, set " "'HF_HUB_OFFLINE=0' as environment variable." ) from api_call_error elif isinstance(api_call_error, RepositoryNotFoundError) or isinstance(api_call_error, GatedRepoError): # Repo not found => let's raise the actual error raise api_call_error else: # Otherwise: most likely a connection issue or Hub downtime => let's warn the user raise LocalEntryNotFoundError( "An error happened while trying to locate the files on the Hub and we cannot find the appropriate" " snapshot folder for the specified revision on the local disk. Please check your internet connection" " and try again." ) from api_call_error # At this stage, internet connection is up and running # => let's download the files! assert repo_info.sha is not None, "Repo info returned from server must have a revision sha." assert repo_info.siblings is not None, "Repo info returned from server must have a siblings list." filtered_repo_files = list( filter_repo_objects( items=[f.rfilename for f in repo_info.siblings], allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, ) ) commit_hash = repo_info.sha snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) # if passed revision is not identical to commit_hash # then revision has to be a branch name or tag name. # In that case store a ref. if revision != commit_hash: ref_path = os.path.join(storage_folder, "refs", revision) try: os.makedirs(os.path.dirname(ref_path), exist_ok=True) with open(ref_path, "w") as f: f.write(commit_hash) except OSError as e: logger.warning(f"Ignored error while writing commit hash to {ref_path}: {e}.") # we pass the commit_hash to hf_hub_download # so no network call happens if we already # have the file locally. def _inner_hf_hub_download(repo_file: str): return hf_hub_download( repo_id, filename=repo_file, repo_type=repo_type, revision=commit_hash, endpoint=endpoint, cache_dir=cache_dir, local_dir=local_dir, local_dir_use_symlinks=local_dir_use_symlinks, library_name=library_name, library_version=library_version, user_agent=user_agent, proxies=proxies, etag_timeout=etag_timeout, resume_download=resume_download, force_download=force_download, token=token, headers=headers, ) if constants.HF_HUB_ENABLE_HF_TRANSFER: # when using hf_transfer we don't want extra parallelism # from the one hf_transfer provides for file in filtered_repo_files: _inner_hf_hub_download(file) else: thread_map( _inner_hf_hub_download, filtered_repo_files, desc=f"Fetching {len(filtered_repo_files)} files", max_workers=max_workers, # User can use its own tqdm class or the default one from `huggingface_hub.utils` tqdm_class=tqdm_class or hf_tqdm, ) if local_dir is not None: return str(os.path.realpath(local_dir)) return snapshot_folder huggingface_hub-0.31.1/src/huggingface_hub/_space_api.py000066400000000000000000000125361500667546600232600ustar00rootroot00000000000000# coding=utf-8 # Copyright 2019-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from datetime import datetime from enum import Enum from typing import Dict, Optional from huggingface_hub.utils import parse_datetime class SpaceStage(str, Enum): """ Enumeration of possible stage of a Space on the Hub. Value can be compared to a string: ```py assert SpaceStage.BUILDING == "BUILDING" ``` Taken from https://github.com/huggingface/moon-landing/blob/main/server/repo_types/SpaceInfo.ts#L61 (private url). """ # Copied from moon-landing > server > repo_types > SpaceInfo.ts (private repo) NO_APP_FILE = "NO_APP_FILE" CONFIG_ERROR = "CONFIG_ERROR" BUILDING = "BUILDING" BUILD_ERROR = "BUILD_ERROR" RUNNING = "RUNNING" RUNNING_BUILDING = "RUNNING_BUILDING" RUNTIME_ERROR = "RUNTIME_ERROR" DELETING = "DELETING" STOPPED = "STOPPED" PAUSED = "PAUSED" class SpaceHardware(str, Enum): """ Enumeration of hardwares available to run your Space on the Hub. Value can be compared to a string: ```py assert SpaceHardware.CPU_BASIC == "cpu-basic" ``` Taken from https://github.com/huggingface-internal/moon-landing/blob/main/server/repo_types/SpaceHardwareFlavor.ts (private url). """ # CPU CPU_BASIC = "cpu-basic" CPU_UPGRADE = "cpu-upgrade" CPU_XL = "cpu-xl" # ZeroGPU ZERO_A10G = "zero-a10g" # GPU T4_SMALL = "t4-small" T4_MEDIUM = "t4-medium" L4X1 = "l4x1" L4X4 = "l4x4" L40SX1 = "l40sx1" L40SX4 = "l40sx4" L40SX8 = "l40sx8" A10G_SMALL = "a10g-small" A10G_LARGE = "a10g-large" A10G_LARGEX2 = "a10g-largex2" A10G_LARGEX4 = "a10g-largex4" A100_LARGE = "a100-large" H100 = "h100" H100X8 = "h100x8" class SpaceStorage(str, Enum): """ Enumeration of persistent storage available for your Space on the Hub. Value can be compared to a string: ```py assert SpaceStorage.SMALL == "small" ``` Taken from https://github.com/huggingface/moon-landing/blob/main/server/repo_types/SpaceHardwareFlavor.ts#L24 (private url). """ SMALL = "small" MEDIUM = "medium" LARGE = "large" @dataclass class SpaceRuntime: """ Contains information about the current runtime of a Space. Args: stage (`str`): Current stage of the space. Example: RUNNING. hardware (`str` or `None`): Current hardware of the space. Example: "cpu-basic". Can be `None` if Space is `BUILDING` for the first time. requested_hardware (`str` or `None`): Requested hardware. Can be different than `hardware` especially if the request has just been made. Example: "t4-medium". Can be `None` if no hardware has been requested yet. sleep_time (`int` or `None`): Number of seconds the Space will be kept alive after the last request. By default (if value is `None`), the Space will never go to sleep if it's running on an upgraded hardware, while it will go to sleep after 48 hours on a free 'cpu-basic' hardware. For more details, see https://huggingface.co/docs/hub/spaces-gpus#sleep-time. raw (`dict`): Raw response from the server. Contains more information about the Space runtime like number of replicas, number of cpu, memory size,... """ stage: SpaceStage hardware: Optional[SpaceHardware] requested_hardware: Optional[SpaceHardware] sleep_time: Optional[int] storage: Optional[SpaceStorage] raw: Dict def __init__(self, data: Dict) -> None: self.stage = data["stage"] self.hardware = data.get("hardware", {}).get("current") self.requested_hardware = data.get("hardware", {}).get("requested") self.sleep_time = data.get("gcTimeout") self.storage = data.get("storage") self.raw = data @dataclass class SpaceVariable: """ Contains information about the current variables of a Space. Args: key (`str`): Variable key. Example: `"MODEL_REPO_ID"` value (`str`): Variable value. Example: `"the_model_repo_id"`. description (`str` or None): Description of the variable. Example: `"Model Repo ID of the implemented model"`. updatedAt (`datetime` or None): datetime of the last update of the variable (if the variable has been updated at least once). """ key: str value: str description: Optional[str] updated_at: Optional[datetime] def __init__(self, key: str, values: Dict) -> None: self.key = key self.value = values["value"] self.description = values.get("description") updated_at = values.get("updatedAt") self.updated_at = parse_datetime(updated_at) if updated_at is not None else None huggingface_hub-0.31.1/src/huggingface_hub/_tensorboard_logger.py000066400000000000000000000202461500667546600252120ustar00rootroot00000000000000# Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains a logger to push training logs to the Hub, using Tensorboard.""" from pathlib import Path from typing import TYPE_CHECKING, List, Optional, Union from ._commit_scheduler import CommitScheduler from .errors import EntryNotFoundError from .repocard import ModelCard from .utils import experimental # Depending on user's setup, SummaryWriter can come either from 'tensorboardX' # or from 'torch.utils.tensorboard'. Both are compatible so let's try to load # from either of them. try: from tensorboardX import SummaryWriter is_summary_writer_available = True except ImportError: try: from torch.utils.tensorboard import SummaryWriter is_summary_writer_available = False except ImportError: # Dummy class to avoid failing at import. Will raise on instance creation. SummaryWriter = object is_summary_writer_available = False if TYPE_CHECKING: from tensorboardX import SummaryWriter class HFSummaryWriter(SummaryWriter): """ Wrapper around the tensorboard's `SummaryWriter` to push training logs to the Hub. Data is logged locally and then pushed to the Hub asynchronously. Pushing data to the Hub is done in a separate thread to avoid blocking the training script. In particular, if the upload fails for any reason (e.g. a connection issue), the main script will not be interrupted. Data is automatically pushed to the Hub every `commit_every` minutes (default to every 5 minutes). `HFSummaryWriter` is experimental. Its API is subject to change in the future without prior notice. Args: repo_id (`str`): The id of the repo to which the logs will be pushed. logdir (`str`, *optional*): The directory where the logs will be written. If not specified, a local directory will be created by the underlying `SummaryWriter` object. commit_every (`int` or `float`, *optional*): The frequency (in minutes) at which the logs will be pushed to the Hub. Defaults to 5 minutes. squash_history (`bool`, *optional*): Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is useful to avoid degraded performances on the repo when it grows too large. repo_type (`str`, *optional*): The type of the repo to which the logs will be pushed. Defaults to "model". repo_revision (`str`, *optional*): The revision of the repo to which the logs will be pushed. Defaults to "main". repo_private (`bool`, *optional*): Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists. path_in_repo (`str`, *optional*): The path to the folder in the repo where the logs will be pushed. Defaults to "tensorboard/". repo_allow_patterns (`List[str]` or `str`, *optional*): A list of patterns to include in the upload. Defaults to `"*.tfevents.*"`. Check out the [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#upload-a-folder) for more details. repo_ignore_patterns (`List[str]` or `str`, *optional*): A list of patterns to exclude in the upload. Check out the [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#upload-a-folder) for more details. token (`str`, *optional*): Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more details kwargs: Additional keyword arguments passed to `SummaryWriter`. Examples: ```diff # Taken from https://pytorch.org/docs/stable/tensorboard.html - from torch.utils.tensorboard import SummaryWriter + from huggingface_hub import HFSummaryWriter import numpy as np - writer = SummaryWriter() + writer = HFSummaryWriter(repo_id="username/my-trained-model") for n_iter in range(100): writer.add_scalar('Loss/train', np.random.random(), n_iter) writer.add_scalar('Loss/test', np.random.random(), n_iter) writer.add_scalar('Accuracy/train', np.random.random(), n_iter) writer.add_scalar('Accuracy/test', np.random.random(), n_iter) ``` ```py >>> from huggingface_hub import HFSummaryWriter # Logs are automatically pushed every 15 minutes (5 by default) + when exiting the context manager >>> with HFSummaryWriter(repo_id="test_hf_logger", commit_every=15) as logger: ... logger.add_scalar("a", 1) ... logger.add_scalar("b", 2) ``` """ @experimental def __new__(cls, *args, **kwargs) -> "HFSummaryWriter": if not is_summary_writer_available: raise ImportError( "You must have `tensorboard` installed to use `HFSummaryWriter`. Please run `pip install --upgrade" " tensorboardX` first." ) return super().__new__(cls) def __init__( self, repo_id: str, *, logdir: Optional[str] = None, commit_every: Union[int, float] = 5, squash_history: bool = False, repo_type: Optional[str] = None, repo_revision: Optional[str] = None, repo_private: Optional[bool] = None, path_in_repo: Optional[str] = "tensorboard", repo_allow_patterns: Optional[Union[List[str], str]] = "*.tfevents.*", repo_ignore_patterns: Optional[Union[List[str], str]] = None, token: Optional[str] = None, **kwargs, ): # Initialize SummaryWriter super().__init__(logdir=logdir, **kwargs) # Check logdir has been correctly initialized and fail early otherwise. In practice, SummaryWriter takes care of it. if not isinstance(self.logdir, str): raise ValueError(f"`self.logdir` must be a string. Got '{self.logdir}' of type {type(self.logdir)}.") # Append logdir name to `path_in_repo` if path_in_repo is None or path_in_repo == "": path_in_repo = Path(self.logdir).name else: path_in_repo = path_in_repo.strip("/") + "/" + Path(self.logdir).name # Initialize scheduler self.scheduler = CommitScheduler( folder_path=self.logdir, path_in_repo=path_in_repo, repo_id=repo_id, repo_type=repo_type, revision=repo_revision, private=repo_private, token=token, allow_patterns=repo_allow_patterns, ignore_patterns=repo_ignore_patterns, every=commit_every, squash_history=squash_history, ) # Exposing some high-level info at root level self.repo_id = self.scheduler.repo_id self.repo_type = self.scheduler.repo_type self.repo_revision = self.scheduler.revision # Add `hf-summary-writer` tag to the model card metadata try: card = ModelCard.load(repo_id_or_path=self.repo_id, repo_type=self.repo_type) except EntryNotFoundError: card = ModelCard("") tags = card.data.get("tags", []) if "hf-summary-writer" not in tags: tags.append("hf-summary-writer") card.data["tags"] = tags card.push_to_hub(repo_id=self.repo_id, repo_type=self.repo_type) def __exit__(self, exc_type, exc_val, exc_tb): """Push to hub in a non-blocking way when exiting the logger's context manager.""" super().__exit__(exc_type, exc_val, exc_tb) future = self.scheduler.trigger() future.result() huggingface_hub-0.31.1/src/huggingface_hub/_upload_large_folder.py000066400000000000000000000601551500667546600253250ustar00rootroot00000000000000# coding=utf-8 # Copyright 2024-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import enum import logging import os import queue import shutil import sys import threading import time import traceback from datetime import datetime from pathlib import Path from threading import Lock from typing import TYPE_CHECKING, List, Optional, Tuple, Union from urllib.parse import quote from . import constants from ._commit_api import CommitOperationAdd, UploadInfo, _fetch_upload_modes from ._local_folder import LocalUploadFileMetadata, LocalUploadFilePaths, get_local_upload_paths, read_upload_metadata from .constants import DEFAULT_REVISION, REPO_TYPES from .utils import DEFAULT_IGNORE_PATTERNS, filter_repo_objects, tqdm from .utils._cache_manager import _format_size from .utils.sha import sha_fileobj if TYPE_CHECKING: from .hf_api import HfApi logger = logging.getLogger(__name__) WAITING_TIME_IF_NO_TASKS = 10 # seconds MAX_NB_REGULAR_FILES_PER_COMMIT = 75 MAX_NB_LFS_FILES_PER_COMMIT = 150 COMMIT_SIZE_SCALE: List[int] = [20, 50, 75, 100, 125, 200, 250, 400, 600, 1000] def upload_large_folder_internal( api: "HfApi", repo_id: str, folder_path: Union[str, Path], *, repo_type: str, # Repo type is required! revision: Optional[str] = None, private: Optional[bool] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, num_workers: Optional[int] = None, print_report: bool = True, print_report_every: int = 60, ): """Upload a large folder to the Hub in the most resilient way possible. See [`HfApi.upload_large_folder`] for the full documentation. """ # 1. Check args and setup if repo_type is None: raise ValueError( "For large uploads, `repo_type` is explicitly required. Please set it to `model`, `dataset` or `space`." " If you are using the CLI, pass it as `--repo-type=model`." ) if repo_type not in REPO_TYPES: raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") if revision is None: revision = DEFAULT_REVISION folder_path = Path(folder_path).expanduser().resolve() if not folder_path.is_dir(): raise ValueError(f"Provided path: '{folder_path}' is not a directory") if ignore_patterns is None: ignore_patterns = [] elif isinstance(ignore_patterns, str): ignore_patterns = [ignore_patterns] ignore_patterns += DEFAULT_IGNORE_PATTERNS if num_workers is None: nb_cores = os.cpu_count() or 1 num_workers = max(nb_cores - 2, 2) # Use all but 2 cores, or at least 2 cores # 2. Create repo if missing repo_url = api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private, exist_ok=True) logger.info(f"Repo created: {repo_url}") repo_id = repo_url.repo_id # 3. List files to upload filtered_paths_list = filter_repo_objects( (path.relative_to(folder_path).as_posix() for path in folder_path.glob("**/*") if path.is_file()), allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, ) paths_list = [get_local_upload_paths(folder_path, relpath) for relpath in filtered_paths_list] logger.info(f"Found {len(paths_list)} candidate files to upload") # Read metadata for each file items = [ (paths, read_upload_metadata(folder_path, paths.path_in_repo)) for paths in tqdm(paths_list, desc="Recovering from metadata files") ] # 4. Start workers status = LargeUploadStatus(items) threads = [ threading.Thread( target=_worker_job, kwargs={ "status": status, "api": api, "repo_id": repo_id, "repo_type": repo_type, "revision": revision, }, ) for _ in range(num_workers) ] for thread in threads: thread.start() # 5. Print regular reports if print_report: print("\n\n" + status.current_report()) last_report_ts = time.time() while True: time.sleep(1) if time.time() - last_report_ts >= print_report_every: if print_report: _print_overwrite(status.current_report()) last_report_ts = time.time() if status.is_done(): logging.info("Is done: exiting main loop") break for thread in threads: thread.join() logger.info(status.current_report()) logging.info("Upload is complete!") #################### # Logic to manage workers and synchronize tasks #################### class WorkerJob(enum.Enum): SHA256 = enum.auto() GET_UPLOAD_MODE = enum.auto() PREUPLOAD_LFS = enum.auto() COMMIT = enum.auto() WAIT = enum.auto() # if no tasks are available but we don't want to exit JOB_ITEM_T = Tuple[LocalUploadFilePaths, LocalUploadFileMetadata] class LargeUploadStatus: """Contains information, queues and tasks for a large upload process.""" def __init__(self, items: List[JOB_ITEM_T]): self.items = items self.queue_sha256: "queue.Queue[JOB_ITEM_T]" = queue.Queue() self.queue_get_upload_mode: "queue.Queue[JOB_ITEM_T]" = queue.Queue() self.queue_preupload_lfs: "queue.Queue[JOB_ITEM_T]" = queue.Queue() self.queue_commit: "queue.Queue[JOB_ITEM_T]" = queue.Queue() self.lock = Lock() self.nb_workers_sha256: int = 0 self.nb_workers_get_upload_mode: int = 0 self.nb_workers_preupload_lfs: int = 0 self.nb_workers_commit: int = 0 self.nb_workers_waiting: int = 0 self.last_commit_attempt: Optional[float] = None self._started_at = datetime.now() self._chunk_idx: int = 1 self._chunk_lock: Lock = Lock() # Setup queues for item in self.items: paths, metadata = item if metadata.sha256 is None: self.queue_sha256.put(item) elif metadata.upload_mode is None: self.queue_get_upload_mode.put(item) elif metadata.upload_mode == "lfs" and not metadata.is_uploaded: self.queue_preupload_lfs.put(item) elif not metadata.is_committed: self.queue_commit.put(item) else: logger.debug(f"Skipping file {paths.path_in_repo} (already uploaded and committed)") def target_chunk(self) -> int: with self._chunk_lock: return COMMIT_SIZE_SCALE[self._chunk_idx] def update_chunk(self, success: bool, nb_items: int, duration: float) -> None: with self._chunk_lock: if not success: logger.warning(f"Failed to commit {nb_items} files at once. Will retry with less files in next batch.") self._chunk_idx -= 1 elif nb_items >= COMMIT_SIZE_SCALE[self._chunk_idx] and duration < 40: logger.info(f"Successfully committed {nb_items} at once. Increasing the limit for next batch.") self._chunk_idx += 1 self._chunk_idx = max(0, min(self._chunk_idx, len(COMMIT_SIZE_SCALE) - 1)) def current_report(self) -> str: """Generate a report of the current status of the large upload.""" nb_hashed = 0 size_hashed = 0 nb_preuploaded = 0 nb_lfs = 0 nb_lfs_unsure = 0 size_preuploaded = 0 nb_committed = 0 size_committed = 0 total_size = 0 ignored_files = 0 total_files = 0 with self.lock: for _, metadata in self.items: if metadata.should_ignore: ignored_files += 1 continue total_size += metadata.size total_files += 1 if metadata.sha256 is not None: nb_hashed += 1 size_hashed += metadata.size if metadata.upload_mode == "lfs": nb_lfs += 1 if metadata.upload_mode is None: nb_lfs_unsure += 1 if metadata.is_uploaded: nb_preuploaded += 1 size_preuploaded += metadata.size if metadata.is_committed: nb_committed += 1 size_committed += metadata.size total_size_str = _format_size(total_size) now = datetime.now() now_str = now.strftime("%Y-%m-%d %H:%M:%S") elapsed = now - self._started_at elapsed_str = str(elapsed).split(".")[0] # remove milliseconds message = "\n" + "-" * 10 message += f" {now_str} ({elapsed_str}) " message += "-" * 10 + "\n" message += "Files: " message += f"hashed {nb_hashed}/{total_files} ({_format_size(size_hashed)}/{total_size_str}) | " message += f"pre-uploaded: {nb_preuploaded}/{nb_lfs} ({_format_size(size_preuploaded)}/{total_size_str})" if nb_lfs_unsure > 0: message += f" (+{nb_lfs_unsure} unsure)" message += f" | committed: {nb_committed}/{total_files} ({_format_size(size_committed)}/{total_size_str})" message += f" | ignored: {ignored_files}\n" message += "Workers: " message += f"hashing: {self.nb_workers_sha256} | " message += f"get upload mode: {self.nb_workers_get_upload_mode} | " message += f"pre-uploading: {self.nb_workers_preupload_lfs} | " message += f"committing: {self.nb_workers_commit} | " message += f"waiting: {self.nb_workers_waiting}\n" message += "-" * 51 return message def is_done(self) -> bool: with self.lock: return all(metadata.is_committed or metadata.should_ignore for _, metadata in self.items) def _worker_job( status: LargeUploadStatus, api: "HfApi", repo_id: str, repo_type: str, revision: str, ): """ Main process for a worker. The worker will perform tasks based on the priority list until all files are uploaded and committed. If no tasks are available, the worker will wait for 10 seconds before checking again. If a task fails for any reason, the item(s) are put back in the queue for another worker to pick up. Read `upload_large_folder` docstring for more information on how tasks are prioritized. """ while True: next_job: Optional[Tuple[WorkerJob, List[JOB_ITEM_T]]] = None # Determine next task next_job = _determine_next_job(status) if next_job is None: return job, items = next_job # Perform task if job == WorkerJob.SHA256: item = items[0] # single item try: _compute_sha256(item) status.queue_get_upload_mode.put(item) except KeyboardInterrupt: raise except Exception as e: logger.error(f"Failed to compute sha256: {e}") traceback.format_exc() status.queue_sha256.put(item) with status.lock: status.nb_workers_sha256 -= 1 elif job == WorkerJob.GET_UPLOAD_MODE: try: _get_upload_mode(items, api=api, repo_id=repo_id, repo_type=repo_type, revision=revision) except KeyboardInterrupt: raise except Exception as e: logger.error(f"Failed to get upload mode: {e}") traceback.format_exc() # Items are either: # - dropped (if should_ignore) # - put in LFS queue (if LFS) # - put in commit queue (if regular) # - or put back (if error occurred). for item in items: _, metadata = item if metadata.should_ignore: continue if metadata.upload_mode == "lfs": status.queue_preupload_lfs.put(item) elif metadata.upload_mode == "regular": status.queue_commit.put(item) else: status.queue_get_upload_mode.put(item) with status.lock: status.nb_workers_get_upload_mode -= 1 elif job == WorkerJob.PREUPLOAD_LFS: item = items[0] # single item try: _preupload_lfs(item, api=api, repo_id=repo_id, repo_type=repo_type, revision=revision) status.queue_commit.put(item) except KeyboardInterrupt: raise except Exception as e: logger.error(f"Failed to preupload LFS: {e}") traceback.format_exc() status.queue_preupload_lfs.put(item) with status.lock: status.nb_workers_preupload_lfs -= 1 elif job == WorkerJob.COMMIT: start_ts = time.time() success = True try: _commit(items, api=api, repo_id=repo_id, repo_type=repo_type, revision=revision) except KeyboardInterrupt: raise except Exception as e: logger.error(f"Failed to commit: {e}") traceback.format_exc() for item in items: status.queue_commit.put(item) success = False duration = time.time() - start_ts status.update_chunk(success, len(items), duration) with status.lock: status.last_commit_attempt = time.time() status.nb_workers_commit -= 1 elif job == WorkerJob.WAIT: time.sleep(WAITING_TIME_IF_NO_TASKS) with status.lock: status.nb_workers_waiting -= 1 def _determine_next_job(status: LargeUploadStatus) -> Optional[Tuple[WorkerJob, List[JOB_ITEM_T]]]: with status.lock: # 1. Commit if more than 5 minutes since last commit attempt (and at least 1 file) if ( status.nb_workers_commit == 0 and status.queue_commit.qsize() > 0 and status.last_commit_attempt is not None and time.time() - status.last_commit_attempt > 5 * 60 ): status.nb_workers_commit += 1 logger.debug("Job: commit (more than 5 minutes since last commit attempt)") return (WorkerJob.COMMIT, _get_items_to_commit(status.queue_commit)) # 2. Commit if at least 100 files are ready to commit elif status.nb_workers_commit == 0 and status.queue_commit.qsize() >= 150: status.nb_workers_commit += 1 logger.debug("Job: commit (>100 files ready)") return (WorkerJob.COMMIT, _get_items_to_commit(status.queue_commit)) # 3. Get upload mode if at least 10 files elif status.queue_get_upload_mode.qsize() >= 10: status.nb_workers_get_upload_mode += 1 logger.debug("Job: get upload mode (>10 files ready)") return (WorkerJob.GET_UPLOAD_MODE, _get_n(status.queue_get_upload_mode, status.target_chunk())) # 4. Preupload LFS file if at least 1 file and no worker is preuploading LFS elif status.queue_preupload_lfs.qsize() > 0 and status.nb_workers_preupload_lfs == 0: status.nb_workers_preupload_lfs += 1 logger.debug("Job: preupload LFS (no other worker preuploading LFS)") return (WorkerJob.PREUPLOAD_LFS, _get_one(status.queue_preupload_lfs)) # 5. Compute sha256 if at least 1 file and no worker is computing sha256 elif status.queue_sha256.qsize() > 0 and status.nb_workers_sha256 == 0: status.nb_workers_sha256 += 1 logger.debug("Job: sha256 (no other worker computing sha256)") return (WorkerJob.SHA256, _get_one(status.queue_sha256)) # 6. Get upload mode if at least 1 file and no worker is getting upload mode elif status.queue_get_upload_mode.qsize() > 0 and status.nb_workers_get_upload_mode == 0: status.nb_workers_get_upload_mode += 1 logger.debug("Job: get upload mode (no other worker getting upload mode)") return (WorkerJob.GET_UPLOAD_MODE, _get_n(status.queue_get_upload_mode, status.target_chunk())) # 7. Preupload LFS file if at least 1 file # Skip if hf_transfer is enabled and there is already a worker preuploading LFS elif status.queue_preupload_lfs.qsize() > 0 and ( status.nb_workers_preupload_lfs == 0 or not constants.HF_HUB_ENABLE_HF_TRANSFER ): status.nb_workers_preupload_lfs += 1 logger.debug("Job: preupload LFS") return (WorkerJob.PREUPLOAD_LFS, _get_one(status.queue_preupload_lfs)) # 8. Compute sha256 if at least 1 file elif status.queue_sha256.qsize() > 0: status.nb_workers_sha256 += 1 logger.debug("Job: sha256") return (WorkerJob.SHA256, _get_one(status.queue_sha256)) # 9. Get upload mode if at least 1 file elif status.queue_get_upload_mode.qsize() > 0: status.nb_workers_get_upload_mode += 1 logger.debug("Job: get upload mode") return (WorkerJob.GET_UPLOAD_MODE, _get_n(status.queue_get_upload_mode, status.target_chunk())) # 10. Commit if at least 1 file and 1 min since last commit attempt elif ( status.nb_workers_commit == 0 and status.queue_commit.qsize() > 0 and status.last_commit_attempt is not None and time.time() - status.last_commit_attempt > 1 * 60 ): status.nb_workers_commit += 1 logger.debug("Job: commit (1 min since last commit attempt)") return (WorkerJob.COMMIT, _get_items_to_commit(status.queue_commit)) # 11. Commit if at least 1 file all other queues are empty and all workers are waiting # e.g. when it's the last commit elif ( status.nb_workers_commit == 0 and status.queue_commit.qsize() > 0 and status.queue_sha256.qsize() == 0 and status.queue_get_upload_mode.qsize() == 0 and status.queue_preupload_lfs.qsize() == 0 and status.nb_workers_sha256 == 0 and status.nb_workers_get_upload_mode == 0 and status.nb_workers_preupload_lfs == 0 ): status.nb_workers_commit += 1 logger.debug("Job: commit") return (WorkerJob.COMMIT, _get_items_to_commit(status.queue_commit)) # 12. If all queues are empty, exit elif all(metadata.is_committed or metadata.should_ignore for _, metadata in status.items): logger.info("All files have been processed! Exiting worker.") return None # 13. If no task is available, wait else: status.nb_workers_waiting += 1 logger.debug(f"No task available, waiting... ({WAITING_TIME_IF_NO_TASKS}s)") return (WorkerJob.WAIT, []) #################### # Atomic jobs (sha256, get_upload_mode, preupload_lfs, commit) #################### def _compute_sha256(item: JOB_ITEM_T) -> None: """Compute sha256 of a file and save it in metadata.""" paths, metadata = item if metadata.sha256 is None: with paths.file_path.open("rb") as f: metadata.sha256 = sha_fileobj(f).hex() metadata.save(paths) def _get_upload_mode(items: List[JOB_ITEM_T], api: "HfApi", repo_id: str, repo_type: str, revision: str) -> None: """Get upload mode for each file and update metadata. Also receive info if the file should be ignored. """ additions = [_build_hacky_operation(item) for item in items] _fetch_upload_modes( additions=additions, repo_type=repo_type, repo_id=repo_id, headers=api._build_hf_headers(), revision=quote(revision, safe=""), ) for item, addition in zip(items, additions): paths, metadata = item metadata.upload_mode = addition._upload_mode metadata.should_ignore = addition._should_ignore metadata.save(paths) def _preupload_lfs(item: JOB_ITEM_T, api: "HfApi", repo_id: str, repo_type: str, revision: str) -> None: """Preupload LFS file and update metadata.""" paths, metadata = item addition = _build_hacky_operation(item) api.preupload_lfs_files( repo_id=repo_id, repo_type=repo_type, revision=revision, additions=[addition], ) metadata.is_uploaded = True metadata.save(paths) def _commit(items: List[JOB_ITEM_T], api: "HfApi", repo_id: str, repo_type: str, revision: str) -> None: """Commit files to the repo.""" additions = [_build_hacky_operation(item) for item in items] api.create_commit( repo_id=repo_id, repo_type=repo_type, revision=revision, operations=additions, commit_message="Add files using upload-large-folder tool", ) for paths, metadata in items: metadata.is_committed = True metadata.save(paths) #################### # Hacks with CommitOperationAdd to bypass checks/sha256 calculation #################### class HackyCommitOperationAdd(CommitOperationAdd): def __post_init__(self) -> None: if isinstance(self.path_or_fileobj, Path): self.path_or_fileobj = str(self.path_or_fileobj) def _build_hacky_operation(item: JOB_ITEM_T) -> HackyCommitOperationAdd: paths, metadata = item operation = HackyCommitOperationAdd(path_in_repo=paths.path_in_repo, path_or_fileobj=paths.file_path) with paths.file_path.open("rb") as file: sample = file.peek(512)[:512] if metadata.sha256 is None: raise ValueError("sha256 must have been computed by now!") operation.upload_info = UploadInfo(sha256=bytes.fromhex(metadata.sha256), size=metadata.size, sample=sample) return operation #################### # Misc helpers #################### def _get_one(queue: "queue.Queue[JOB_ITEM_T]") -> List[JOB_ITEM_T]: return [queue.get()] def _get_n(queue: "queue.Queue[JOB_ITEM_T]", n: int) -> List[JOB_ITEM_T]: return [queue.get() for _ in range(min(queue.qsize(), n))] def _get_items_to_commit(queue: "queue.Queue[JOB_ITEM_T]") -> List[JOB_ITEM_T]: """Special case for commit job: the number of items to commit depends on the type of files.""" # Can take at most 50 regular files and/or 100 LFS files in a single commit items: List[JOB_ITEM_T] = [] nb_lfs, nb_regular = 0, 0 while True: # If empty queue => commit everything if queue.qsize() == 0: return items # If we have enough items => commit them if nb_lfs >= MAX_NB_LFS_FILES_PER_COMMIT or nb_regular >= MAX_NB_REGULAR_FILES_PER_COMMIT: return items # Else, get a new item and increase counter item = queue.get() items.append(item) _, metadata = item if metadata.upload_mode == "lfs": nb_lfs += 1 else: nb_regular += 1 def _print_overwrite(report: str) -> None: """Print a report, overwriting the previous lines. Since tqdm in using `sys.stderr` to (re-)write progress bars, we need to use `sys.stdout` to print the report. Note: works well only if no other process is writing to `sys.stdout`! """ report += "\n" # Get terminal width terminal_width = shutil.get_terminal_size().columns # Count number of lines that should be cleared nb_lines = sum(len(line) // terminal_width + 1 for line in report.splitlines()) # Clear previous lines based on the number of lines in the report for _ in range(nb_lines): sys.stdout.write("\r\033[K") # Clear line sys.stdout.write("\033[F") # Move cursor up one line # Print the new report, filling remaining space with whitespace sys.stdout.write(report) sys.stdout.write(" " * (terminal_width - len(report.splitlines()[-1]))) sys.stdout.flush() huggingface_hub-0.31.1/src/huggingface_hub/_webhooks_payload.py000066400000000000000000000070411500667546600246610ustar00rootroot00000000000000# coding=utf-8 # Copyright 2023-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains data structures to parse the webhooks payload.""" from typing import List, Literal, Optional from .utils import is_pydantic_available if is_pydantic_available(): from pydantic import BaseModel else: # Define a dummy BaseModel to avoid import errors when pydantic is not installed # Import error will be raised when trying to use the class class BaseModel: # type: ignore [no-redef] def __init__(self, *args, **kwargs) -> None: raise ImportError( "You must have `pydantic` installed to use `WebhookPayload`. This is an optional dependency that" " should be installed separately. Please run `pip install --upgrade pydantic` and retry." ) # This is an adaptation of the ReportV3 interface implemented in moon-landing. V0, V1 and V2 have been ignored as they # are not in used anymore. To keep in sync when format is updated in # https://github.com/huggingface/moon-landing/blob/main/server/lib/HFWebhooks.ts (internal link). WebhookEvent_T = Literal[ "create", "delete", "move", "update", ] RepoChangeEvent_T = Literal[ "add", "move", "remove", "update", ] RepoType_T = Literal[ "dataset", "model", "space", ] DiscussionStatus_T = Literal[ "closed", "draft", "open", "merged", ] SupportedWebhookVersion = Literal[3] class ObjectId(BaseModel): id: str class WebhookPayloadUrl(BaseModel): web: str api: Optional[str] = None class WebhookPayloadMovedTo(BaseModel): name: str owner: ObjectId class WebhookPayloadWebhook(ObjectId): version: SupportedWebhookVersion class WebhookPayloadEvent(BaseModel): action: WebhookEvent_T scope: str class WebhookPayloadDiscussionChanges(BaseModel): base: str mergeCommitId: Optional[str] = None class WebhookPayloadComment(ObjectId): author: ObjectId hidden: bool content: Optional[str] = None url: WebhookPayloadUrl class WebhookPayloadDiscussion(ObjectId): num: int author: ObjectId url: WebhookPayloadUrl title: str isPullRequest: bool status: DiscussionStatus_T changes: Optional[WebhookPayloadDiscussionChanges] = None pinned: Optional[bool] = None class WebhookPayloadRepo(ObjectId): owner: ObjectId head_sha: Optional[str] = None name: str private: bool subdomain: Optional[str] = None tags: Optional[List[str]] = None type: Literal["dataset", "model", "space"] url: WebhookPayloadUrl class WebhookPayloadUpdatedRef(BaseModel): ref: str oldSha: Optional[str] = None newSha: Optional[str] = None class WebhookPayload(BaseModel): event: WebhookPayloadEvent repo: WebhookPayloadRepo discussion: Optional[WebhookPayloadDiscussion] = None comment: Optional[WebhookPayloadComment] = None webhook: WebhookPayloadWebhook movedTo: Optional[WebhookPayloadMovedTo] = None updatedRefs: Optional[List[WebhookPayloadUpdatedRef]] = None huggingface_hub-0.31.1/src/huggingface_hub/_webhooks_server.py000066400000000000000000000366271500667546600245520ustar00rootroot00000000000000# coding=utf-8 # Copyright 2023-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains `WebhooksServer` and `webhook_endpoint` to create a webhook server easily.""" import atexit import inspect import os from functools import wraps from typing import TYPE_CHECKING, Any, Callable, Dict, Optional from .utils import experimental, is_fastapi_available, is_gradio_available if TYPE_CHECKING: import gradio as gr from fastapi import Request if is_fastapi_available(): from fastapi import FastAPI, Request from fastapi.responses import JSONResponse else: # Will fail at runtime if FastAPI is not available FastAPI = Request = JSONResponse = None # type: ignore [misc, assignment] _global_app: Optional["WebhooksServer"] = None _is_local = os.environ.get("SPACE_ID") is None @experimental class WebhooksServer: """ The [`WebhooksServer`] class lets you create an instance of a Gradio app that can receive Huggingface webhooks. These webhooks can be registered using the [`~WebhooksServer.add_webhook`] decorator. Webhook endpoints are added to the app as a POST endpoint to the FastAPI router. Once all the webhooks are registered, the `launch` method has to be called to start the app. It is recommended to accept [`WebhookPayload`] as the first argument of the webhook function. It is a Pydantic model that contains all the information about the webhook event. The data will be parsed automatically for you. Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to setup your WebhooksServer and deploy it on a Space. `WebhooksServer` is experimental. Its API is subject to change in the future. You must have `gradio` installed to use `WebhooksServer` (`pip install --upgrade gradio`). Args: ui (`gradio.Blocks`, optional): A Gradio UI instance to be used as the Space landing page. If `None`, a UI displaying instructions about the configured webhooks is created. webhook_secret (`str`, optional): A secret key to verify incoming webhook requests. You can set this value to any secret you want as long as you also configure it in your [webhooks settings panel](https://huggingface.co/settings/webhooks). You can also set this value as the `WEBHOOK_SECRET` environment variable. If no secret is provided, the webhook endpoints are opened without any security. Example: ```python import gradio as gr from huggingface_hub import WebhooksServer, WebhookPayload with gr.Blocks() as ui: ... app = WebhooksServer(ui=ui, webhook_secret="my_secret_key") @app.add_webhook("/say_hello") async def hello(payload: WebhookPayload): return {"message": "hello"} app.launch() ``` """ def __new__(cls, *args, **kwargs) -> "WebhooksServer": if not is_gradio_available(): raise ImportError( "You must have `gradio` installed to use `WebhooksServer`. Please run `pip install --upgrade gradio`" " first." ) if not is_fastapi_available(): raise ImportError( "You must have `fastapi` installed to use `WebhooksServer`. Please run `pip install --upgrade fastapi`" " first." ) return super().__new__(cls) def __init__( self, ui: Optional["gr.Blocks"] = None, webhook_secret: Optional[str] = None, ) -> None: self._ui = ui self.webhook_secret = webhook_secret or os.getenv("WEBHOOK_SECRET") self.registered_webhooks: Dict[str, Callable] = {} _warn_on_empty_secret(self.webhook_secret) def add_webhook(self, path: Optional[str] = None) -> Callable: """ Decorator to add a webhook to the [`WebhooksServer`] server. Args: path (`str`, optional): The URL path to register the webhook function. If not provided, the function name will be used as the path. In any case, all webhooks are registered under `/webhooks`. Raises: ValueError: If the provided path is already registered as a webhook. Example: ```python from huggingface_hub import WebhooksServer, WebhookPayload app = WebhooksServer() @app.add_webhook async def trigger_training(payload: WebhookPayload): if payload.repo.type == "dataset" and payload.event.action == "update": # Trigger a training job if a dataset is updated ... app.launch() ``` """ # Usage: directly as decorator. Example: `@app.add_webhook` if callable(path): # If path is a function, it means it was used as a decorator without arguments return self.add_webhook()(path) # Usage: provide a path. Example: `@app.add_webhook(...)` @wraps(FastAPI.post) def _inner_post(*args, **kwargs): func = args[0] abs_path = f"/webhooks/{(path or func.__name__).strip('/')}" if abs_path in self.registered_webhooks: raise ValueError(f"Webhook {abs_path} already exists.") self.registered_webhooks[abs_path] = func return _inner_post def launch(self, prevent_thread_lock: bool = False, **launch_kwargs: Any) -> None: """Launch the Gradio app and register webhooks to the underlying FastAPI server. Input parameters are forwarded to Gradio when launching the app. """ ui = self._ui or self._get_default_ui() # Start Gradio App # - as non-blocking so that webhooks can be added afterwards # - as shared if launch locally (to debug webhooks) launch_kwargs.setdefault("share", _is_local) self.fastapi_app, _, _ = ui.launch(prevent_thread_lock=True, **launch_kwargs) # Register webhooks to FastAPI app for path, func in self.registered_webhooks.items(): # Add secret check if required if self.webhook_secret is not None: func = _wrap_webhook_to_check_secret(func, webhook_secret=self.webhook_secret) # Add route to FastAPI app self.fastapi_app.post(path)(func) # Print instructions and block main thread space_host = os.environ.get("SPACE_HOST") url = "https://" + space_host if space_host is not None else (ui.share_url or ui.local_url) if url is None: raise ValueError("Cannot find the URL of the app. Please provide a valid `ui` or update `gradio` version.") url = url.strip("/") message = "\nWebhooks are correctly setup and ready to use:" message += "\n" + "\n".join(f" - POST {url}{webhook}" for webhook in self.registered_webhooks) message += "\nGo to https://huggingface.co/settings/webhooks to setup your webhooks." print(message) if not prevent_thread_lock: ui.block_thread() def _get_default_ui(self) -> "gr.Blocks": """Default UI if not provided (lists webhooks and provides basic instructions).""" import gradio as gr with gr.Blocks() as ui: gr.Markdown("# This is an app to process 🤗 Webhooks") gr.Markdown( "Webhooks are a foundation for MLOps-related features. They allow you to listen for new changes on" " specific repos or to all repos belonging to particular set of users/organizations (not just your" " repos, but any repo). Check out this [guide](https://huggingface.co/docs/hub/webhooks) to get to" " know more about webhooks on the Huggingface Hub." ) gr.Markdown( f"{len(self.registered_webhooks)} webhook(s) are registered:" + "\n\n" + "\n ".join( f"- [{webhook_path}]({_get_webhook_doc_url(webhook.__name__, webhook_path)})" for webhook_path, webhook in self.registered_webhooks.items() ) ) gr.Markdown( "Go to https://huggingface.co/settings/webhooks to setup your webhooks." + "\nYou app is running locally. Please look at the logs to check the full URL you need to set." if _is_local else ( "\nThis app is running on a Space. You can find the corresponding URL in the options menu" " (top-right) > 'Embed the Space'. The URL looks like 'https://{username}-{repo_name}.hf.space'." ) ) return ui @experimental def webhook_endpoint(path: Optional[str] = None) -> Callable: """Decorator to start a [`WebhooksServer`] and register the decorated function as a webhook endpoint. This is a helper to get started quickly. If you need more flexibility (custom landing page or webhook secret), you can use [`WebhooksServer`] directly. You can register multiple webhook endpoints (to the same server) by using this decorator multiple times. Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to setup your server and deploy it on a Space. `webhook_endpoint` is experimental. Its API is subject to change in the future. You must have `gradio` installed to use `webhook_endpoint` (`pip install --upgrade gradio`). Args: path (`str`, optional): The URL path to register the webhook function. If not provided, the function name will be used as the path. In any case, all webhooks are registered under `/webhooks`. Examples: The default usage is to register a function as a webhook endpoint. The function name will be used as the path. The server will be started automatically at exit (i.e. at the end of the script). ```python from huggingface_hub import webhook_endpoint, WebhookPayload @webhook_endpoint async def trigger_training(payload: WebhookPayload): if payload.repo.type == "dataset" and payload.event.action == "update": # Trigger a training job if a dataset is updated ... # Server is automatically started at the end of the script. ``` Advanced usage: register a function as a webhook endpoint and start the server manually. This is useful if you are running it in a notebook. ```python from huggingface_hub import webhook_endpoint, WebhookPayload @webhook_endpoint async def trigger_training(payload: WebhookPayload): if payload.repo.type == "dataset" and payload.event.action == "update": # Trigger a training job if a dataset is updated ... # Start the server manually trigger_training.launch() ``` """ if callable(path): # If path is a function, it means it was used as a decorator without arguments return webhook_endpoint()(path) @wraps(WebhooksServer.add_webhook) def _inner(func: Callable) -> Callable: app = _get_global_app() app.add_webhook(path)(func) if len(app.registered_webhooks) == 1: # Register `app.launch` to run at exit (only once) atexit.register(app.launch) @wraps(app.launch) def _launch_now(): # Run the app directly (without waiting atexit) atexit.unregister(app.launch) app.launch() func.launch = _launch_now # type: ignore return func return _inner def _get_global_app() -> WebhooksServer: global _global_app if _global_app is None: _global_app = WebhooksServer() return _global_app def _warn_on_empty_secret(webhook_secret: Optional[str]) -> None: if webhook_secret is None: print("Webhook secret is not defined. This means your webhook endpoints will be open to everyone.") print( "To add a secret, set `WEBHOOK_SECRET` as environment variable or pass it at initialization: " "\n\t`app = WebhooksServer(webhook_secret='my_secret', ...)`" ) print( "For more details about webhook secrets, please refer to" " https://huggingface.co/docs/hub/webhooks#webhook-secret." ) else: print("Webhook secret is correctly defined.") def _get_webhook_doc_url(webhook_name: str, webhook_path: str) -> str: """Returns the anchor to a given webhook in the docs (experimental)""" return "/docs#/default/" + webhook_name + webhook_path.replace("/", "_") + "_post" def _wrap_webhook_to_check_secret(func: Callable, webhook_secret: str) -> Callable: """Wraps a webhook function to check the webhook secret before calling the function. This is a hacky way to add the `request` parameter to the function signature. Since FastAPI based itself on route parameters to inject the values to the function, we need to hack the function signature to retrieve the `Request` object (and hence the headers). A far cleaner solution would be to use a middleware. However, since `fastapi==0.90.1`, a middleware cannot be added once the app has started. And since the FastAPI app is started by Gradio internals (and not by us), we cannot add a middleware. This method is called only when a secret has been defined by the user. If a request is sent without the "x-webhook-secret", the function will return a 401 error (unauthorized). If the header is sent but is incorrect, the function will return a 403 error (forbidden). Inspired by https://stackoverflow.com/a/33112180. """ initial_sig = inspect.signature(func) @wraps(func) async def _protected_func(request: Request, **kwargs): request_secret = request.headers.get("x-webhook-secret") if request_secret is None: return JSONResponse({"error": "x-webhook-secret header not set."}, status_code=401) if request_secret != webhook_secret: return JSONResponse({"error": "Invalid webhook secret."}, status_code=403) # Inject `request` in kwargs if required if "request" in initial_sig.parameters: kwargs["request"] = request # Handle both sync and async routes if inspect.iscoroutinefunction(func): return await func(**kwargs) else: return func(**kwargs) # Update signature to include request if "request" not in initial_sig.parameters: _protected_func.__signature__ = initial_sig.replace( # type: ignore parameters=( inspect.Parameter(name="request", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request), ) + tuple(initial_sig.parameters.values()) ) # Return protected route return _protected_func huggingface_hub-0.31.1/src/huggingface_hub/commands/000077500000000000000000000000001500667546600224155ustar00rootroot00000000000000huggingface_hub-0.31.1/src/huggingface_hub/commands/__init__.py000066400000000000000000000016401500667546600245270ustar00rootroot00000000000000# Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod from argparse import _SubParsersAction class BaseHuggingfaceCLICommand(ABC): @staticmethod @abstractmethod def register_subcommand(parser: _SubParsersAction): raise NotImplementedError() @abstractmethod def run(self): raise NotImplementedError() huggingface_hub-0.31.1/src/huggingface_hub/commands/_cli_utils.py000066400000000000000000000040571500667546600251230ustar00rootroot00000000000000# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains a utility for good-looking prints.""" import os from typing import List, Union class ANSI: """ Helper for en.wikipedia.org/wiki/ANSI_escape_code """ _bold = "\u001b[1m" _gray = "\u001b[90m" _red = "\u001b[31m" _reset = "\u001b[0m" _yellow = "\u001b[33m" @classmethod def bold(cls, s: str) -> str: return cls._format(s, cls._bold) @classmethod def gray(cls, s: str) -> str: return cls._format(s, cls._gray) @classmethod def red(cls, s: str) -> str: return cls._format(s, cls._bold + cls._red) @classmethod def yellow(cls, s: str) -> str: return cls._format(s, cls._yellow) @classmethod def _format(cls, s: str, code: str) -> str: if os.environ.get("NO_COLOR"): # See https://no-color.org/ return s return f"{code}{s}{cls._reset}" def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str: """ Inspired by: - stackoverflow.com/a/8356620/593036 - stackoverflow.com/questions/9535954/printing-lists-as-tabular-data """ col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)] row_format = ("{{:{}}} " * len(headers)).format(*col_widths) lines = [] lines.append(row_format.format(*headers)) lines.append(row_format.format(*["-" * w for w in col_widths])) for row in rows: lines.append(row_format.format(*row)) return "\n".join(lines) huggingface_hub-0.31.1/src/huggingface_hub/commands/delete_cache.py000066400000000000000000000423271500667546600253640ustar00rootroot00000000000000# coding=utf-8 # Copyright 2022-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains command to delete some revisions from the HF cache directory. Usage: huggingface-cli delete-cache huggingface-cli delete-cache --disable-tui huggingface-cli delete-cache --dir ~/.cache/huggingface/hub huggingface-cli delete-cache --sort=size NOTE: This command is based on `InquirerPy` to build the multiselect menu in the terminal. This dependency has to be installed with `pip install huggingface_hub[cli]`. Since we want to avoid as much as possible cross-platform issues, I chose a library that is built on top of `python-prompt-toolkit` which seems to be a reference in terminal GUI (actively maintained on both Unix and Windows, 7.9k stars). For the moment, the TUI feature is in beta. See: - https://github.com/kazhala/InquirerPy - https://inquirerpy.readthedocs.io/en/latest/ - https://github.com/prompt-toolkit/python-prompt-toolkit Other solutions could have been: - `simple_term_menu`: would be good as well for our use case but some issues suggest that Windows is less supported. See: https://github.com/IngoMeyer441/simple-term-menu - `PyInquirer`: very similar to `InquirerPy` but older and not maintained anymore. In particular, no support of Python3.10. See: https://github.com/CITGuru/PyInquirer - `pick` (or `pickpack`): easy to use and flexible but built on top of Python's standard library `curses` that is specific to Unix (not implemented on Windows). See https://github.com/wong2/pick and https://github.com/anafvana/pickpack. - `inquirer`: lot of traction (700 stars) but explicitly states "experimental support of Windows". Not built on top of `python-prompt-toolkit`. See https://github.com/magmax/python-inquirer TODO: add support for `huggingface-cli delete-cache aaaaaa bbbbbb cccccc (...)` ? TODO: add "--keep-last" arg to delete revisions that are not on `main` ref TODO: add "--filter" arg to filter repositories by name ? TODO: add "--limit" arg to limit to X repos ? TODO: add "-y" arg for immediate deletion ? See discussions in https://github.com/huggingface/huggingface_hub/issues/1025. """ import os from argparse import Namespace, _SubParsersAction from functools import wraps from tempfile import mkstemp from typing import Any, Callable, Iterable, List, Literal, Optional, Union from ..utils import CachedRepoInfo, CachedRevisionInfo, HFCacheInfo, scan_cache_dir from . import BaseHuggingfaceCLICommand from ._cli_utils import ANSI try: from InquirerPy import inquirer from InquirerPy.base.control import Choice from InquirerPy.separator import Separator _inquirer_py_available = True except ImportError: _inquirer_py_available = False SortingOption_T = Literal["alphabetical", "lastUpdated", "lastUsed", "size"] def require_inquirer_py(fn: Callable) -> Callable: """Decorator to flag methods that require `InquirerPy`.""" # TODO: refactor this + imports in a unified pattern across codebase @wraps(fn) def _inner(*args, **kwargs): if not _inquirer_py_available: raise ImportError( "The `delete-cache` command requires extra dependencies to work with" " the TUI.\nPlease run `pip install huggingface_hub[cli]` to install" " them.\nOtherwise, disable TUI using the `--disable-tui` flag." ) return fn(*args, **kwargs) return _inner # Possibility for the user to cancel deletion _CANCEL_DELETION_STR = "CANCEL_DELETION" class DeleteCacheCommand(BaseHuggingfaceCLICommand): @staticmethod def register_subcommand(parser: _SubParsersAction): delete_cache_parser = parser.add_parser("delete-cache", help="Delete revisions from the cache directory.") delete_cache_parser.add_argument( "--dir", type=str, default=None, help="cache directory (optional). Default to the default HuggingFace cache.", ) delete_cache_parser.add_argument( "--disable-tui", action="store_true", help=( "Disable Terminal User Interface (TUI) mode. Useful if your" " platform/terminal doesn't support the multiselect menu." ), ) delete_cache_parser.add_argument( "--sort", nargs="?", choices=["alphabetical", "lastUpdated", "lastUsed", "size"], help=( "Sort repositories by the specified criteria. Options: " "'alphabetical' (A-Z), " "'lastUpdated' (newest first), " "'lastUsed' (most recent first), " "'size' (largest first)." ), ) delete_cache_parser.set_defaults(func=DeleteCacheCommand) def __init__(self, args: Namespace) -> None: self.cache_dir: Optional[str] = args.dir self.disable_tui: bool = args.disable_tui self.sort_by: Optional[SortingOption_T] = args.sort def run(self): """Run `delete-cache` command with or without TUI.""" # Scan cache directory hf_cache_info = scan_cache_dir(self.cache_dir) # Manual review from the user if self.disable_tui: selected_hashes = _manual_review_no_tui(hf_cache_info, preselected=[], sort_by=self.sort_by) else: selected_hashes = _manual_review_tui(hf_cache_info, preselected=[], sort_by=self.sort_by) # If deletion is not cancelled if len(selected_hashes) > 0 and _CANCEL_DELETION_STR not in selected_hashes: confirm_message = _get_expectations_str(hf_cache_info, selected_hashes) + " Confirm deletion ?" # Confirm deletion if self.disable_tui: confirmed = _ask_for_confirmation_no_tui(confirm_message) else: confirmed = _ask_for_confirmation_tui(confirm_message) # Deletion is confirmed if confirmed: strategy = hf_cache_info.delete_revisions(*selected_hashes) print("Start deletion.") strategy.execute() print( f"Done. Deleted {len(strategy.repos)} repo(s) and" f" {len(strategy.snapshots)} revision(s) for a total of" f" {strategy.expected_freed_size_str}." ) return # Deletion is cancelled print("Deletion is cancelled. Do nothing.") def _get_repo_sorting_key(repo: CachedRepoInfo, sort_by: Optional[SortingOption_T] = None): if sort_by == "alphabetical": return (repo.repo_type, repo.repo_id.lower()) # by type then name elif sort_by == "lastUpdated": return -max(rev.last_modified for rev in repo.revisions) # newest first elif sort_by == "lastUsed": return -repo.last_accessed # most recently used first elif sort_by == "size": return -repo.size_on_disk # largest first else: return (repo.repo_type, repo.repo_id) # default stable order @require_inquirer_py def _manual_review_tui( hf_cache_info: HFCacheInfo, preselected: List[str], sort_by: Optional[SortingOption_T] = None, ) -> List[str]: """Ask the user for a manual review of the revisions to delete. Displays a multi-select menu in the terminal (TUI). """ # Define multiselect list choices = _get_tui_choices_from_scan( repos=hf_cache_info.repos, preselected=preselected, sort_by=sort_by, ) checkbox = inquirer.checkbox( message="Select revisions to delete:", choices=choices, # List of revisions with some pre-selection cycle=False, # No loop between top and bottom height=100, # Large list if possible # We use the instruction to display to the user the expected effect of the # deletion. instruction=_get_expectations_str( hf_cache_info, selected_hashes=[c.value for c in choices if isinstance(c, Choice) and c.enabled], ), # We use the long instruction to should keybindings instructions to the user long_instruction="Press to select, to validate and to quit without modification.", # Message that is displayed once the user validates its selection. transformer=lambda result: f"{len(result)} revision(s) selected.", ) # Add a callback to update the information line when a revision is # selected/unselected def _update_expectations(_) -> None: # Hacky way to dynamically set an instruction message to the checkbox when # a revision hash is selected/unselected. checkbox._instruction = _get_expectations_str( hf_cache_info, selected_hashes=[choice["value"] for choice in checkbox.content_control.choices if choice["enabled"]], ) checkbox.kb_func_lookup["toggle"].append({"func": _update_expectations}) # Finally display the form to the user. try: return checkbox.execute() except KeyboardInterrupt: return [] # Quit without deletion @require_inquirer_py def _ask_for_confirmation_tui(message: str, default: bool = True) -> bool: """Ask for confirmation using Inquirer.""" return inquirer.confirm(message, default=default).execute() def _get_tui_choices_from_scan( repos: Iterable[CachedRepoInfo], preselected: List[str], sort_by: Optional[SortingOption_T] = None, ) -> List: """Build a list of choices from the scanned repos. Args: repos (*Iterable[`CachedRepoInfo`]*): List of scanned repos on which we want to delete revisions. preselected (*List[`str`]*): List of revision hashes that will be preselected. sort_by (*Optional[SortingOption_T]*): Sorting direction. Choices: "alphabetical", "lastUpdated", "lastUsed", "size". Return: The list of choices to pass to `inquirer.checkbox`. """ choices: List[Union[Choice, Separator]] = [] # First choice is to cancel the deletion choices.append( Choice( _CANCEL_DELETION_STR, name="None of the following (if selected, nothing will be deleted).", enabled=False, ) ) # Sort repos based on specified criteria sorted_repos = sorted(repos, key=lambda repo: _get_repo_sorting_key(repo, sort_by)) for repo in sorted_repos: # Repo as separator choices.append( Separator( f"\n{repo.repo_type.capitalize()} {repo.repo_id} ({repo.size_on_disk_str}," f" used {repo.last_accessed_str})" ) ) for revision in sorted(repo.revisions, key=_revision_sorting_order): # Revision as choice choices.append( Choice( revision.commit_hash, name=( f"{revision.commit_hash[:8]}:" f" {', '.join(sorted(revision.refs)) or '(detached)'} #" f" modified {revision.last_modified_str}" ), enabled=revision.commit_hash in preselected, ) ) # Return choices return choices def _manual_review_no_tui( hf_cache_info: HFCacheInfo, preselected: List[str], sort_by: Optional[SortingOption_T] = None, ) -> List[str]: """Ask the user for a manual review of the revisions to delete. Used when TUI is disabled. Manual review happens in a separate tmp file that the user can manually edit. """ # 1. Generate temporary file with delete commands. fd, tmp_path = mkstemp(suffix=".txt") # suffix to make it easier to find by editors os.close(fd) lines = [] sorted_repos = sorted(hf_cache_info.repos, key=lambda repo: _get_repo_sorting_key(repo, sort_by)) for repo in sorted_repos: lines.append( f"\n# {repo.repo_type.capitalize()} {repo.repo_id} ({repo.size_on_disk_str}," f" used {repo.last_accessed_str})" ) for revision in sorted(repo.revisions, key=_revision_sorting_order): lines.append( # Deselect by prepending a '#' f"{'' if revision.commit_hash in preselected else '#'} " f" {revision.commit_hash} # Refs:" # Print `refs` as comment on same line f" {', '.join(sorted(revision.refs)) or '(detached)'} # modified" # Print `last_modified` as comment on same line f" {revision.last_modified_str}" ) with open(tmp_path, "w") as f: f.write(_MANUAL_REVIEW_NO_TUI_INSTRUCTIONS) f.write("\n".join(lines)) # 2. Prompt instructions to user. instructions = f""" TUI is disabled. In order to select which revisions you want to delete, please edit the following file using the text editor of your choice. Instructions for manual editing are located at the beginning of the file. Edit the file, save it and confirm to continue. File to edit: {ANSI.bold(tmp_path)} """ print("\n".join(line.strip() for line in instructions.strip().split("\n"))) # 3. Wait for user confirmation. while True: selected_hashes = _read_manual_review_tmp_file(tmp_path) if _ask_for_confirmation_no_tui( _get_expectations_str(hf_cache_info, selected_hashes) + " Continue ?", default=False, ): break # 4. Return selected_hashes sorted to maintain stable order os.remove(tmp_path) return sorted(selected_hashes) # Sort to maintain stable order def _ask_for_confirmation_no_tui(message: str, default: bool = True) -> bool: """Ask for confirmation using pure-python.""" YES = ("y", "yes", "1") NO = ("n", "no", "0") DEFAULT = "" ALL = YES + NO + (DEFAULT,) full_message = message + (" (Y/n) " if default else " (y/N) ") while True: answer = input(full_message).lower() if answer == DEFAULT: return default if answer in YES: return True if answer in NO: return False print(f"Invalid input. Must be one of {ALL}") def _get_expectations_str(hf_cache_info: HFCacheInfo, selected_hashes: List[str]) -> str: """Format a string to display to the user how much space would be saved. Example: ``` >>> _get_expectations_str(hf_cache_info, selected_hashes) '7 revisions selected counting for 4.3G.' ``` """ if _CANCEL_DELETION_STR in selected_hashes: return "Nothing will be deleted." strategy = hf_cache_info.delete_revisions(*selected_hashes) return f"{len(selected_hashes)} revisions selected counting for {strategy.expected_freed_size_str}." def _read_manual_review_tmp_file(tmp_path: str) -> List[str]: """Read the manually reviewed instruction file and return a list of revision hash. Example: ```txt # This is the tmp file content ### # Commented out line 123456789 # revision hash # Something else # a_newer_hash # 2 days ago an_older_hash # 3 days ago ``` ```py >>> _read_manual_review_tmp_file(tmp_path) ['123456789', 'an_older_hash'] ``` """ with open(tmp_path) as f: content = f.read() # Split lines lines = [line.strip() for line in content.split("\n")] # Filter commented lines selected_lines = [line for line in lines if not line.startswith("#")] # Select only before comment selected_hashes = [line.split("#")[0].strip() for line in selected_lines] # Return revision hashes return [hash for hash in selected_hashes if len(hash) > 0] _MANUAL_REVIEW_NO_TUI_INSTRUCTIONS = f""" # INSTRUCTIONS # ------------ # This is a temporary file created by running `huggingface-cli delete-cache` with the # `--disable-tui` option. It contains a set of revisions that can be deleted from your # local cache directory. # # Please manually review the revisions you want to delete: # - Revision hashes can be commented out with '#'. # - Only non-commented revisions in this file will be deleted. # - Revision hashes that are removed from this file are ignored as well. # - If `{_CANCEL_DELETION_STR}` line is uncommented, the all cache deletion is cancelled and # no changes will be applied. # # Once you've manually reviewed this file, please confirm deletion in the terminal. This # file will be automatically removed once done. # ------------ # KILL SWITCH # ------------ # Un-comment following line to completely cancel the deletion process # {_CANCEL_DELETION_STR} # ------------ # REVISIONS # ------------ """.strip() def _revision_sorting_order(revision: CachedRevisionInfo) -> Any: # Sort by last modified (oldest first) return revision.last_modified huggingface_hub-0.31.1/src/huggingface_hub/commands/download.py000066400000000000000000000177671500667546600246200ustar00rootroot00000000000000# coding=utf-8 # Copyright 2023-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains command to download files from the Hub with the CLI. Usage: huggingface-cli download --help # Download file huggingface-cli download gpt2 config.json # Download entire repo huggingface-cli download fffiloni/zeroscope --repo-type=space --revision=refs/pr/78 # Download repo with filters huggingface-cli download gpt2 --include="*.safetensors" # Download with token huggingface-cli download Wauplin/private-model --token=hf_*** # Download quietly (no progress bar, no warnings, only the returned path) huggingface-cli download gpt2 config.json --quiet # Download to local dir huggingface-cli download gpt2 --local-dir=./models/gpt2 """ import warnings from argparse import Namespace, _SubParsersAction from typing import List, Optional from huggingface_hub import logging from huggingface_hub._snapshot_download import snapshot_download from huggingface_hub.commands import BaseHuggingfaceCLICommand from huggingface_hub.file_download import hf_hub_download from huggingface_hub.utils import disable_progress_bars, enable_progress_bars logger = logging.get_logger(__name__) class DownloadCommand(BaseHuggingfaceCLICommand): @staticmethod def register_subcommand(parser: _SubParsersAction): download_parser = parser.add_parser("download", help="Download files from the Hub") download_parser.add_argument( "repo_id", type=str, help="ID of the repo to download from (e.g. `username/repo-name`)." ) download_parser.add_argument( "filenames", type=str, nargs="*", help="Files to download (e.g. `config.json`, `data/metadata.jsonl`)." ) download_parser.add_argument( "--repo-type", choices=["model", "dataset", "space"], default="model", help="Type of repo to download from (defaults to 'model').", ) download_parser.add_argument( "--revision", type=str, help="An optional Git revision id which can be a branch name, a tag, or a commit hash.", ) download_parser.add_argument( "--include", nargs="*", type=str, help="Glob patterns to match files to download." ) download_parser.add_argument( "--exclude", nargs="*", type=str, help="Glob patterns to exclude from files to download." ) download_parser.add_argument( "--cache-dir", type=str, help="Path to the directory where to save the downloaded files." ) download_parser.add_argument( "--local-dir", type=str, help=( "If set, the downloaded file will be placed under this directory. Check out" " https://huggingface.co/docs/huggingface_hub/guides/download#download-files-to-local-folder for more" " details." ), ) download_parser.add_argument( "--local-dir-use-symlinks", choices=["auto", "True", "False"], help=("Deprecated and ignored. Downloading to a local directory does not use symlinks anymore."), ) download_parser.add_argument( "--force-download", action="store_true", help="If True, the files will be downloaded even if they are already cached.", ) download_parser.add_argument( "--resume-download", action="store_true", help="Deprecated and ignored. Downloading a file to local dir always attempts to resume previously interrupted downloads (unless hf-transfer is enabled).", ) download_parser.add_argument( "--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens" ) download_parser.add_argument( "--quiet", action="store_true", help="If True, progress bars are disabled and only the path to the download files is printed.", ) download_parser.add_argument( "--max-workers", type=int, default=8, help="Maximum number of workers to use for downloading files. Default is 8.", ) download_parser.set_defaults(func=DownloadCommand) def __init__(self, args: Namespace) -> None: self.token = args.token self.repo_id: str = args.repo_id self.filenames: List[str] = args.filenames self.repo_type: str = args.repo_type self.revision: Optional[str] = args.revision self.include: Optional[List[str]] = args.include self.exclude: Optional[List[str]] = args.exclude self.cache_dir: Optional[str] = args.cache_dir self.local_dir: Optional[str] = args.local_dir self.force_download: bool = args.force_download self.resume_download: Optional[bool] = args.resume_download or None self.quiet: bool = args.quiet self.max_workers: int = args.max_workers if args.local_dir_use_symlinks is not None: warnings.warn( "Ignoring --local-dir-use-symlinks. Downloading to a local directory does not use symlinks anymore.", FutureWarning, ) def run(self) -> None: if self.quiet: disable_progress_bars() with warnings.catch_warnings(): warnings.simplefilter("ignore") print(self._download()) # Print path to downloaded files enable_progress_bars() else: logging.set_verbosity_info() print(self._download()) # Print path to downloaded files logging.set_verbosity_warning() def _download(self) -> str: # Warn user if patterns are ignored if len(self.filenames) > 0: if self.include is not None and len(self.include) > 0: warnings.warn("Ignoring `--include` since filenames have being explicitly set.") if self.exclude is not None and len(self.exclude) > 0: warnings.warn("Ignoring `--exclude` since filenames have being explicitly set.") # Single file to download: use `hf_hub_download` if len(self.filenames) == 1: return hf_hub_download( repo_id=self.repo_id, repo_type=self.repo_type, revision=self.revision, filename=self.filenames[0], cache_dir=self.cache_dir, resume_download=self.resume_download, force_download=self.force_download, token=self.token, local_dir=self.local_dir, library_name="huggingface-cli", ) # Otherwise: use `snapshot_download` to ensure all files comes from same revision elif len(self.filenames) == 0: allow_patterns = self.include ignore_patterns = self.exclude else: allow_patterns = self.filenames ignore_patterns = None return snapshot_download( repo_id=self.repo_id, repo_type=self.repo_type, revision=self.revision, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, resume_download=self.resume_download, force_download=self.force_download, cache_dir=self.cache_dir, token=self.token, local_dir=self.local_dir, library_name="huggingface-cli", max_workers=self.max_workers, ) huggingface_hub-0.31.1/src/huggingface_hub/commands/env.py000066400000000000000000000023121500667546600235550ustar00rootroot00000000000000# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains command to print information about the environment. Usage: huggingface-cli env """ from argparse import _SubParsersAction from ..utils import dump_environment_info from . import BaseHuggingfaceCLICommand class EnvironmentCommand(BaseHuggingfaceCLICommand): def __init__(self, args): self.args = args @staticmethod def register_subcommand(parser: _SubParsersAction): env_parser = parser.add_parser("env", help="Print information about the environment.") env_parser.set_defaults(func=EnvironmentCommand) def run(self) -> None: dump_environment_info() huggingface_hub-0.31.1/src/huggingface_hub/commands/huggingface_cli.py000066400000000000000000000045561500667546600260770ustar00rootroot00000000000000# Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from argparse import ArgumentParser from huggingface_hub.commands.delete_cache import DeleteCacheCommand from huggingface_hub.commands.download import DownloadCommand from huggingface_hub.commands.env import EnvironmentCommand from huggingface_hub.commands.lfs import LfsCommands from huggingface_hub.commands.repo_files import RepoFilesCommand from huggingface_hub.commands.scan_cache import ScanCacheCommand from huggingface_hub.commands.tag import TagCommands from huggingface_hub.commands.upload import UploadCommand from huggingface_hub.commands.upload_large_folder import UploadLargeFolderCommand from huggingface_hub.commands.user import UserCommands from huggingface_hub.commands.version import VersionCommand def main(): parser = ArgumentParser("huggingface-cli", usage="huggingface-cli []") commands_parser = parser.add_subparsers(help="huggingface-cli command helpers") # Register commands DownloadCommand.register_subcommand(commands_parser) UploadCommand.register_subcommand(commands_parser) RepoFilesCommand.register_subcommand(commands_parser) EnvironmentCommand.register_subcommand(commands_parser) UserCommands.register_subcommand(commands_parser) LfsCommands.register_subcommand(commands_parser) ScanCacheCommand.register_subcommand(commands_parser) DeleteCacheCommand.register_subcommand(commands_parser) TagCommands.register_subcommand(commands_parser) VersionCommand.register_subcommand(commands_parser) # Experimental UploadLargeFolderCommand.register_subcommand(commands_parser) # Let's go args = parser.parse_args() if not hasattr(args, "func"): parser.print_help() exit(1) # Run service = args.func(args) service.run() if __name__ == "__main__": main() huggingface_hub-0.31.1/src/huggingface_hub/commands/lfs.py000066400000000000000000000162561500667546600235650ustar00rootroot00000000000000""" Implementation of a custom transfer agent for the transfer type "multipart" for git-lfs. Inspired by: github.com/cbartz/git-lfs-swift-transfer-agent/blob/master/git_lfs_swift_transfer.py Spec is: github.com/git-lfs/git-lfs/blob/master/docs/custom-transfers.md To launch debugger while developing: ``` [lfs "customtransfer.multipart"] path = /path/to/huggingface_hub/.env/bin/python args = -m debugpy --listen 5678 --wait-for-client /path/to/huggingface_hub/src/huggingface_hub/commands/huggingface_cli.py lfs-multipart-upload ```""" import json import os import subprocess import sys from argparse import _SubParsersAction from typing import Dict, List, Optional from huggingface_hub.commands import BaseHuggingfaceCLICommand from huggingface_hub.lfs import LFS_MULTIPART_UPLOAD_COMMAND from ..utils import get_session, hf_raise_for_status, logging from ..utils._lfs import SliceFileObj logger = logging.get_logger(__name__) class LfsCommands(BaseHuggingfaceCLICommand): """ Implementation of a custom transfer agent for the transfer type "multipart" for git-lfs. This lets users upload large files >5GB 🔥. Spec for LFS custom transfer agent is: https://github.com/git-lfs/git-lfs/blob/master/docs/custom-transfers.md This introduces two commands to the CLI: 1. $ huggingface-cli lfs-enable-largefiles This should be executed once for each model repo that contains a model file >5GB. It's documented in the error message you get if you just try to git push a 5GB file without having enabled it before. 2. $ huggingface-cli lfs-multipart-upload This command is called by lfs directly and is not meant to be called by the user. """ @staticmethod def register_subcommand(parser: _SubParsersAction): enable_parser = parser.add_parser( "lfs-enable-largefiles", help="Configure your repository to enable upload of files > 5GB." ) enable_parser.add_argument("path", type=str, help="Local path to repository you want to configure.") enable_parser.set_defaults(func=lambda args: LfsEnableCommand(args)) # Command will get called by git-lfs, do not call it directly. upload_parser = parser.add_parser(LFS_MULTIPART_UPLOAD_COMMAND, add_help=False) upload_parser.set_defaults(func=lambda args: LfsUploadCommand(args)) class LfsEnableCommand: def __init__(self, args): self.args = args def run(self): local_path = os.path.abspath(self.args.path) if not os.path.isdir(local_path): print("This does not look like a valid git repo.") exit(1) subprocess.run( "git config lfs.customtransfer.multipart.path huggingface-cli".split(), check=True, cwd=local_path, ) subprocess.run( f"git config lfs.customtransfer.multipart.args {LFS_MULTIPART_UPLOAD_COMMAND}".split(), check=True, cwd=local_path, ) print("Local repo set up for largefiles") def write_msg(msg: Dict): """Write out the message in Line delimited JSON.""" msg_str = json.dumps(msg) + "\n" sys.stdout.write(msg_str) sys.stdout.flush() def read_msg() -> Optional[Dict]: """Read Line delimited JSON from stdin.""" msg = json.loads(sys.stdin.readline().strip()) if "terminate" in (msg.get("type"), msg.get("event")): # terminate message received return None if msg.get("event") not in ("download", "upload"): logger.critical("Received unexpected message") sys.exit(1) return msg class LfsUploadCommand: def __init__(self, args) -> None: self.args = args def run(self) -> None: # Immediately after invoking a custom transfer process, git-lfs # sends initiation data to the process over stdin. # This tells the process useful information about the configuration. init_msg = json.loads(sys.stdin.readline().strip()) if not (init_msg.get("event") == "init" and init_msg.get("operation") == "upload"): write_msg({"error": {"code": 32, "message": "Wrong lfs init operation"}}) sys.exit(1) # The transfer process should use the information it needs from the # initiation structure, and also perform any one-off setup tasks it # needs to do. It should then respond on stdout with a simple empty # confirmation structure, as follows: write_msg({}) # After the initiation exchange, git-lfs will send any number of # transfer requests to the stdin of the transfer process, in a serial sequence. while True: msg = read_msg() if msg is None: # When all transfers have been processed, git-lfs will send # a terminate event to the stdin of the transfer process. # On receiving this message the transfer process should # clean up and terminate. No response is expected. sys.exit(0) oid = msg["oid"] filepath = msg["path"] completion_url = msg["action"]["href"] header = msg["action"]["header"] chunk_size = int(header.pop("chunk_size")) presigned_urls: List[str] = list(header.values()) # Send a "started" progress event to allow other workers to start. # Otherwise they're delayed until first "progress" event is reported, # i.e. after the first 5GB by default (!) write_msg( { "event": "progress", "oid": oid, "bytesSoFar": 1, "bytesSinceLast": 0, } ) parts = [] with open(filepath, "rb") as file: for i, presigned_url in enumerate(presigned_urls): with SliceFileObj( file, seek_from=i * chunk_size, read_limit=chunk_size, ) as data: r = get_session().put(presigned_url, data=data) hf_raise_for_status(r) parts.append( { "etag": r.headers.get("etag"), "partNumber": i + 1, } ) # In order to support progress reporting while data is uploading / downloading, # the transfer process should post messages to stdout write_msg( { "event": "progress", "oid": oid, "bytesSoFar": (i + 1) * chunk_size, "bytesSinceLast": chunk_size, } ) # Not precise but that's ok. r = get_session().post( completion_url, json={ "oid": oid, "parts": parts, }, ) hf_raise_for_status(r) write_msg({"event": "complete", "oid": oid}) huggingface_hub-0.31.1/src/huggingface_hub/commands/repo_files.py000066400000000000000000000114731500667546600251240ustar00rootroot00000000000000# coding=utf-8 # Copyright 2023-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains command to update or delete files in a repository using the CLI. Usage: # delete all huggingface-cli repo-files delete "*" # delete single file huggingface-cli repo-files delete file.txt # delete single folder huggingface-cli repo-files delete folder/ # delete multiple huggingface-cli repo-files delete file.txt folder/ file2.txt # delete multiple patterns huggingface-cli repo-files delete file.txt "*.json" "folder/*.parquet" # delete from different revision / repo-type huggingface-cli repo-files delete file.txt --revision=refs/pr/1 --repo-type=dataset """ from argparse import _SubParsersAction from typing import List, Optional from huggingface_hub import logging from huggingface_hub.commands import BaseHuggingfaceCLICommand from huggingface_hub.hf_api import HfApi logger = logging.get_logger(__name__) class DeleteFilesSubCommand: def __init__(self, args) -> None: self.args = args self.repo_id: str = args.repo_id self.repo_type: Optional[str] = args.repo_type self.revision: Optional[str] = args.revision self.api: HfApi = HfApi(token=args.token, library_name="huggingface-cli") self.patterns: List[str] = args.patterns self.commit_message: Optional[str] = args.commit_message self.commit_description: Optional[str] = args.commit_description self.create_pr: bool = args.create_pr self.token: Optional[str] = args.token def run(self) -> None: logging.set_verbosity_info() url = self.api.delete_files( delete_patterns=self.patterns, repo_id=self.repo_id, repo_type=self.repo_type, revision=self.revision, commit_message=self.commit_message, commit_description=self.commit_description, create_pr=self.create_pr, ) print(f"Files correctly deleted from repo. Commit: {url}.") logging.set_verbosity_warning() class RepoFilesCommand(BaseHuggingfaceCLICommand): @staticmethod def register_subcommand(parser: _SubParsersAction): repo_files_parser = parser.add_parser("repo-files", help="Manage files in a repo on the Hub") repo_files_parser.add_argument( "repo_id", type=str, help="The ID of the repo to manage (e.g. `username/repo-name`)." ) repo_files_subparsers = repo_files_parser.add_subparsers( help="Action to execute against the files.", required=True, ) delete_subparser = repo_files_subparsers.add_parser( "delete", help="Delete files from a repo on the Hub", ) delete_subparser.set_defaults(func=lambda args: DeleteFilesSubCommand(args)) delete_subparser.add_argument( "patterns", nargs="+", type=str, help="Glob patterns to match files to delete.", ) delete_subparser.add_argument( "--repo-type", choices=["model", "dataset", "space"], default="model", help="Type of the repo to upload to (e.g. `dataset`).", ) delete_subparser.add_argument( "--revision", type=str, help=( "An optional Git revision to push to. It can be a branch name " "or a PR reference. If revision does not" " exist and `--create-pr` is not set, a branch will be automatically created." ), ) delete_subparser.add_argument( "--commit-message", type=str, help="The summary / title / first line of the generated commit." ) delete_subparser.add_argument( "--commit-description", type=str, help="The description of the generated commit." ) delete_subparser.add_argument( "--create-pr", action="store_true", help="Whether to create a new Pull Request for these changes." ) repo_files_parser.add_argument( "--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens", ) repo_files_parser.set_defaults(func=RepoFilesCommand) huggingface_hub-0.31.1/src/huggingface_hub/commands/scan_cache.py000066400000000000000000000205631500667546600250440ustar00rootroot00000000000000# coding=utf-8 # Copyright 2022-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains command to scan the HF cache directory. Usage: huggingface-cli scan-cache huggingface-cli scan-cache -v huggingface-cli scan-cache -vvv huggingface-cli scan-cache --dir ~/.cache/huggingface/hub """ import time from argparse import Namespace, _SubParsersAction from typing import Optional from ..utils import CacheNotFound, HFCacheInfo, scan_cache_dir from . import BaseHuggingfaceCLICommand from ._cli_utils import ANSI, tabulate class ScanCacheCommand(BaseHuggingfaceCLICommand): @staticmethod def register_subcommand(parser: _SubParsersAction): scan_cache_parser = parser.add_parser("scan-cache", help="Scan cache directory.") scan_cache_parser.add_argument( "--dir", type=str, default=None, help="cache directory to scan (optional). Default to the default HuggingFace cache.", ) scan_cache_parser.add_argument( "-v", "--verbose", action="count", default=0, help="show a more verbose output", ) scan_cache_parser.set_defaults(func=ScanCacheCommand) def __init__(self, args: Namespace) -> None: self.verbosity: int = args.verbose self.cache_dir: Optional[str] = args.dir def run(self): try: t0 = time.time() hf_cache_info = scan_cache_dir(self.cache_dir) t1 = time.time() except CacheNotFound as exc: cache_dir = exc.cache_dir print(f"Cache directory not found: {cache_dir}") return self._print_hf_cache_info_as_table(hf_cache_info) print( f"\nDone in {round(t1 - t0, 1)}s. Scanned {len(hf_cache_info.repos)} repo(s)" f" for a total of {ANSI.red(hf_cache_info.size_on_disk_str)}." ) if len(hf_cache_info.warnings) > 0: message = f"Got {len(hf_cache_info.warnings)} warning(s) while scanning." if self.verbosity >= 3: print(ANSI.gray(message)) for warning in hf_cache_info.warnings: print(ANSI.gray(warning)) else: print(ANSI.gray(message + " Use -vvv to print details.")) def _print_hf_cache_info_as_table(self, hf_cache_info: HFCacheInfo) -> None: print(get_table(hf_cache_info, verbosity=self.verbosity)) def get_table(hf_cache_info: HFCacheInfo, *, verbosity: int = 0) -> str: """Generate a table from the [`HFCacheInfo`] object. Pass `verbosity=0` to get a table with a single row per repo, with columns "repo_id", "repo_type", "size_on_disk", "nb_files", "last_accessed", "last_modified", "refs", "local_path". Pass `verbosity=1` to get a table with a row per repo and revision (thus multiple rows can appear for a single repo), with columns "repo_id", "repo_type", "revision", "size_on_disk", "nb_files", "last_modified", "refs", "local_path". Example: ```py >>> from huggingface_hub.utils import scan_cache_dir >>> from huggingface_hub.commands.scan_cache import get_table >>> hf_cache_info = scan_cache_dir() HFCacheInfo(...) >>> print(get_table(hf_cache_info, verbosity=0)) REPO ID REPO TYPE SIZE ON DISK NB FILES LAST_ACCESSED LAST_MODIFIED REFS LOCAL PATH --------------------------------------------------- --------- ------------ -------- ------------- ------------- ---- -------------------------------------------------------------------------------------------------- roberta-base model 2.7M 5 1 day ago 1 week ago main C:\\Users\\admin\\.cache\\huggingface\\hub\\models--roberta-base suno/bark model 8.8K 1 1 week ago 1 week ago main C:\\Users\\admin\\.cache\\huggingface\\hub\\models--suno--bark t5-base model 893.8M 4 4 days ago 7 months ago main C:\\Users\\admin\\.cache\\huggingface\\hub\\models--t5-base t5-large model 3.0G 4 5 weeks ago 5 months ago main C:\\Users\\admin\\.cache\\huggingface\\hub\\models--t5-large >>> print(get_table(hf_cache_info, verbosity=1)) REPO ID REPO TYPE REVISION SIZE ON DISK NB FILES LAST_MODIFIED REFS LOCAL PATH --------------------------------------------------- --------- ---------------------------------------- ------------ -------- ------------- ---- ----------------------------------------------------------------------------------------------------------------------------------------------------- roberta-base model e2da8e2f811d1448a5b465c236feacd80ffbac7b 2.7M 5 1 week ago main C:\\Users\\admin\\.cache\\huggingface\\hub\\models--roberta-base\\snapshots\\e2da8e2f811d1448a5b465c236feacd80ffbac7b suno/bark model 70a8a7d34168586dc5d028fa9666aceade177992 8.8K 1 1 week ago main C:\\Users\\admin\\.cache\\huggingface\\hub\\models--suno--bark\\snapshots\\70a8a7d34168586dc5d028fa9666aceade177992 t5-base model a9723ea7f1b39c1eae772870f3b547bf6ef7e6c1 893.8M 4 7 months ago main C:\\Users\\admin\\.cache\\huggingface\\hub\\models--t5-base\\snapshots\\a9723ea7f1b39c1eae772870f3b547bf6ef7e6c1 t5-large model 150ebc2c4b72291e770f58e6057481c8d2ed331a 3.0G 4 5 months ago main C:\\Users\\admin\\.cache\\huggingface\\hub\\models--t5-large\\snapshots\\150ebc2c4b72291e770f58e6057481c8d2ed331a ``` ``` Args: hf_cache_info ([`HFCacheInfo`]): The HFCacheInfo object to print. verbosity (`int`, *optional*): The verbosity level. Defaults to 0. Returns: `str`: The table as a string. """ if verbosity == 0: return tabulate( rows=[ [ repo.repo_id, repo.repo_type, "{:>12}".format(repo.size_on_disk_str), repo.nb_files, repo.last_accessed_str, repo.last_modified_str, ", ".join(sorted(repo.refs)), str(repo.repo_path), ] for repo in sorted(hf_cache_info.repos, key=lambda repo: repo.repo_path) ], headers=[ "REPO ID", "REPO TYPE", "SIZE ON DISK", "NB FILES", "LAST_ACCESSED", "LAST_MODIFIED", "REFS", "LOCAL PATH", ], ) else: return tabulate( rows=[ [ repo.repo_id, repo.repo_type, revision.commit_hash, "{:>12}".format(revision.size_on_disk_str), revision.nb_files, revision.last_modified_str, ", ".join(sorted(revision.refs)), str(revision.snapshot_path), ] for repo in sorted(hf_cache_info.repos, key=lambda repo: repo.repo_path) for revision in sorted(repo.revisions, key=lambda revision: revision.commit_hash) ], headers=[ "REPO ID", "REPO TYPE", "REVISION", "SIZE ON DISK", "NB FILES", "LAST_MODIFIED", "REFS", "LOCAL PATH", ], ) huggingface_hub-0.31.1/src/huggingface_hub/commands/tag.py000066400000000000000000000142201500667546600235410ustar00rootroot00000000000000# coding=utf-8 # Copyright 2024-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains commands to perform tag management with the CLI. Usage Examples: - Create a tag: $ huggingface-cli tag user/my-model 1.0 --message "First release" $ huggingface-cli tag user/my-model 1.0 -m "First release" --revision develop $ huggingface-cli tag user/my-dataset 1.0 -m "First release" --repo-type dataset $ huggingface-cli tag user/my-space 1.0 - List all tags: $ huggingface-cli tag -l user/my-model $ huggingface-cli tag --list user/my-dataset --repo-type dataset - Delete a tag: $ huggingface-cli tag -d user/my-model 1.0 $ huggingface-cli tag --delete user/my-dataset 1.0 --repo-type dataset $ huggingface-cli tag -d user/my-space 1.0 -y """ from argparse import Namespace, _SubParsersAction from requests.exceptions import HTTPError from huggingface_hub.commands import BaseHuggingfaceCLICommand from huggingface_hub.constants import ( REPO_TYPES, ) from huggingface_hub.hf_api import HfApi from ..errors import HfHubHTTPError, RepositoryNotFoundError, RevisionNotFoundError from ._cli_utils import ANSI class TagCommands(BaseHuggingfaceCLICommand): @staticmethod def register_subcommand(parser: _SubParsersAction): tag_parser = parser.add_parser("tag", help="(create, list, delete) tags for a repo in the hub") tag_parser.add_argument("repo_id", type=str, help="The ID of the repo to tag (e.g. `username/repo-name`).") tag_parser.add_argument("tag", nargs="?", type=str, help="The name of the tag for creation or deletion.") tag_parser.add_argument("-m", "--message", type=str, help="The description of the tag to create.") tag_parser.add_argument("--revision", type=str, help="The git revision to tag.") tag_parser.add_argument( "--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens." ) tag_parser.add_argument( "--repo-type", choices=["model", "dataset", "space"], default="model", help="Set the type of repository (model, dataset, or space).", ) tag_parser.add_argument("-y", "--yes", action="store_true", help="Answer Yes to prompts automatically.") tag_parser.add_argument("-l", "--list", action="store_true", help="List tags for a repository.") tag_parser.add_argument("-d", "--delete", action="store_true", help="Delete a tag for a repository.") tag_parser.set_defaults(func=lambda args: handle_commands(args)) def handle_commands(args: Namespace): if args.list: return TagListCommand(args) elif args.delete: return TagDeleteCommand(args) else: return TagCreateCommand(args) class TagCommand: def __init__(self, args: Namespace): self.args = args self.api = HfApi(token=self.args.token) self.repo_id = self.args.repo_id self.repo_type = self.args.repo_type if self.repo_type not in REPO_TYPES: print("Invalid repo --repo-type") exit(1) class TagCreateCommand(TagCommand): def run(self): print(f"You are about to create tag {ANSI.bold(self.args.tag)} on {self.repo_type} {ANSI.bold(self.repo_id)}") try: self.api.create_tag( repo_id=self.repo_id, tag=self.args.tag, tag_message=self.args.message, revision=self.args.revision, repo_type=self.repo_type, ) except RepositoryNotFoundError: print(f"{self.repo_type.capitalize()} {ANSI.bold(self.repo_id)} not found.") exit(1) except RevisionNotFoundError: print(f"Revision {ANSI.bold(self.args.revision)} not found.") exit(1) except HfHubHTTPError as e: if e.response.status_code == 409: print(f"Tag {ANSI.bold(self.args.tag)} already exists on {ANSI.bold(self.repo_id)}") exit(1) raise e print(f"Tag {ANSI.bold(self.args.tag)} created on {ANSI.bold(self.repo_id)}") class TagListCommand(TagCommand): def run(self): try: refs = self.api.list_repo_refs( repo_id=self.repo_id, repo_type=self.repo_type, ) except RepositoryNotFoundError: print(f"{self.repo_type.capitalize()} {ANSI.bold(self.repo_id)} not found.") exit(1) except HTTPError as e: print(e) print(ANSI.red(e.response.text)) exit(1) if len(refs.tags) == 0: print("No tags found") exit(0) print(f"Tags for {self.repo_type} {ANSI.bold(self.repo_id)}:") for tag in refs.tags: print(tag.name) class TagDeleteCommand(TagCommand): def run(self): print(f"You are about to delete tag {ANSI.bold(self.args.tag)} on {self.repo_type} {ANSI.bold(self.repo_id)}") if not self.args.yes: choice = input("Proceed? [Y/n] ").lower() if choice not in ("", "y", "yes"): print("Abort") exit() try: self.api.delete_tag(repo_id=self.repo_id, tag=self.args.tag, repo_type=self.repo_type) except RepositoryNotFoundError: print(f"{self.repo_type.capitalize()} {ANSI.bold(self.repo_id)} not found.") exit(1) except RevisionNotFoundError: print(f"Tag {ANSI.bold(self.args.tag)} not found on {ANSI.bold(self.repo_id)}") exit(1) print(f"Tag {ANSI.bold(self.args.tag)} deleted on {ANSI.bold(self.repo_id)}") huggingface_hub-0.31.1/src/huggingface_hub/commands/upload.py000066400000000000000000000341651500667546600242640ustar00rootroot00000000000000# coding=utf-8 # Copyright 2023-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains command to upload a repo or file with the CLI. Usage: # Upload file (implicit) huggingface-cli upload my-cool-model ./my-cool-model.safetensors # Upload file (explicit) huggingface-cli upload my-cool-model ./my-cool-model.safetensors model.safetensors # Upload directory (implicit). If `my-cool-model/` is a directory it will be uploaded, otherwise an exception is raised. huggingface-cli upload my-cool-model # Upload directory (explicit) huggingface-cli upload my-cool-model ./models/my-cool-model . # Upload filtered directory (example: tensorboard logs except for the last run) huggingface-cli upload my-cool-model ./model/training /logs --include "*.tfevents.*" --exclude "*20230905*" # Upload with wildcard huggingface-cli upload my-cool-model "./model/training/*.safetensors" # Upload private dataset huggingface-cli upload Wauplin/my-cool-dataset ./data . --repo-type=dataset --private # Upload with token huggingface-cli upload Wauplin/my-cool-model --token=hf_**** # Sync local Space with Hub (upload new files, delete removed files) huggingface-cli upload Wauplin/space-example --repo-type=space --exclude="/logs/*" --delete="*" --commit-message="Sync local Space with Hub" # Schedule commits every 30 minutes huggingface-cli upload Wauplin/my-cool-model --every=30 """ import os import time import warnings from argparse import Namespace, _SubParsersAction from typing import List, Optional from huggingface_hub import logging from huggingface_hub._commit_scheduler import CommitScheduler from huggingface_hub.commands import BaseHuggingfaceCLICommand from huggingface_hub.constants import HF_HUB_ENABLE_HF_TRANSFER from huggingface_hub.errors import RevisionNotFoundError from huggingface_hub.hf_api import HfApi from huggingface_hub.utils import disable_progress_bars, enable_progress_bars from huggingface_hub.utils._runtime import is_xet_available logger = logging.get_logger(__name__) class UploadCommand(BaseHuggingfaceCLICommand): @staticmethod def register_subcommand(parser: _SubParsersAction): upload_parser = parser.add_parser("upload", help="Upload a file or a folder to a repo on the Hub") upload_parser.add_argument( "repo_id", type=str, help="The ID of the repo to upload to (e.g. `username/repo-name`)." ) upload_parser.add_argument( "local_path", nargs="?", help="Local path to the file or folder to upload. Wildcard patterns are supported. Defaults to current directory.", ) upload_parser.add_argument( "path_in_repo", nargs="?", help="Path of the file or folder in the repo. Defaults to the relative path of the file or folder.", ) upload_parser.add_argument( "--repo-type", choices=["model", "dataset", "space"], default="model", help="Type of the repo to upload to (e.g. `dataset`).", ) upload_parser.add_argument( "--revision", type=str, help=( "An optional Git revision to push to. It can be a branch name or a PR reference. If revision does not" " exist and `--create-pr` is not set, a branch will be automatically created." ), ) upload_parser.add_argument( "--private", action="store_true", help=( "Whether to create a private repo if repo doesn't exist on the Hub. Ignored if the repo already" " exists." ), ) upload_parser.add_argument("--include", nargs="*", type=str, help="Glob patterns to match files to upload.") upload_parser.add_argument( "--exclude", nargs="*", type=str, help="Glob patterns to exclude from files to upload." ) upload_parser.add_argument( "--delete", nargs="*", type=str, help="Glob patterns for file to be deleted from the repo while committing.", ) upload_parser.add_argument( "--commit-message", type=str, help="The summary / title / first line of the generated commit." ) upload_parser.add_argument("--commit-description", type=str, help="The description of the generated commit.") upload_parser.add_argument( "--create-pr", action="store_true", help="Whether to upload content as a new Pull Request." ) upload_parser.add_argument( "--every", type=float, help="If set, a background job is scheduled to create commits every `every` minutes.", ) upload_parser.add_argument( "--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens" ) upload_parser.add_argument( "--quiet", action="store_true", help="If True, progress bars are disabled and only the path to the uploaded files is printed.", ) upload_parser.set_defaults(func=UploadCommand) def __init__(self, args: Namespace) -> None: self.repo_id: str = args.repo_id self.repo_type: Optional[str] = args.repo_type self.revision: Optional[str] = args.revision self.private: bool = args.private self.include: Optional[List[str]] = args.include self.exclude: Optional[List[str]] = args.exclude self.delete: Optional[List[str]] = args.delete self.commit_message: Optional[str] = args.commit_message self.commit_description: Optional[str] = args.commit_description self.create_pr: bool = args.create_pr self.api: HfApi = HfApi(token=args.token, library_name="huggingface-cli") self.quiet: bool = args.quiet # disable warnings and progress bars # Check `--every` is valid if args.every is not None and args.every <= 0: raise ValueError(f"`every` must be a positive value (got '{args.every}')") self.every: Optional[float] = args.every # Resolve `local_path` and `path_in_repo` repo_name: str = args.repo_id.split("/")[-1] # e.g. "Wauplin/my-cool-model" => "my-cool-model" self.local_path: str self.path_in_repo: str if args.local_path is not None and any(c in args.local_path for c in ["*", "?", "["]): if args.include is not None: raise ValueError("Cannot set `--include` when passing a `local_path` containing a wildcard.") if args.path_in_repo is not None and args.path_in_repo != ".": raise ValueError("Cannot set `path_in_repo` when passing a `local_path` containing a wildcard.") self.local_path = "." self.include = args.local_path self.path_in_repo = "." elif args.local_path is None and os.path.isfile(repo_name): # Implicit case 1: user provided only a repo_id which happen to be a local file as well => upload it with same name self.local_path = repo_name self.path_in_repo = repo_name elif args.local_path is None and os.path.isdir(repo_name): # Implicit case 2: user provided only a repo_id which happen to be a local folder as well => upload it at root self.local_path = repo_name self.path_in_repo = "." elif args.local_path is None: # Implicit case 3: user provided only a repo_id that does not match a local file or folder # => the user must explicitly provide a local_path => raise exception raise ValueError(f"'{repo_name}' is not a local file or folder. Please set `local_path` explicitly.") elif args.path_in_repo is None and os.path.isfile(args.local_path): # Explicit local path to file, no path in repo => upload it at root with same name self.local_path = args.local_path self.path_in_repo = os.path.basename(args.local_path) elif args.path_in_repo is None: # Explicit local path to folder, no path in repo => upload at root self.local_path = args.local_path self.path_in_repo = "." else: # Finally, if both paths are explicit self.local_path = args.local_path self.path_in_repo = args.path_in_repo def run(self) -> None: if self.quiet: disable_progress_bars() with warnings.catch_warnings(): warnings.simplefilter("ignore") print(self._upload()) enable_progress_bars() else: logging.set_verbosity_info() print(self._upload()) logging.set_verbosity_warning() def _upload(self) -> str: if os.path.isfile(self.local_path): if self.include is not None and len(self.include) > 0: warnings.warn("Ignoring `--include` since a single file is uploaded.") if self.exclude is not None and len(self.exclude) > 0: warnings.warn("Ignoring `--exclude` since a single file is uploaded.") if self.delete is not None and len(self.delete) > 0: warnings.warn("Ignoring `--delete` since a single file is uploaded.") if not is_xet_available() and not HF_HUB_ENABLE_HF_TRANSFER: logger.info( "Consider using `hf_transfer` for faster uploads. This solution comes with some limitations. See" " https://huggingface.co/docs/huggingface_hub/hf_transfer for more details." ) # Schedule commits if `every` is set if self.every is not None: if os.path.isfile(self.local_path): # If file => watch entire folder + use allow_patterns folder_path = os.path.dirname(self.local_path) path_in_repo = ( self.path_in_repo[: -len(self.local_path)] # remove filename from path_in_repo if self.path_in_repo.endswith(self.local_path) else self.path_in_repo ) allow_patterns = [self.local_path] ignore_patterns = [] else: folder_path = self.local_path path_in_repo = self.path_in_repo allow_patterns = self.include or [] ignore_patterns = self.exclude or [] if self.delete is not None and len(self.delete) > 0: warnings.warn("Ignoring `--delete` when uploading with scheduled commits.") scheduler = CommitScheduler( folder_path=folder_path, repo_id=self.repo_id, repo_type=self.repo_type, revision=self.revision, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, path_in_repo=path_in_repo, private=self.private, every=self.every, hf_api=self.api, ) print(f"Scheduling commits every {self.every} minutes to {scheduler.repo_id}.") try: # Block main thread until KeyboardInterrupt while True: time.sleep(100) except KeyboardInterrupt: scheduler.stop() return "Stopped scheduled commits." # Otherwise, create repo and proceed with the upload if not os.path.isfile(self.local_path) and not os.path.isdir(self.local_path): raise FileNotFoundError(f"No such file or directory: '{self.local_path}'.") repo_id = self.api.create_repo( repo_id=self.repo_id, repo_type=self.repo_type, exist_ok=True, private=self.private, space_sdk="gradio" if self.repo_type == "space" else None, # ^ We don't want it to fail when uploading to a Space => let's set Gradio by default. # ^ I'd rather not add CLI args to set it explicitly as we already have `huggingface-cli repo create` for that. ).repo_id # Check if branch already exists and if not, create it if self.revision is not None and not self.create_pr: try: self.api.repo_info(repo_id=repo_id, repo_type=self.repo_type, revision=self.revision) except RevisionNotFoundError: logger.info(f"Branch '{self.revision}' not found. Creating it...") self.api.create_branch(repo_id=repo_id, repo_type=self.repo_type, branch=self.revision, exist_ok=True) # ^ `exist_ok=True` to avoid race concurrency issues # File-based upload if os.path.isfile(self.local_path): return self.api.upload_file( path_or_fileobj=self.local_path, path_in_repo=self.path_in_repo, repo_id=repo_id, repo_type=self.repo_type, revision=self.revision, commit_message=self.commit_message, commit_description=self.commit_description, create_pr=self.create_pr, ) # Folder-based upload else: return self.api.upload_folder( folder_path=self.local_path, path_in_repo=self.path_in_repo, repo_id=repo_id, repo_type=self.repo_type, revision=self.revision, commit_message=self.commit_message, commit_description=self.commit_description, create_pr=self.create_pr, allow_patterns=self.include, ignore_patterns=self.exclude, delete_patterns=self.delete, ) huggingface_hub-0.31.1/src/huggingface_hub/commands/upload_large_folder.py000066400000000000000000000137611500667546600267700ustar00rootroot00000000000000# coding=utf-8 # Copyright 2023-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains command to upload a large folder with the CLI.""" import os from argparse import Namespace, _SubParsersAction from typing import List, Optional from huggingface_hub import logging from huggingface_hub.commands import BaseHuggingfaceCLICommand from huggingface_hub.hf_api import HfApi from huggingface_hub.utils import disable_progress_bars from ._cli_utils import ANSI logger = logging.get_logger(__name__) class UploadLargeFolderCommand(BaseHuggingfaceCLICommand): @staticmethod def register_subcommand(parser: _SubParsersAction): subparser = parser.add_parser("upload-large-folder", help="Upload a large folder to a repo on the Hub") subparser.add_argument( "repo_id", type=str, help="The ID of the repo to upload to (e.g. `username/repo-name`)." ) subparser.add_argument("local_path", type=str, help="Local path to the file or folder to upload.") subparser.add_argument( "--repo-type", choices=["model", "dataset", "space"], help="Type of the repo to upload to (e.g. `dataset`).", ) subparser.add_argument( "--revision", type=str, help=("An optional Git revision to push to. It can be a branch name or a PR reference."), ) subparser.add_argument( "--private", action="store_true", help=( "Whether to create a private repo if repo doesn't exist on the Hub. Ignored if the repo already exists." ), ) subparser.add_argument("--include", nargs="*", type=str, help="Glob patterns to match files to upload.") subparser.add_argument("--exclude", nargs="*", type=str, help="Glob patterns to exclude from files to upload.") subparser.add_argument( "--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens" ) subparser.add_argument( "--num-workers", type=int, help="Number of workers to use to hash, upload and commit files." ) subparser.add_argument("--no-report", action="store_true", help="Whether to disable regular status report.") subparser.add_argument("--no-bars", action="store_true", help="Whether to disable progress bars.") subparser.set_defaults(func=UploadLargeFolderCommand) def __init__(self, args: Namespace) -> None: self.repo_id: str = args.repo_id self.local_path: str = args.local_path self.repo_type: str = args.repo_type self.revision: Optional[str] = args.revision self.private: bool = args.private self.include: Optional[List[str]] = args.include self.exclude: Optional[List[str]] = args.exclude self.api: HfApi = HfApi(token=args.token, library_name="huggingface-cli") self.num_workers: Optional[int] = args.num_workers self.no_report: bool = args.no_report self.no_bars: bool = args.no_bars if not os.path.isdir(self.local_path): raise ValueError("Large upload is only supported for folders.") def run(self) -> None: logging.set_verbosity_info() print( ANSI.yellow( "You are about to upload a large folder to the Hub using `huggingface-cli upload-large-folder`. " "This is a new feature so feedback is very welcome!\n" "\n" "A few things to keep in mind:\n" " - Repository limits still apply: https://huggingface.co/docs/hub/repositories-recommendations\n" " - Do not start several processes in parallel.\n" " - You can interrupt and resume the process at any time. " "The script will pick up where it left off except for partially uploaded files that would have to be entirely reuploaded.\n" " - Do not upload the same folder to several repositories. If you need to do so, you must delete the `./.cache/huggingface/` folder first.\n" "\n" f"Some temporary metadata will be stored under `{self.local_path}/.cache/huggingface`.\n" " - You must not modify those files manually.\n" " - You must not delete the `./.cache/huggingface/` folder while a process is running.\n" " - You can delete the `./.cache/huggingface/` folder to reinitialize the upload state when process is not running. Files will have to be hashed and preuploaded again, except for already committed files.\n" "\n" "If the process output is too verbose, you can disable the progress bars with `--no-bars`. " "You can also entirely disable the status report with `--no-report`.\n" "\n" "For more details, run `huggingface-cli upload-large-folder --help` or check the documentation at " "https://huggingface.co/docs/huggingface_hub/guides/upload#upload-a-large-folder." ) ) if self.no_bars: disable_progress_bars() self.api.upload_large_folder( repo_id=self.repo_id, folder_path=self.local_path, repo_type=self.repo_type, revision=self.revision, private=self.private, allow_patterns=self.include, ignore_patterns=self.exclude, num_workers=self.num_workers, print_report=not self.no_report, ) huggingface_hub-0.31.1/src/huggingface_hub/commands/user.py000066400000000000000000000256401500667546600237540ustar00rootroot00000000000000# Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains commands to authenticate to the Hugging Face Hub and interact with your repositories. Usage: # login and save token locally. huggingface-cli login --token=hf_*** --add-to-git-credential # switch between tokens huggingface-cli auth switch # list all tokens huggingface-cli auth list # logout from a specific token, if no token-name is provided, all tokens will be deleted from your machine. huggingface-cli logout --token-name=your_token_name # find out which huggingface.co account you are logged in as huggingface-cli whoami # create a new dataset repo on the Hub huggingface-cli repo create mydataset --type=dataset """ import subprocess from argparse import _SubParsersAction from typing import List, Optional from requests.exceptions import HTTPError from huggingface_hub.commands import BaseHuggingfaceCLICommand from huggingface_hub.constants import ENDPOINT, REPO_TYPES, REPO_TYPES_URL_PREFIXES, SPACES_SDK_TYPES from huggingface_hub.hf_api import HfApi from .._login import ( # noqa: F401 # for backward compatibility # noqa: F401 # for backward compatibility NOTEBOOK_LOGIN_PASSWORD_HTML, NOTEBOOK_LOGIN_TOKEN_HTML_END, NOTEBOOK_LOGIN_TOKEN_HTML_START, auth_list, auth_switch, login, logout, notebook_login, ) from ..utils import get_stored_tokens, get_token, logging from ._cli_utils import ANSI logger = logging.get_logger(__name__) try: from InquirerPy import inquirer from InquirerPy.base.control import Choice _inquirer_py_available = True except ImportError: _inquirer_py_available = False class UserCommands(BaseHuggingfaceCLICommand): @staticmethod def register_subcommand(parser: _SubParsersAction): login_parser = parser.add_parser("login", help="Log in using a token from huggingface.co/settings/tokens") login_parser.add_argument( "--token", type=str, help="Token generated from https://huggingface.co/settings/tokens", ) login_parser.add_argument( "--add-to-git-credential", action="store_true", help="Optional: Save token to git credential helper.", ) login_parser.set_defaults(func=lambda args: LoginCommand(args)) whoami_parser = parser.add_parser("whoami", help="Find out which huggingface.co account you are logged in as.") whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args)) logout_parser = parser.add_parser("logout", help="Log out") logout_parser.add_argument( "--token-name", type=str, help="Optional: Name of the access token to log out from.", ) logout_parser.set_defaults(func=lambda args: LogoutCommand(args)) auth_parser = parser.add_parser("auth", help="Other authentication related commands") auth_subparsers = auth_parser.add_subparsers(help="Authentication subcommands") auth_switch_parser = auth_subparsers.add_parser("switch", help="Switch between access tokens") auth_switch_parser.add_argument( "--token-name", type=str, help="Optional: Name of the access token to switch to.", ) auth_switch_parser.add_argument( "--add-to-git-credential", action="store_true", help="Optional: Save token to git credential helper.", ) auth_switch_parser.set_defaults(func=lambda args: AuthSwitchCommand(args)) auth_list_parser = auth_subparsers.add_parser("list", help="List all stored access tokens") auth_list_parser.set_defaults(func=lambda args: AuthListCommand(args)) # new system: git-based repo system repo_parser = parser.add_parser("repo", help="{create} Commands to interact with your huggingface.co repos.") repo_subparsers = repo_parser.add_subparsers(help="huggingface.co repos related commands") repo_create_parser = repo_subparsers.add_parser("create", help="Create a new repo on huggingface.co") repo_create_parser.add_argument( "name", type=str, help="Name for your repo. Will be namespaced under your username to build the repo id.", ) repo_create_parser.add_argument( "--type", type=str, help='Optional: repo_type: set to "dataset" or "space" if creating a dataset or space, default is model.', ) repo_create_parser.add_argument("--organization", type=str, help="Optional: organization namespace.") repo_create_parser.add_argument( "--space_sdk", type=str, help='Optional: Hugging Face Spaces SDK type. Required when --type is set to "space".', choices=SPACES_SDK_TYPES, ) repo_create_parser.add_argument( "-y", "--yes", action="store_true", help="Optional: answer Yes to the prompt", ) repo_create_parser.set_defaults(func=lambda args: RepoCreateCommand(args)) class BaseUserCommand: def __init__(self, args): self.args = args self._api = HfApi() class LoginCommand(BaseUserCommand): def run(self): logging.set_verbosity_info() login( token=self.args.token, add_to_git_credential=self.args.add_to_git_credential, ) class LogoutCommand(BaseUserCommand): def run(self): logging.set_verbosity_info() logout(token_name=self.args.token_name) class AuthSwitchCommand(BaseUserCommand): def run(self): logging.set_verbosity_info() token_name = self.args.token_name if token_name is None: token_name = self._select_token_name() if token_name is None: print("No token name provided. Aborting.") exit() auth_switch(token_name, add_to_git_credential=self.args.add_to_git_credential) def _select_token_name(self) -> Optional[str]: token_names = list(get_stored_tokens().keys()) if not token_names: logger.error("No stored tokens found. Please login first.") return None if _inquirer_py_available: return self._select_token_name_tui(token_names) # if inquirer is not available, use a simpler terminal UI print("Available stored tokens:") for i, token_name in enumerate(token_names, 1): print(f"{i}. {token_name}") while True: try: choice = input("Enter the number of the token to switch to (or 'q' to quit): ") if choice.lower() == "q": return None index = int(choice) - 1 if 0 <= index < len(token_names): return token_names[index] else: print("Invalid selection. Please try again.") except ValueError: print("Invalid input. Please enter a number or 'q' to quit.") def _select_token_name_tui(self, token_names: List[str]) -> Optional[str]: choices = [Choice(token_name, name=token_name) for token_name in token_names] try: return inquirer.select( message="Select a token to switch to:", choices=choices, default=None, ).execute() except KeyboardInterrupt: logger.info("Token selection cancelled.") return None class AuthListCommand(BaseUserCommand): def run(self): logging.set_verbosity_info() auth_list() class WhoamiCommand(BaseUserCommand): def run(self): token = get_token() if token is None: print("Not logged in") exit() try: info = self._api.whoami(token) print(info["name"]) orgs = [org["name"] for org in info["orgs"]] if orgs: print(ANSI.bold("orgs: "), ",".join(orgs)) if ENDPOINT != "https://huggingface.co": print(f"Authenticated through private endpoint: {ENDPOINT}") except HTTPError as e: print(e) print(ANSI.red(e.response.text)) exit(1) class RepoCreateCommand(BaseUserCommand): def run(self): token = get_token() if token is None: print("Not logged in") exit(1) try: stdout = subprocess.check_output(["git", "--version"]).decode("utf-8") print(ANSI.gray(stdout.strip())) except FileNotFoundError: print("Looks like you do not have git installed, please install.") try: stdout = subprocess.check_output(["git-lfs", "--version"]).decode("utf-8") print(ANSI.gray(stdout.strip())) except FileNotFoundError: print( ANSI.red( "Looks like you do not have git-lfs installed, please install." " You can install from https://git-lfs.github.com/." " Then run `git lfs install` (you only have to do this once)." ) ) print("") user = self._api.whoami(token)["name"] namespace = self.args.organization if self.args.organization is not None else user repo_id = f"{namespace}/{self.args.name}" if self.args.type not in REPO_TYPES: print("Invalid repo --type") exit(1) if self.args.type in REPO_TYPES_URL_PREFIXES: prefixed_repo_id = REPO_TYPES_URL_PREFIXES[self.args.type] + repo_id else: prefixed_repo_id = repo_id print(f"You are about to create {ANSI.bold(prefixed_repo_id)}") if not self.args.yes: choice = input("Proceed? [Y/n] ").lower() if not (choice == "" or choice == "y" or choice == "yes"): print("Abort") exit() try: url = self._api.create_repo( repo_id=repo_id, token=token, repo_type=self.args.type, space_sdk=self.args.space_sdk, ) except HTTPError as e: print(e) print(ANSI.red(e.response.text)) exit(1) print("\nYour repo now lives at:") print(f" {ANSI.bold(url)}") print("\nYou can clone it locally with the command below, and commit/push as usual.") print(f"\n git clone {url}") print("") huggingface_hub-0.31.1/src/huggingface_hub/commands/version.py000066400000000000000000000023621500667546600244570ustar00rootroot00000000000000# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains command to print information about the version. Usage: huggingface-cli version """ from argparse import _SubParsersAction from huggingface_hub import __version__ from . import BaseHuggingfaceCLICommand class VersionCommand(BaseHuggingfaceCLICommand): def __init__(self, args): self.args = args @staticmethod def register_subcommand(parser: _SubParsersAction): version_parser = parser.add_parser("version", help="Print information about the huggingface-cli version.") version_parser.set_defaults(func=VersionCommand) def run(self) -> None: print(f"huggingface_hub version: {__version__}") huggingface_hub-0.31.1/src/huggingface_hub/community.py000066400000000000000000000276461500667546600232310ustar00rootroot00000000000000""" Data structures to interact with Discussions and Pull Requests on the Hub. See [the Discussions and Pull Requests guide](https://huggingface.co/docs/hub/repositories-pull-requests-discussions) for more information on Pull Requests, Discussions, and the community tab. """ from dataclasses import dataclass from datetime import datetime from typing import List, Literal, Optional, Union from . import constants from .utils import parse_datetime DiscussionStatus = Literal["open", "closed", "merged", "draft"] @dataclass class Discussion: """ A Discussion or Pull Request on the Hub. This dataclass is not intended to be instantiated directly. Attributes: title (`str`): The title of the Discussion / Pull Request status (`str`): The status of the Discussion / Pull Request. It must be one of: * `"open"` * `"closed"` * `"merged"` (only for Pull Requests ) * `"draft"` (only for Pull Requests ) num (`int`): The number of the Discussion / Pull Request. repo_id (`str`): The id (`"{namespace}/{repo_name}"`) of the repo on which the Discussion / Pull Request was open. repo_type (`str`): The type of the repo on which the Discussion / Pull Request was open. Possible values are: `"model"`, `"dataset"`, `"space"`. author (`str`): The username of the Discussion / Pull Request author. Can be `"deleted"` if the user has been deleted since. is_pull_request (`bool`): Whether or not this is a Pull Request. created_at (`datetime`): The `datetime` of creation of the Discussion / Pull Request. endpoint (`str`): Endpoint of the Hub. Default is https://huggingface.co. git_reference (`str`, *optional*): (property) Git reference to which changes can be pushed if this is a Pull Request, `None` otherwise. url (`str`): (property) URL of the discussion on the Hub. """ title: str status: DiscussionStatus num: int repo_id: str repo_type: str author: str is_pull_request: bool created_at: datetime endpoint: str @property def git_reference(self) -> Optional[str]: """ If this is a Pull Request , returns the git reference to which changes can be pushed. Returns `None` otherwise. """ if self.is_pull_request: return f"refs/pr/{self.num}" return None @property def url(self) -> str: """Returns the URL of the discussion on the Hub.""" if self.repo_type is None or self.repo_type == constants.REPO_TYPE_MODEL: return f"{self.endpoint}/{self.repo_id}/discussions/{self.num}" return f"{self.endpoint}/{self.repo_type}s/{self.repo_id}/discussions/{self.num}" @dataclass class DiscussionWithDetails(Discussion): """ Subclass of [`Discussion`]. Attributes: title (`str`): The title of the Discussion / Pull Request status (`str`): The status of the Discussion / Pull Request. It can be one of: * `"open"` * `"closed"` * `"merged"` (only for Pull Requests ) * `"draft"` (only for Pull Requests ) num (`int`): The number of the Discussion / Pull Request. repo_id (`str`): The id (`"{namespace}/{repo_name}"`) of the repo on which the Discussion / Pull Request was open. repo_type (`str`): The type of the repo on which the Discussion / Pull Request was open. Possible values are: `"model"`, `"dataset"`, `"space"`. author (`str`): The username of the Discussion / Pull Request author. Can be `"deleted"` if the user has been deleted since. is_pull_request (`bool`): Whether or not this is a Pull Request. created_at (`datetime`): The `datetime` of creation of the Discussion / Pull Request. events (`list` of [`DiscussionEvent`]) The list of [`DiscussionEvents`] in this Discussion or Pull Request. conflicting_files (`Union[List[str], bool, None]`, *optional*): A list of conflicting files if this is a Pull Request. `None` if `self.is_pull_request` is `False`. `True` if there are conflicting files but the list can't be retrieved. target_branch (`str`, *optional*): The branch into which changes are to be merged if this is a Pull Request . `None` if `self.is_pull_request` is `False`. merge_commit_oid (`str`, *optional*): If this is a merged Pull Request , this is set to the OID / SHA of the merge commit, `None` otherwise. diff (`str`, *optional*): The git diff if this is a Pull Request , `None` otherwise. endpoint (`str`): Endpoint of the Hub. Default is https://huggingface.co. git_reference (`str`, *optional*): (property) Git reference to which changes can be pushed if this is a Pull Request, `None` otherwise. url (`str`): (property) URL of the discussion on the Hub. """ events: List["DiscussionEvent"] conflicting_files: Union[List[str], bool, None] target_branch: Optional[str] merge_commit_oid: Optional[str] diff: Optional[str] @dataclass class DiscussionEvent: """ An event in a Discussion or Pull Request. Use concrete classes: * [`DiscussionComment`] * [`DiscussionStatusChange`] * [`DiscussionCommit`] * [`DiscussionTitleChange`] Attributes: id (`str`): The ID of the event. An hexadecimal string. type (`str`): The type of the event. created_at (`datetime`): A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime) object holding the creation timestamp for the event. author (`str`): The username of the Discussion / Pull Request author. Can be `"deleted"` if the user has been deleted since. """ id: str type: str created_at: datetime author: str _event: dict """Stores the original event data, in case we need to access it later.""" @dataclass class DiscussionComment(DiscussionEvent): """A comment in a Discussion / Pull Request. Subclass of [`DiscussionEvent`]. Attributes: id (`str`): The ID of the event. An hexadecimal string. type (`str`): The type of the event. created_at (`datetime`): A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime) object holding the creation timestamp for the event. author (`str`): The username of the Discussion / Pull Request author. Can be `"deleted"` if the user has been deleted since. content (`str`): The raw markdown content of the comment. Mentions, links and images are not rendered. edited (`bool`): Whether or not this comment has been edited. hidden (`bool`): Whether or not this comment has been hidden. """ content: str edited: bool hidden: bool @property def rendered(self) -> str: """The rendered comment, as a HTML string""" return self._event["data"]["latest"]["html"] @property def last_edited_at(self) -> datetime: """The last edit time, as a `datetime` object.""" return parse_datetime(self._event["data"]["latest"]["updatedAt"]) @property def last_edited_by(self) -> str: """The last edit time, as a `datetime` object.""" return self._event["data"]["latest"].get("author", {}).get("name", "deleted") @property def edit_history(self) -> List[dict]: """The edit history of the comment""" return self._event["data"]["history"] @property def number_of_edits(self) -> int: return len(self.edit_history) @dataclass class DiscussionStatusChange(DiscussionEvent): """A change of status in a Discussion / Pull Request. Subclass of [`DiscussionEvent`]. Attributes: id (`str`): The ID of the event. An hexadecimal string. type (`str`): The type of the event. created_at (`datetime`): A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime) object holding the creation timestamp for the event. author (`str`): The username of the Discussion / Pull Request author. Can be `"deleted"` if the user has been deleted since. new_status (`str`): The status of the Discussion / Pull Request after the change. It can be one of: * `"open"` * `"closed"` * `"merged"` (only for Pull Requests ) """ new_status: str @dataclass class DiscussionCommit(DiscussionEvent): """A commit in a Pull Request. Subclass of [`DiscussionEvent`]. Attributes: id (`str`): The ID of the event. An hexadecimal string. type (`str`): The type of the event. created_at (`datetime`): A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime) object holding the creation timestamp for the event. author (`str`): The username of the Discussion / Pull Request author. Can be `"deleted"` if the user has been deleted since. summary (`str`): The summary of the commit. oid (`str`): The OID / SHA of the commit, as a hexadecimal string. """ summary: str oid: str @dataclass class DiscussionTitleChange(DiscussionEvent): """A rename event in a Discussion / Pull Request. Subclass of [`DiscussionEvent`]. Attributes: id (`str`): The ID of the event. An hexadecimal string. type (`str`): The type of the event. created_at (`datetime`): A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime) object holding the creation timestamp for the event. author (`str`): The username of the Discussion / Pull Request author. Can be `"deleted"` if the user has been deleted since. old_title (`str`): The previous title for the Discussion / Pull Request. new_title (`str`): The new title. """ old_title: str new_title: str def deserialize_event(event: dict) -> DiscussionEvent: """Instantiates a [`DiscussionEvent`] from a dict""" event_id: str = event["id"] event_type: str = event["type"] created_at = parse_datetime(event["createdAt"]) common_args = dict( id=event_id, type=event_type, created_at=created_at, author=event.get("author", {}).get("name", "deleted"), _event=event, ) if event_type == "comment": return DiscussionComment( **common_args, edited=event["data"]["edited"], hidden=event["data"]["hidden"], content=event["data"]["latest"]["raw"], ) if event_type == "status-change": return DiscussionStatusChange( **common_args, new_status=event["data"]["status"], ) if event_type == "commit": return DiscussionCommit( **common_args, summary=event["data"]["subject"], oid=event["data"]["oid"], ) if event_type == "title-change": return DiscussionTitleChange( **common_args, old_title=event["data"]["from"], new_title=event["data"]["to"], ) return DiscussionEvent(**common_args) huggingface_hub-0.31.1/src/huggingface_hub/constants.py000066400000000000000000000225031500667546600232040ustar00rootroot00000000000000import os import re import typing from typing import Literal, Optional, Tuple # Possible values for env variables ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) def _is_true(value: Optional[str]) -> bool: if value is None: return False return value.upper() in ENV_VARS_TRUE_VALUES def _as_int(value: Optional[str]) -> Optional[int]: if value is None: return None return int(value) # Constants for file downloads PYTORCH_WEIGHTS_NAME = "pytorch_model.bin" TF2_WEIGHTS_NAME = "tf_model.h5" TF_WEIGHTS_NAME = "model.ckpt" FLAX_WEIGHTS_NAME = "flax_model.msgpack" CONFIG_NAME = "config.json" REPOCARD_NAME = "README.md" DEFAULT_ETAG_TIMEOUT = 10 DEFAULT_DOWNLOAD_TIMEOUT = 10 DEFAULT_REQUEST_TIMEOUT = 10 DOWNLOAD_CHUNK_SIZE = 10 * 1024 * 1024 HF_TRANSFER_CONCURRENCY = 100 MAX_HTTP_DOWNLOAD_SIZE = 50 * 1000 * 1000 * 1000 # 50 GB # Constants for serialization PYTORCH_WEIGHTS_FILE_PATTERN = "pytorch_model{suffix}.bin" # Unsafe pickle: use safetensors instead SAFETENSORS_WEIGHTS_FILE_PATTERN = "model{suffix}.safetensors" TF2_WEIGHTS_FILE_PATTERN = "tf_model{suffix}.h5" # Constants for safetensors repos SAFETENSORS_SINGLE_FILE = "model.safetensors" SAFETENSORS_INDEX_FILE = "model.safetensors.index.json" SAFETENSORS_MAX_HEADER_LENGTH = 25_000_000 # Timeout of aquiring file lock and logging the attempt FILELOCK_LOG_EVERY_SECONDS = 10 # Git-related constants DEFAULT_REVISION = "main" REGEX_COMMIT_OID = re.compile(r"[A-Fa-f0-9]{5,40}") HUGGINGFACE_CO_URL_HOME = "https://huggingface.co/" _staging_mode = _is_true(os.environ.get("HUGGINGFACE_CO_STAGING")) _HF_DEFAULT_ENDPOINT = "https://huggingface.co" _HF_DEFAULT_STAGING_ENDPOINT = "https://hub-ci.huggingface.co" ENDPOINT = os.getenv("HF_ENDPOINT", _HF_DEFAULT_ENDPOINT).rstrip("/") HUGGINGFACE_CO_URL_TEMPLATE = ENDPOINT + "/{repo_id}/resolve/{revision}/{filename}" if _staging_mode: ENDPOINT = _HF_DEFAULT_STAGING_ENDPOINT HUGGINGFACE_CO_URL_TEMPLATE = _HF_DEFAULT_STAGING_ENDPOINT + "/{repo_id}/resolve/{revision}/{filename}" HUGGINGFACE_HEADER_X_REPO_COMMIT = "X-Repo-Commit" HUGGINGFACE_HEADER_X_LINKED_ETAG = "X-Linked-Etag" HUGGINGFACE_HEADER_X_LINKED_SIZE = "X-Linked-Size" HUGGINGFACE_HEADER_X_BILL_TO = "X-HF-Bill-To" INFERENCE_ENDPOINT = os.environ.get("HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co") # See https://huggingface.co/docs/inference-endpoints/index INFERENCE_ENDPOINTS_ENDPOINT = "https://api.endpoints.huggingface.cloud/v2" INFERENCE_CATALOG_ENDPOINT = "https://endpoints.huggingface.co/api/catalog" # Proxy for third-party providers INFERENCE_PROXY_TEMPLATE = "https://router.huggingface.co/{provider}" REPO_ID_SEPARATOR = "--" # ^ this substring is not allowed in repo_ids on hf.co # and is the canonical one we use for serialization of repo ids elsewhere. REPO_TYPE_DATASET = "dataset" REPO_TYPE_SPACE = "space" REPO_TYPE_MODEL = "model" REPO_TYPES = [None, REPO_TYPE_MODEL, REPO_TYPE_DATASET, REPO_TYPE_SPACE] SPACES_SDK_TYPES = ["gradio", "streamlit", "docker", "static"] REPO_TYPES_URL_PREFIXES = { REPO_TYPE_DATASET: "datasets/", REPO_TYPE_SPACE: "spaces/", } REPO_TYPES_MAPPING = { "datasets": REPO_TYPE_DATASET, "spaces": REPO_TYPE_SPACE, "models": REPO_TYPE_MODEL, } DiscussionTypeFilter = Literal["all", "discussion", "pull_request"] DISCUSSION_TYPES: Tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionTypeFilter) DiscussionStatusFilter = Literal["all", "open", "closed"] DISCUSSION_STATUS: Tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionStatusFilter) # Webhook subscription types WEBHOOK_DOMAIN_T = Literal["repo", "discussions"] # default cache default_home = os.path.join(os.path.expanduser("~"), ".cache") HF_HOME = os.path.expandvars( os.path.expanduser( os.getenv( "HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", default_home), "huggingface"), ) ) ) hf_cache_home = HF_HOME # for backward compatibility. TODO: remove this in 1.0.0 default_cache_path = os.path.join(HF_HOME, "hub") default_assets_cache_path = os.path.join(HF_HOME, "assets") # Legacy env variables HUGGINGFACE_HUB_CACHE = os.getenv("HUGGINGFACE_HUB_CACHE", default_cache_path) HUGGINGFACE_ASSETS_CACHE = os.getenv("HUGGINGFACE_ASSETS_CACHE", default_assets_cache_path) # New env variables HF_HUB_CACHE = os.path.expandvars( os.path.expanduser( os.getenv( "HF_HUB_CACHE", HUGGINGFACE_HUB_CACHE, ) ) ) HF_ASSETS_CACHE = os.path.expandvars( os.path.expanduser( os.getenv( "HF_ASSETS_CACHE", HUGGINGFACE_ASSETS_CACHE, ) ) ) HF_HUB_OFFLINE = _is_true(os.environ.get("HF_HUB_OFFLINE") or os.environ.get("TRANSFORMERS_OFFLINE")) # If set, log level will be set to DEBUG and all requests made to the Hub will be logged # as curl commands for reproducibility. HF_DEBUG = _is_true(os.environ.get("HF_DEBUG")) # Opt-out from telemetry requests HF_HUB_DISABLE_TELEMETRY = ( _is_true(os.environ.get("HF_HUB_DISABLE_TELEMETRY")) # HF-specific env variable or _is_true(os.environ.get("DISABLE_TELEMETRY")) or _is_true(os.environ.get("DO_NOT_TRACK")) # https://consoledonottrack.com/ ) HF_TOKEN_PATH = os.path.expandvars( os.path.expanduser( os.getenv( "HF_TOKEN_PATH", os.path.join(HF_HOME, "token"), ) ) ) HF_STORED_TOKENS_PATH = os.path.join(os.path.dirname(HF_TOKEN_PATH), "stored_tokens") if _staging_mode: # In staging mode, we use a different cache to ensure we don't mix up production and staging data or tokens # In practice in `huggingface_hub` tests, we monkeypatch these values with temporary directories. The following # lines are only used in third-party libraries tests (e.g. `transformers`, `diffusers`, etc.). _staging_home = os.path.join(os.path.expanduser("~"), ".cache", "huggingface_staging") HUGGINGFACE_HUB_CACHE = os.path.join(_staging_home, "hub") HF_TOKEN_PATH = os.path.join(_staging_home, "token") # Here, `True` will disable progress bars globally without possibility of enabling it # programmatically. `False` will enable them without possibility of disabling them. # If environment variable is not set (None), then the user is free to enable/disable # them programmatically. # TL;DR: env variable has priority over code __HF_HUB_DISABLE_PROGRESS_BARS = os.environ.get("HF_HUB_DISABLE_PROGRESS_BARS") HF_HUB_DISABLE_PROGRESS_BARS: Optional[bool] = ( _is_true(__HF_HUB_DISABLE_PROGRESS_BARS) if __HF_HUB_DISABLE_PROGRESS_BARS is not None else None ) # Disable warning on machines that do not support symlinks (e.g. Windows non-developer) HF_HUB_DISABLE_SYMLINKS_WARNING: bool = _is_true(os.environ.get("HF_HUB_DISABLE_SYMLINKS_WARNING")) # Disable warning when using experimental features HF_HUB_DISABLE_EXPERIMENTAL_WARNING: bool = _is_true(os.environ.get("HF_HUB_DISABLE_EXPERIMENTAL_WARNING")) # Disable sending the cached token by default is all HTTP requests to the Hub HF_HUB_DISABLE_IMPLICIT_TOKEN: bool = _is_true(os.environ.get("HF_HUB_DISABLE_IMPLICIT_TOKEN")) # Enable fast-download using external dependency "hf_transfer" # See: # - https://pypi.org/project/hf-transfer/ # - https://github.com/huggingface/hf_transfer (private) HF_HUB_ENABLE_HF_TRANSFER: bool = _is_true(os.environ.get("HF_HUB_ENABLE_HF_TRANSFER")) # UNUSED # We don't use symlinks in local dir anymore. HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD: int = ( _as_int(os.environ.get("HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD")) or 5 * 1024 * 1024 ) # Used to override the etag timeout on a system level HF_HUB_ETAG_TIMEOUT: int = _as_int(os.environ.get("HF_HUB_ETAG_TIMEOUT")) or DEFAULT_ETAG_TIMEOUT # Used to override the get request timeout on a system level HF_HUB_DOWNLOAD_TIMEOUT: int = _as_int(os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT")) or DEFAULT_DOWNLOAD_TIMEOUT # Allows to add information about the requester in the user-agent (eg. partner name) HF_HUB_USER_AGENT_ORIGIN: Optional[str] = os.environ.get("HF_HUB_USER_AGENT_ORIGIN") # List frameworks that are handled by the InferenceAPI service. Useful to scan endpoints and check which models are # deployed and running. Since 95% of the models are using the top 4 frameworks listed below, we scan only those by # default. We still keep the full list of supported frameworks in case we want to scan all of them. MAIN_INFERENCE_API_FRAMEWORKS = [ "diffusers", "sentence-transformers", "text-generation-inference", "transformers", ] ALL_INFERENCE_API_FRAMEWORKS = MAIN_INFERENCE_API_FRAMEWORKS + [ "adapter-transformers", "allennlp", "asteroid", "bertopic", "doctr", "espnet", "fairseq", "fastai", "fasttext", "flair", "k2", "keras", "mindspore", "nemo", "open_clip", "paddlenlp", "peft", "pyannote-audio", "sklearn", "spacy", "span-marker", "speechbrain", "stanza", "timm", ] # Xet constants HUGGINGFACE_HEADER_X_XET_ENDPOINT = "X-Xet-Cas-Url" HUGGINGFACE_HEADER_X_XET_ACCESS_TOKEN = "X-Xet-Access-Token" HUGGINGFACE_HEADER_X_XET_EXPIRATION = "X-Xet-Token-Expiration" HUGGINGFACE_HEADER_X_XET_HASH = "X-Xet-Hash" HUGGINGFACE_HEADER_X_XET_REFRESH_ROUTE = "X-Xet-Refresh-Route" HUGGINGFACE_HEADER_LINK_XET_AUTH_KEY = "xet-auth" default_xet_cache_path = os.path.join(HF_HOME, "xet") HF_XET_CACHE = os.getenv("HF_XET_CACHE", default_xet_cache_path) huggingface_hub-0.31.1/src/huggingface_hub/errors.py000066400000000000000000000236531500667546600225130ustar00rootroot00000000000000"""Contains all custom errors.""" from pathlib import Path from typing import Optional, Union from requests import HTTPError, Response # CACHE ERRORS class CacheNotFound(Exception): """Exception thrown when the Huggingface cache is not found.""" cache_dir: Union[str, Path] def __init__(self, msg: str, cache_dir: Union[str, Path], *args, **kwargs): super().__init__(msg, *args, **kwargs) self.cache_dir = cache_dir class CorruptedCacheException(Exception): """Exception for any unexpected structure in the Huggingface cache-system.""" # HEADERS ERRORS class LocalTokenNotFoundError(EnvironmentError): """Raised if local token is required but not found.""" # HTTP ERRORS class OfflineModeIsEnabled(ConnectionError): """Raised when a request is made but `HF_HUB_OFFLINE=1` is set as environment variable.""" class HfHubHTTPError(HTTPError): """ HTTPError to inherit from for any custom HTTP Error raised in HF Hub. Any HTTPError is converted at least into a `HfHubHTTPError`. If some information is sent back by the server, it will be added to the error message. Added details: - Request id from "X-Request-Id" header if exists. If not, fallback to "X-Amzn-Trace-Id" header if exists. - Server error message from the header "X-Error-Message". - Server error message if we can found one in the response body. Example: ```py import requests from huggingface_hub.utils import get_session, hf_raise_for_status, HfHubHTTPError response = get_session().post(...) try: hf_raise_for_status(response) except HfHubHTTPError as e: print(str(e)) # formatted message e.request_id, e.server_message # details returned by server # Complete the error message with additional information once it's raised e.append_to_message("\n`create_commit` expects the repository to exist.") raise ``` """ def __init__(self, message: str, response: Optional[Response] = None, *, server_message: Optional[str] = None): self.request_id = ( response.headers.get("x-request-id") or response.headers.get("X-Amzn-Trace-Id") if response is not None else None ) self.server_message = server_message super().__init__( message, response=response, # type: ignore [arg-type] request=response.request if response is not None else None, # type: ignore [arg-type] ) def append_to_message(self, additional_message: str) -> None: """Append additional information to the `HfHubHTTPError` initial message.""" self.args = (self.args[0] + additional_message,) + self.args[1:] # INFERENCE CLIENT ERRORS class InferenceTimeoutError(HTTPError, TimeoutError): """Error raised when a model is unavailable or the request times out.""" # INFERENCE ENDPOINT ERRORS class InferenceEndpointError(Exception): """Generic exception when dealing with Inference Endpoints.""" class InferenceEndpointTimeoutError(InferenceEndpointError, TimeoutError): """Exception for timeouts while waiting for Inference Endpoint.""" # SAFETENSORS ERRORS class SafetensorsParsingError(Exception): """Raised when failing to parse a safetensors file metadata. This can be the case if the file is not a safetensors file or does not respect the specification. """ class NotASafetensorsRepoError(Exception): """Raised when a repo is not a Safetensors repo i.e. doesn't have either a `model.safetensors` or a `model.safetensors.index.json` file. """ # TEXT GENERATION ERRORS class TextGenerationError(HTTPError): """Generic error raised if text-generation went wrong.""" # Text Generation Inference Errors class ValidationError(TextGenerationError): """Server-side validation error.""" class GenerationError(TextGenerationError): pass class OverloadedError(TextGenerationError): pass class IncompleteGenerationError(TextGenerationError): pass class UnknownError(TextGenerationError): pass # VALIDATION ERRORS class HFValidationError(ValueError): """Generic exception thrown by `huggingface_hub` validators. Inherits from [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError). """ # FILE METADATA ERRORS class FileMetadataError(OSError): """Error triggered when the metadata of a file on the Hub cannot be retrieved (missing ETag or commit_hash). Inherits from `OSError` for backward compatibility. """ # REPOSITORY ERRORS class RepositoryNotFoundError(HfHubHTTPError): """ Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does not have access to. Example: ```py >>> from huggingface_hub import model_info >>> model_info("") (...) huggingface_hub.utils._errors.RepositoryNotFoundError: 401 Client Error. (Request ID: PvMw_VjBMjVdMz53WKIzP) Repository Not Found for url: https://huggingface.co/api/models/%3Cnon_existent_repository%3E. Please make sure you specified the correct `repo_id` and `repo_type`. If the repo is private, make sure you are authenticated. Invalid username or password. ``` """ class GatedRepoError(RepositoryNotFoundError): """ Raised when trying to access a gated repository for which the user is not on the authorized list. Note: derives from `RepositoryNotFoundError` to ensure backward compatibility. Example: ```py >>> from huggingface_hub import model_info >>> model_info("") (...) huggingface_hub.utils._errors.GatedRepoError: 403 Client Error. (Request ID: ViT1Bf7O_026LGSQuVqfa) Cannot access gated repo for url https://huggingface.co/api/models/ardent-figment/gated-model. Access to model ardent-figment/gated-model is restricted and you are not in the authorized list. Visit https://huggingface.co/ardent-figment/gated-model to ask for access. ``` """ class DisabledRepoError(HfHubHTTPError): """ Raised when trying to access a repository that has been disabled by its author. Example: ```py >>> from huggingface_hub import dataset_info >>> dataset_info("laion/laion-art") (...) huggingface_hub.utils._errors.DisabledRepoError: 403 Client Error. (Request ID: Root=1-659fc3fa-3031673e0f92c71a2260dbe2;bc6f4dfb-b30a-4862-af0a-5cfe827610d8) Cannot access repository for url https://huggingface.co/api/datasets/laion/laion-art. Access to this resource is disabled. ``` """ # REVISION ERROR class RevisionNotFoundError(HfHubHTTPError): """ Raised when trying to access a hf.co URL with a valid repository but an invalid revision. Example: ```py >>> from huggingface_hub import hf_hub_download >>> hf_hub_download('bert-base-cased', 'config.json', revision='') (...) huggingface_hub.utils._errors.RevisionNotFoundError: 404 Client Error. (Request ID: Mwhe_c3Kt650GcdKEFomX) Revision Not Found for url: https://huggingface.co/bert-base-cased/resolve/%3Cnon-existent-revision%3E/config.json. ``` """ # ENTRY ERRORS class EntryNotFoundError(HfHubHTTPError): """ Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename. Example: ```py >>> from huggingface_hub import hf_hub_download >>> hf_hub_download('bert-base-cased', '') (...) huggingface_hub.utils._errors.EntryNotFoundError: 404 Client Error. (Request ID: 53pNl6M0MxsnG5Sw8JA6x) Entry Not Found for url: https://huggingface.co/bert-base-cased/resolve/main/%3Cnon-existent-file%3E. ``` """ class LocalEntryNotFoundError(EntryNotFoundError, FileNotFoundError, ValueError): """ Raised when trying to access a file or snapshot that is not on the disk when network is disabled or unavailable (connection issue). The entry may exist on the Hub. Note: `ValueError` type is to ensure backward compatibility. Note: `LocalEntryNotFoundError` derives from `HTTPError` because of `EntryNotFoundError` even when it is not a network issue. Example: ```py >>> from huggingface_hub import hf_hub_download >>> hf_hub_download('bert-base-cased', '', local_files_only=True) (...) huggingface_hub.utils._errors.LocalEntryNotFoundError: Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable hf.co look-ups and downloads online, set 'local_files_only' to False. ``` """ def __init__(self, message: str): super().__init__(message, response=None) # REQUEST ERROR class BadRequestError(HfHubHTTPError, ValueError): """ Raised by `hf_raise_for_status` when the server returns a HTTP 400 error. Example: ```py >>> resp = requests.post("hf.co/api/check", ...) >>> hf_raise_for_status(resp, endpoint_name="check") huggingface_hub.utils._errors.BadRequestError: Bad request for check endpoint: {details} (Request ID: XXX) ``` """ # DDUF file format ERROR class DDUFError(Exception): """Base exception for errors related to the DDUF format.""" class DDUFCorruptedFileError(DDUFError): """Exception thrown when the DDUF file is corrupted.""" class DDUFExportError(DDUFError): """Base exception for errors during DDUF export.""" class DDUFInvalidEntryNameError(DDUFExportError): """Exception thrown when the entry name is invalid.""" # XET ERRORS class XetError(Exception): """Base exception for errors related to Xet Storage.""" class XetAuthorizationError(XetError): """Exception thrown when the user does not have the right authorization to use Xet Storage.""" class XetRefreshTokenError(XetError): """Exception thrown when the refresh token is invalid.""" class XetDownloadError(Exception): """Exception thrown when the download from Xet Storage fails.""" huggingface_hub-0.31.1/src/huggingface_hub/fastai_utils.py000066400000000000000000000405511500667546600236620ustar00rootroot00000000000000import json import os from pathlib import Path from pickle import DEFAULT_PROTOCOL, PicklingError from typing import Any, Dict, List, Optional, Union from packaging import version from huggingface_hub import constants, snapshot_download from huggingface_hub.hf_api import HfApi from huggingface_hub.utils import ( SoftTemporaryDirectory, get_fastai_version, get_fastcore_version, get_python_version, ) from .utils import logging, validate_hf_hub_args from .utils._runtime import _PY_VERSION # noqa: F401 # for backward compatibility... logger = logging.get_logger(__name__) def _check_fastai_fastcore_versions( fastai_min_version: str = "2.4", fastcore_min_version: str = "1.3.27", ): """ Checks that the installed fastai and fastcore versions are compatible for pickle serialization. Args: fastai_min_version (`str`, *optional*): The minimum fastai version supported. fastcore_min_version (`str`, *optional*): The minimum fastcore version supported. Raises the following error: - [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError) if the fastai or fastcore libraries are not available or are of an invalid version. """ if (get_fastcore_version() or get_fastai_version()) == "N/A": raise ImportError( f"fastai>={fastai_min_version} and fastcore>={fastcore_min_version} are" f" required. Currently using fastai=={get_fastai_version()} and" f" fastcore=={get_fastcore_version()}." ) current_fastai_version = version.Version(get_fastai_version()) current_fastcore_version = version.Version(get_fastcore_version()) if current_fastai_version < version.Version(fastai_min_version): raise ImportError( "`push_to_hub_fastai` and `from_pretrained_fastai` require a" f" fastai>={fastai_min_version} version, but you are using fastai version" f" {get_fastai_version()} which is incompatible. Upgrade with `pip install" " fastai==2.5.6`." ) if current_fastcore_version < version.Version(fastcore_min_version): raise ImportError( "`push_to_hub_fastai` and `from_pretrained_fastai` require a" f" fastcore>={fastcore_min_version} version, but you are using fastcore" f" version {get_fastcore_version()} which is incompatible. Upgrade with" " `pip install fastcore==1.3.27`." ) def _check_fastai_fastcore_pyproject_versions( storage_folder: str, fastai_min_version: str = "2.4", fastcore_min_version: str = "1.3.27", ): """ Checks that the `pyproject.toml` file in the directory `storage_folder` has fastai and fastcore versions that are compatible with `from_pretrained_fastai` and `push_to_hub_fastai`. If `pyproject.toml` does not exist or does not contain versions for fastai and fastcore, then it logs a warning. Args: storage_folder (`str`): Folder to look for the `pyproject.toml` file. fastai_min_version (`str`, *optional*): The minimum fastai version supported. fastcore_min_version (`str`, *optional*): The minimum fastcore version supported. Raises the following errors: - [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError) if the `toml` module is not installed. - [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError) if the `pyproject.toml` indicates a lower than minimum supported version of fastai or fastcore. """ try: import toml except ModuleNotFoundError: raise ImportError( "`push_to_hub_fastai` and `from_pretrained_fastai` require the toml module." " Install it with `pip install toml`." ) # Checks that a `pyproject.toml`, with `build-system` and `requires` sections, exists in the repository. If so, get a list of required packages. if not os.path.isfile(f"{storage_folder}/pyproject.toml"): logger.warning( "There is no `pyproject.toml` in the repository that contains the fastai" " `Learner`. The `pyproject.toml` would allow us to verify that your fastai" " and fastcore versions are compatible with those of the model you want to" " load." ) return pyproject_toml = toml.load(f"{storage_folder}/pyproject.toml") if "build-system" not in pyproject_toml.keys(): logger.warning( "There is no `build-system` section in the pyproject.toml of the repository" " that contains the fastai `Learner`. The `build-system` would allow us to" " verify that your fastai and fastcore versions are compatible with those" " of the model you want to load." ) return build_system_toml = pyproject_toml["build-system"] if "requires" not in build_system_toml.keys(): logger.warning( "There is no `requires` section in the pyproject.toml of the repository" " that contains the fastai `Learner`. The `requires` would allow us to" " verify that your fastai and fastcore versions are compatible with those" " of the model you want to load." ) return package_versions = build_system_toml["requires"] # Extracts contains fastai and fastcore versions from `pyproject.toml` if available. # If the package is specified but not the version (e.g. "fastai" instead of "fastai=2.4"), the default versions are the highest. fastai_packages = [pck for pck in package_versions if pck.startswith("fastai")] if len(fastai_packages) == 0: logger.warning("The repository does not have a fastai version specified in the `pyproject.toml`.") # fastai_version is an empty string if not specified else: fastai_version = str(fastai_packages[0]).partition("=")[2] if fastai_version != "" and version.Version(fastai_version) < version.Version(fastai_min_version): raise ImportError( "`from_pretrained_fastai` requires" f" fastai>={fastai_min_version} version but the model to load uses" f" {fastai_version} which is incompatible." ) fastcore_packages = [pck for pck in package_versions if pck.startswith("fastcore")] if len(fastcore_packages) == 0: logger.warning("The repository does not have a fastcore version specified in the `pyproject.toml`.") # fastcore_version is an empty string if not specified else: fastcore_version = str(fastcore_packages[0]).partition("=")[2] if fastcore_version != "" and version.Version(fastcore_version) < version.Version(fastcore_min_version): raise ImportError( "`from_pretrained_fastai` requires" f" fastcore>={fastcore_min_version} version, but you are using fastcore" f" version {fastcore_version} which is incompatible." ) README_TEMPLATE = """--- tags: - fastai --- # Amazing! 🥳 Congratulations on hosting your fastai model on the Hugging Face Hub! # Some next steps 1. Fill out this model card with more information (see the template below and the [documentation here](https://huggingface.co/docs/hub/model-repos))! 2. Create a demo in Gradio or Streamlit using 🤗 Spaces ([documentation here](https://huggingface.co/docs/hub/spaces)). 3. Join the fastai community on the [Fastai Discord](https://discord.com/invite/YKrxeNn)! Greetings fellow fastlearner 🤝! Don't forget to delete this content from your model card. --- # Model card ## Model description More information needed ## Intended uses & limitations More information needed ## Training and evaluation data More information needed """ PYPROJECT_TEMPLATE = f"""[build-system] requires = ["setuptools>=40.8.0", "wheel", "python={get_python_version()}", "fastai={get_fastai_version()}", "fastcore={get_fastcore_version()}"] build-backend = "setuptools.build_meta:__legacy__" """ def _create_model_card(repo_dir: Path): """ Creates a model card for the repository. Args: repo_dir (`Path`): Directory where model card is created. """ readme_path = repo_dir / "README.md" if not readme_path.exists(): with readme_path.open("w", encoding="utf-8") as f: f.write(README_TEMPLATE) def _create_model_pyproject(repo_dir: Path): """ Creates a `pyproject.toml` for the repository. Args: repo_dir (`Path`): Directory where `pyproject.toml` is created. """ pyproject_path = repo_dir / "pyproject.toml" if not pyproject_path.exists(): with pyproject_path.open("w", encoding="utf-8") as f: f.write(PYPROJECT_TEMPLATE) def _save_pretrained_fastai( learner, save_directory: Union[str, Path], config: Optional[Dict[str, Any]] = None, ): """ Saves a fastai learner to `save_directory` in pickle format using the default pickle protocol for the version of python used. Args: learner (`Learner`): The `fastai.Learner` you'd like to save. save_directory (`str` or `Path`): Specific directory in which you want to save the fastai learner. config (`dict`, *optional*): Configuration object. Will be uploaded as a .json file. Example: 'https://huggingface.co/espejelomar/fastai-pet-breeds-classification/blob/main/config.json'. Raises the following error: - [`RuntimeError`](https://docs.python.org/3/library/exceptions.html#RuntimeError) if the config file provided is not a dictionary. """ _check_fastai_fastcore_versions() os.makedirs(save_directory, exist_ok=True) # if the user provides config then we update it with the fastai and fastcore versions in CONFIG_TEMPLATE. if config is not None: if not isinstance(config, dict): raise RuntimeError(f"Provided config should be a dict. Got: '{type(config)}'") path = os.path.join(save_directory, constants.CONFIG_NAME) with open(path, "w") as f: json.dump(config, f) _create_model_card(Path(save_directory)) _create_model_pyproject(Path(save_directory)) # learner.export saves the model in `self.path`. learner.path = Path(save_directory) os.makedirs(save_directory, exist_ok=True) try: learner.export( fname="model.pkl", pickle_protocol=DEFAULT_PROTOCOL, ) except PicklingError: raise PicklingError( "You are using a lambda function, i.e., an anonymous function. `pickle`" " cannot pickle function objects and requires that all functions have" " names. One possible solution is to name the function." ) @validate_hf_hub_args def from_pretrained_fastai( repo_id: str, revision: Optional[str] = None, ): """ Load pretrained fastai model from the Hub or from a local directory. Args: repo_id (`str`): The location where the pickled fastai.Learner is. It can be either of the two: - Hosted on the Hugging Face Hub. E.g.: 'espejelomar/fatai-pet-breeds-classification' or 'distilgpt2'. You can add a `revision` by appending `@` at the end of `repo_id`. E.g.: `dbmdz/bert-base-german-cased@main`. Revision is the specific model version to use. Since we use a git-based system for storing models and other artifacts on the Hugging Face Hub, it can be a branch name, a tag name, or a commit id. - Hosted locally. `repo_id` would be a directory containing the pickle and a pyproject.toml indicating the fastai and fastcore versions used to build the `fastai.Learner`. E.g.: `./my_model_directory/`. revision (`str`, *optional*): Revision at which the repo's files are downloaded. See documentation of `snapshot_download`. Returns: The `fastai.Learner` model in the `repo_id` repo. """ _check_fastai_fastcore_versions() # Load the `repo_id` repo. # `snapshot_download` returns the folder where the model was stored. # `cache_dir` will be the default '/root/.cache/huggingface/hub' if not os.path.isdir(repo_id): storage_folder = snapshot_download( repo_id=repo_id, revision=revision, library_name="fastai", library_version=get_fastai_version(), ) else: storage_folder = repo_id _check_fastai_fastcore_pyproject_versions(storage_folder) from fastai.learner import load_learner # type: ignore return load_learner(os.path.join(storage_folder, "model.pkl")) @validate_hf_hub_args def push_to_hub_fastai( learner, *, repo_id: str, commit_message: str = "Push FastAI model using huggingface_hub.", private: Optional[bool] = None, token: Optional[str] = None, config: Optional[dict] = None, branch: Optional[str] = None, create_pr: Optional[bool] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, delete_patterns: Optional[Union[List[str], str]] = None, api_endpoint: Optional[str] = None, ): """ Upload learner checkpoint files to the Hub. Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more details. Args: learner (`Learner`): The `fastai.Learner' you'd like to push to the Hub. repo_id (`str`): The repository id for your model in Hub in the format of "namespace/repo_name". The namespace can be your individual account or an organization to which you have write access (for example, 'stanfordnlp/stanza-de'). commit_message (`str`, *optional*): Message to commit while pushing. Will default to :obj:`"add model"`. private (`bool`, *optional*): Whether or not the repository created should be private. If `None` (default), will default to been public except if the organization's default is private. token (`str`, *optional*): The Hugging Face account token to use as HTTP bearer authorization for remote files. If :obj:`None`, the token will be asked by a prompt. config (`dict`, *optional*): Configuration object to be saved alongside the model weights. branch (`str`, *optional*): The git branch on which to push the model. This defaults to the default branch as specified in your repository, which defaults to `"main"`. create_pr (`boolean`, *optional*): Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`. api_endpoint (`str`, *optional*): The API endpoint to use when pushing the model to the hub. allow_patterns (`List[str]` or `str`, *optional*): If provided, only files matching at least one pattern are pushed. ignore_patterns (`List[str]` or `str`, *optional*): If provided, files matching any of the patterns are not pushed. delete_patterns (`List[str]` or `str`, *optional*): If provided, remote files matching any of the patterns will be deleted from the repo. Returns: The url of the commit of your model in the given repository. Raises the following error: - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if the user is not log on to the Hugging Face Hub. """ _check_fastai_fastcore_versions() api = HfApi(endpoint=api_endpoint) repo_id = api.create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True).repo_id # Push the files to the repo in a single commit with SoftTemporaryDirectory() as tmp: saved_path = Path(tmp) / repo_id _save_pretrained_fastai(learner, saved_path, config=config) return api.upload_folder( repo_id=repo_id, token=token, folder_path=saved_path, commit_message=commit_message, revision=branch, create_pr=create_pr, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, delete_patterns=delete_patterns, ) huggingface_hub-0.31.1/src/huggingface_hub/file_download.py000066400000000000000000002311651500667546600240040ustar00rootroot00000000000000import copy import errno import inspect import os import re import shutil import stat import time import uuid import warnings from dataclasses import dataclass from pathlib import Path from typing import Any, BinaryIO, Dict, Literal, NoReturn, Optional, Tuple, Union from urllib.parse import quote, urlparse import requests from . import ( __version__, # noqa: F401 # for backward compatibility constants, ) from ._local_folder import get_local_download_paths, read_download_metadata, write_download_metadata from .constants import ( HUGGINGFACE_CO_URL_TEMPLATE, # noqa: F401 # for backward compatibility HUGGINGFACE_HUB_CACHE, # noqa: F401 # for backward compatibility ) from .errors import ( EntryNotFoundError, FileMetadataError, GatedRepoError, HfHubHTTPError, LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError, ) from .utils import ( OfflineModeIsEnabled, SoftTemporaryDirectory, WeakFileLock, XetFileData, build_hf_headers, get_fastai_version, # noqa: F401 # for backward compatibility get_fastcore_version, # noqa: F401 # for backward compatibility get_graphviz_version, # noqa: F401 # for backward compatibility get_jinja_version, # noqa: F401 # for backward compatibility get_pydot_version, # noqa: F401 # for backward compatibility get_tf_version, # noqa: F401 # for backward compatibility get_torch_version, # noqa: F401 # for backward compatibility hf_raise_for_status, is_fastai_available, # noqa: F401 # for backward compatibility is_fastcore_available, # noqa: F401 # for backward compatibility is_graphviz_available, # noqa: F401 # for backward compatibility is_jinja_available, # noqa: F401 # for backward compatibility is_pydot_available, # noqa: F401 # for backward compatibility is_tf_available, # noqa: F401 # for backward compatibility is_torch_available, # noqa: F401 # for backward compatibility logging, parse_xet_file_data_from_response, refresh_xet_connection_info, reset_sessions, tqdm, validate_hf_hub_args, ) from .utils._http import _adjust_range_header, http_backoff from .utils._runtime import _PY_VERSION, is_xet_available # noqa: F401 # for backward compatibility from .utils._typing import HTTP_METHOD_T from .utils.sha import sha_fileobj from .utils.tqdm import _get_progress_bar_context logger = logging.get_logger(__name__) # Return value when trying to load a file from cache but the file does not exist in the distant repo. _CACHED_NO_EXIST = object() _CACHED_NO_EXIST_T = Any # Regex to get filename from a "Content-Disposition" header for CDN-served files HEADER_FILENAME_PATTERN = re.compile(r'filename="(?P.*?)";') # Regex to check if the revision IS directly a commit_hash REGEX_COMMIT_HASH = re.compile(r"^[0-9a-f]{40}$") # Regex to check if the file etag IS a valid sha256 REGEX_SHA256 = re.compile(r"^[0-9a-f]{64}$") _are_symlinks_supported_in_dir: Dict[str, bool] = {} def are_symlinks_supported(cache_dir: Union[str, Path, None] = None) -> bool: """Return whether the symlinks are supported on the machine. Since symlinks support can change depending on the mounted disk, we need to check on the precise cache folder. By default, the default HF cache directory is checked. Args: cache_dir (`str`, `Path`, *optional*): Path to the folder where cached files are stored. Returns: [bool] Whether symlinks are supported in the directory. """ # Defaults to HF cache if cache_dir is None: cache_dir = constants.HF_HUB_CACHE cache_dir = str(Path(cache_dir).expanduser().resolve()) # make it unique # Check symlink compatibility only once (per cache directory) at first time use if cache_dir not in _are_symlinks_supported_in_dir: _are_symlinks_supported_in_dir[cache_dir] = True os.makedirs(cache_dir, exist_ok=True) with SoftTemporaryDirectory(dir=cache_dir) as tmpdir: src_path = Path(tmpdir) / "dummy_file_src" src_path.touch() dst_path = Path(tmpdir) / "dummy_file_dst" # Relative source path as in `_create_symlink`` relative_src = os.path.relpath(src_path, start=os.path.dirname(dst_path)) try: os.symlink(relative_src, dst_path) except OSError: # Likely running on Windows _are_symlinks_supported_in_dir[cache_dir] = False if not constants.HF_HUB_DISABLE_SYMLINKS_WARNING: message = ( "`huggingface_hub` cache-system uses symlinks by default to" " efficiently store duplicated files but your machine does not" f" support them in {cache_dir}. Caching files will still work" " but in a degraded version that might require more space on" " your disk. This warning can be disabled by setting the" " `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For" " more details, see" " https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations." ) if os.name == "nt": message += ( "\nTo support symlinks on Windows, you either need to" " activate Developer Mode or to run Python as an" " administrator. In order to activate developer mode," " see this article:" " https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development" ) warnings.warn(message) return _are_symlinks_supported_in_dir[cache_dir] @dataclass(frozen=True) class HfFileMetadata: """Data structure containing information about a file versioned on the Hub. Returned by [`get_hf_file_metadata`] based on a URL. Args: commit_hash (`str`, *optional*): The commit_hash related to the file. etag (`str`, *optional*): Etag of the file on the server. location (`str`): Location where to download the file. Can be a Hub url or not (CDN). size (`size`): Size of the file. In case of an LFS file, contains the size of the actual LFS file, not the pointer. xet_file_data (`XetFileData`, *optional*): Xet information for the file. This is only set if the file is stored using Xet storage. """ commit_hash: Optional[str] etag: Optional[str] location: str size: Optional[int] xet_file_data: Optional[XetFileData] @validate_hf_hub_args def hf_hub_url( repo_id: str, filename: str, *, subfolder: Optional[str] = None, repo_type: Optional[str] = None, revision: Optional[str] = None, endpoint: Optional[str] = None, ) -> str: """Construct the URL of a file from the given information. The resolved address can either be a huggingface.co-hosted url, or a link to Cloudfront (a Content Delivery Network, or CDN) for large files which are more than a few MBs. Args: repo_id (`str`): A namespace (user or an organization) name and a repo name separated by a `/`. filename (`str`): The name of the file in the repo. subfolder (`str`, *optional*): An optional value corresponding to a folder inside the repo. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if downloading from a dataset or space, `None` or `"model"` if downloading from a model. Default is `None`. revision (`str`, *optional*): An optional Git revision id which can be a branch name, a tag, or a commit hash. Example: ```python >>> from huggingface_hub import hf_hub_url >>> hf_hub_url( ... repo_id="julien-c/EsperBERTo-small", filename="pytorch_model.bin" ... ) 'https://huggingface.co/julien-c/EsperBERTo-small/resolve/main/pytorch_model.bin' ``` Notes: Cloudfront is replicated over the globe so downloads are way faster for the end user (and it also lowers our bandwidth costs). Cloudfront aggressively caches files by default (default TTL is 24 hours), however this is not an issue here because we implement a git-based versioning system on huggingface.co, which means that we store the files on S3/Cloudfront in a content-addressable way (i.e., the file name is its hash). Using content-addressable filenames means cache can't ever be stale. In terms of client-side caching from this library, we base our caching on the objects' entity tag (`ETag`), which is an identifier of a specific version of a resource [1]_. An object's ETag is: its git-sha1 if stored in git, or its sha256 if stored in git-lfs. References: - [1] https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/ETag """ if subfolder == "": subfolder = None if subfolder is not None: filename = f"{subfolder}/{filename}" if repo_type not in constants.REPO_TYPES: raise ValueError("Invalid repo type") if repo_type in constants.REPO_TYPES_URL_PREFIXES: repo_id = constants.REPO_TYPES_URL_PREFIXES[repo_type] + repo_id if revision is None: revision = constants.DEFAULT_REVISION url = HUGGINGFACE_CO_URL_TEMPLATE.format( repo_id=repo_id, revision=quote(revision, safe=""), filename=quote(filename) ) # Update endpoint if provided if endpoint is not None and url.startswith(constants.ENDPOINT): url = endpoint + url[len(constants.ENDPOINT) :] return url def _request_wrapper( method: HTTP_METHOD_T, url: str, *, follow_relative_redirects: bool = False, **params ) -> requests.Response: """Wrapper around requests methods to follow relative redirects if `follow_relative_redirects=True` even when `allow_redirection=False`. A backoff mechanism retries the HTTP call on 429, 503 and 504 errors. Args: method (`str`): HTTP method, such as 'GET' or 'HEAD'. url (`str`): The URL of the resource to fetch. follow_relative_redirects (`bool`, *optional*, defaults to `False`) If True, relative redirection (redirection to the same site) will be resolved even when `allow_redirection` kwarg is set to False. Useful when we want to follow a redirection to a renamed repository without following redirection to a CDN. **params (`dict`, *optional*): Params to pass to `requests.request`. """ # Recursively follow relative redirects if follow_relative_redirects: response = _request_wrapper( method=method, url=url, follow_relative_redirects=False, **params, ) # If redirection, we redirect only relative paths. # This is useful in case of a renamed repository. if 300 <= response.status_code <= 399: parsed_target = urlparse(response.headers["Location"]) if parsed_target.netloc == "": # This means it is a relative 'location' headers, as allowed by RFC 7231. # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource') # We want to follow this relative redirect ! # # Highly inspired by `resolve_redirects` from requests library. # See https://github.com/psf/requests/blob/main/requests/sessions.py#L159 next_url = urlparse(url)._replace(path=parsed_target.path).geturl() return _request_wrapper(method=method, url=next_url, follow_relative_redirects=True, **params) return response # Perform request and return if status_code is not in the retry list. response = http_backoff(method=method, url=url, **params, retry_on_exceptions=(), retry_on_status_codes=(429,)) hf_raise_for_status(response) return response def _get_file_length_from_http_response(response: requests.Response) -> Optional[int]: """ Get the length of the file from the HTTP response headers. This function extracts the file size from the HTTP response headers, either from the `Content-Range` or `Content-Length` header, if available (in that order). The HTTP response object containing the headers. `int` or `None`: The length of the file in bytes if the information is available, otherwise `None`. Args: response (`requests.Response`): The HTTP response object. Returns: `int` or `None`: The length of the file in bytes, or None if not available. """ content_range = response.headers.get("Content-Range") if content_range is not None: return int(content_range.rsplit("/")[-1]) content_length = response.headers.get("Content-Length") if content_length is not None: return int(content_length) return None def http_get( url: str, temp_file: BinaryIO, *, proxies: Optional[Dict] = None, resume_size: int = 0, headers: Optional[Dict[str, Any]] = None, expected_size: Optional[int] = None, displayed_filename: Optional[str] = None, _nb_retries: int = 5, _tqdm_bar: Optional[tqdm] = None, ) -> None: """ Download a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub. If ConnectionError (SSLError) or ReadTimeout happen while streaming data from the server, it is most likely a transient error (network outage?). We log a warning message and try to resume the download a few times before giving up. The method gives up after 5 attempts if no new data has being received from the server. Args: url (`str`): The URL of the file to download. temp_file (`BinaryIO`): The file-like object where to save the file. proxies (`dict`, *optional*): Dictionary mapping protocol to the URL of the proxy passed to `requests.request`. resume_size (`int`, *optional*): The number of bytes already downloaded. If set to 0 (default), the whole file is download. If set to a positive number, the download will resume at the given position. headers (`dict`, *optional*): Dictionary of HTTP Headers to send with the request. expected_size (`int`, *optional*): The expected size of the file to download. If set, the download will raise an error if the size of the received content is different from the expected one. displayed_filename (`str`, *optional*): The filename of the file that is being downloaded. Value is used only to display a nice progress bar. If not set, the filename is guessed from the URL or the `Content-Disposition` header. """ if expected_size is not None and resume_size == expected_size: # If the file is already fully downloaded, we don't need to download it again. return has_custom_range_header = headers is not None and any(h.lower() == "range" for h in headers) hf_transfer = None if constants.HF_HUB_ENABLE_HF_TRANSFER: if resume_size != 0: warnings.warn("'hf_transfer' does not support `resume_size`: falling back to regular download method") elif proxies is not None: warnings.warn("'hf_transfer' does not support `proxies`: falling back to regular download method") elif has_custom_range_header: warnings.warn("'hf_transfer' ignores custom 'Range' headers; falling back to regular download method") else: try: import hf_transfer # type: ignore[no-redef] except ImportError: raise ValueError( "Fast download using 'hf_transfer' is enabled" " (HF_HUB_ENABLE_HF_TRANSFER=1) but 'hf_transfer' package is not" " available in your environment. Try `pip install hf_transfer`." ) initial_headers = headers headers = copy.deepcopy(headers) or {} if resume_size > 0: headers["Range"] = _adjust_range_header(headers.get("Range"), resume_size) elif expected_size and expected_size > constants.MAX_HTTP_DOWNLOAD_SIZE: # Any files over 50GB will not be available through basic http request. # Setting the range header to 0-0 will force the server to return the file size in the Content-Range header. # Since hf_transfer splits the download into chunks, the process will succeed afterwards. if hf_transfer: headers["Range"] = "bytes=0-0" else: raise ValueError( "The file is too large to be downloaded using the regular download method. Use `hf_transfer` or `hf_xet` instead." " Try `pip install hf_transfer` or `pip install hf_xet`." ) r = _request_wrapper( method="GET", url=url, stream=True, proxies=proxies, headers=headers, timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT ) hf_raise_for_status(r) content_length = _get_file_length_from_http_response(r) # NOTE: 'total' is the total number of bytes to download, not the number of bytes in the file. # If the file is compressed, the number of bytes in the saved file will be higher than 'total'. total = resume_size + int(content_length) if content_length is not None else None if displayed_filename is None: displayed_filename = url content_disposition = r.headers.get("Content-Disposition") if content_disposition is not None: match = HEADER_FILENAME_PATTERN.search(content_disposition) if match is not None: # Means file is on CDN displayed_filename = match.groupdict()["filename"] # Truncate filename if too long to display if len(displayed_filename) > 40: displayed_filename = f"(…){displayed_filename[-40:]}" consistency_error_message = ( f"Consistency check failed: file should be of size {expected_size} but has size" f" {{actual_size}} ({displayed_filename}).\nThis is usually due to network issues while downloading the file." " Please retry with `force_download=True`." ) progress_cm = _get_progress_bar_context( desc=displayed_filename, log_level=logger.getEffectiveLevel(), total=total, initial=resume_size, name="huggingface_hub.http_get", _tqdm_bar=_tqdm_bar, ) with progress_cm as progress: if hf_transfer and total is not None and total > 5 * constants.DOWNLOAD_CHUNK_SIZE: supports_callback = "callback" in inspect.signature(hf_transfer.download).parameters if not supports_callback: warnings.warn( "You are using an outdated version of `hf_transfer`. " "Consider upgrading to latest version to enable progress bars " "using `pip install -U hf_transfer`." ) try: hf_transfer.download( url=url, filename=temp_file.name, max_files=constants.HF_TRANSFER_CONCURRENCY, chunk_size=constants.DOWNLOAD_CHUNK_SIZE, headers=initial_headers, parallel_failures=3, max_retries=5, **({"callback": progress.update} if supports_callback else {}), ) except Exception as e: raise RuntimeError( "An error occurred while downloading using `hf_transfer`. Consider" " disabling HF_HUB_ENABLE_HF_TRANSFER for better error handling." ) from e if not supports_callback: progress.update(total) if expected_size is not None and expected_size != os.path.getsize(temp_file.name): raise EnvironmentError( consistency_error_message.format( actual_size=os.path.getsize(temp_file.name), ) ) return new_resume_size = resume_size try: for chunk in r.iter_content(chunk_size=constants.DOWNLOAD_CHUNK_SIZE): if chunk: # filter out keep-alive new chunks progress.update(len(chunk)) temp_file.write(chunk) new_resume_size += len(chunk) # Some data has been downloaded from the server so we reset the number of retries. _nb_retries = 5 except (requests.ConnectionError, requests.ReadTimeout) as e: # If ConnectionError (SSLError) or ReadTimeout happen while streaming data from the server, it is most likely # a transient error (network outage?). We log a warning message and try to resume the download a few times # before giving up. Tre retry mechanism is basic but should be enough in most cases. if _nb_retries <= 0: logger.warning("Error while downloading from %s: %s\nMax retries exceeded.", url, str(e)) raise logger.warning("Error while downloading from %s: %s\nTrying to resume download...", url, str(e)) time.sleep(1) reset_sessions() # In case of SSLError it's best to reset the shared requests.Session objects return http_get( url=url, temp_file=temp_file, proxies=proxies, resume_size=new_resume_size, headers=initial_headers, expected_size=expected_size, _nb_retries=_nb_retries - 1, _tqdm_bar=_tqdm_bar, ) if expected_size is not None and expected_size != temp_file.tell(): raise EnvironmentError( consistency_error_message.format( actual_size=temp_file.tell(), ) ) def xet_get( *, incomplete_path: Path, xet_file_data: XetFileData, headers: Dict[str, str], expected_size: Optional[int] = None, displayed_filename: Optional[str] = None, _tqdm_bar: Optional[tqdm] = None, ) -> None: """ Download a file using Xet storage service. Args: incomplete_path (`Path`): The path to the file to download. xet_file_data (`XetFileData`): The file metadata needed to make the request to the xet storage service. headers (`Dict[str, str]`): The headers to send to the xet storage service. expected_size (`int`, *optional*): The expected size of the file to download. If set, the download will raise an error if the size of the received content is different from the expected one. displayed_filename (`str`, *optional*): The filename of the file that is being downloaded. Value is used only to display a nice progress bar. If not set, the filename is guessed from the URL or the `Content-Disposition` header. **How it works:** The file download system uses Xet storage, which is a content-addressable storage system that breaks files into chunks for efficient storage and transfer. `hf_xet.download_files` manages downloading files by: - Taking a list of files to download (each with its unique content hash) - Connecting to a storage server (CAS server) that knows how files are chunked - Using authentication to ensure secure access - Providing progress updates during download Authentication works by regularly refreshing access tokens through `refresh_xet_connection_info` to maintain a valid connection to the storage server. The download process works like this: 1. Create a local cache folder at `~/.cache/huggingface/xet/chunk-cache` to store reusable file chunks 2. Download files in parallel: 2.1. Prepare to write the file to disk 2.2. Ask the server "how is this file split into chunks?" using the file's unique hash The server responds with: - Which chunks make up the complete file - Where each chunk can be downloaded from 2.3. For each needed chunk: - Checks if we already have it in our local cache - If not, download it from cloud storage (S3) - Save it to cache for future use - Assemble the chunks in order to recreate the original file """ try: from hf_xet import PyXetDownloadInfo, download_files # type: ignore[no-redef] except ImportError: raise ValueError( "To use optimized download using Xet storage, you need to install the hf_xet package. " 'Try `pip install "huggingface_hub[hf_xet]"` or `pip install hf_xet`.' ) connection_info = refresh_xet_connection_info(file_data=xet_file_data, headers=headers) def token_refresher() -> Tuple[str, int]: connection_info = refresh_xet_connection_info(file_data=xet_file_data, headers=headers) if connection_info is None: raise ValueError("Failed to refresh token using xet metadata.") return connection_info.access_token, connection_info.expiration_unix_epoch xet_download_info = [ PyXetDownloadInfo( destination_path=str(incomplete_path.absolute()), hash=xet_file_data.file_hash, file_size=expected_size ) ] if not displayed_filename: displayed_filename = incomplete_path.name # Truncate filename if too long to display if len(displayed_filename) > 40: displayed_filename = f"{displayed_filename[:40]}(…)" progress_cm = _get_progress_bar_context( desc=displayed_filename, log_level=logger.getEffectiveLevel(), total=expected_size, initial=0, name="huggingface_hub.xet_get", _tqdm_bar=_tqdm_bar, ) with progress_cm as progress: def progress_updater(progress_bytes: float): progress.update(progress_bytes) download_files( xet_download_info, endpoint=connection_info.endpoint, token_info=(connection_info.access_token, connection_info.expiration_unix_epoch), token_refresher=token_refresher, progress_updater=[progress_updater], ) def _normalize_etag(etag: Optional[str]) -> Optional[str]: """Normalize ETag HTTP header, so it can be used to create nice filepaths. The HTTP spec allows two forms of ETag: ETag: W/"" ETag: "" For now, we only expect the second form from the server, but we want to be future-proof so we support both. For more context, see `TestNormalizeEtag` tests and https://github.com/huggingface/huggingface_hub/pull/1428. Args: etag (`str`, *optional*): HTTP header Returns: `str` or `None`: string that can be used as a nice directory name. Returns `None` if input is None. """ if etag is None: return None return etag.lstrip("W/").strip('"') def _create_relative_symlink(src: str, dst: str, new_blob: bool = False) -> None: """Alias method used in `transformers` conversion script.""" return _create_symlink(src=src, dst=dst, new_blob=new_blob) def _create_symlink(src: str, dst: str, new_blob: bool = False) -> None: """Create a symbolic link named dst pointing to src. By default, it will try to create a symlink using a relative path. Relative paths have 2 advantages: - If the cache_folder is moved (example: back-up on a shared drive), relative paths within the cache folder will not break. - Relative paths seems to be better handled on Windows. Issue was reported 3 times in less than a week when changing from relative to absolute paths. See https://github.com/huggingface/huggingface_hub/issues/1398, https://github.com/huggingface/diffusers/issues/2729 and https://github.com/huggingface/transformers/pull/22228. NOTE: The issue with absolute paths doesn't happen on admin mode. When creating a symlink from the cache to a local folder, it is possible that a relative path cannot be created. This happens when paths are not on the same volume. In that case, we use absolute paths. The result layout looks something like └── [ 128] snapshots ├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f │ ├── [ 52] README.md -> ../../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812 │ └── [ 76] pytorch_model.bin -> ../../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd If symlinks cannot be created on this platform (most likely to be Windows), the workaround is to avoid symlinks by having the actual file in `dst`. If it is a new file (`new_blob=True`), we move it to `dst`. If it is not a new file (`new_blob=False`), we don't know if the blob file is already referenced elsewhere. To avoid breaking existing cache, the file is duplicated on the disk. In case symlinks are not supported, a warning message is displayed to the user once when loading `huggingface_hub`. The warning message can be disabled with the `DISABLE_SYMLINKS_WARNING` environment variable. """ try: os.remove(dst) except OSError: pass abs_src = os.path.abspath(os.path.expanduser(src)) abs_dst = os.path.abspath(os.path.expanduser(dst)) abs_dst_folder = os.path.dirname(abs_dst) # Use relative_dst in priority try: relative_src = os.path.relpath(abs_src, abs_dst_folder) except ValueError: # Raised on Windows if src and dst are not on the same volume. This is the case when creating a symlink to a # local_dir instead of within the cache directory. # See https://docs.python.org/3/library/os.path.html#os.path.relpath relative_src = None try: commonpath = os.path.commonpath([abs_src, abs_dst]) _support_symlinks = are_symlinks_supported(commonpath) except ValueError: # Raised if src and dst are not on the same volume. Symlinks will still work on Linux/Macos. # See https://docs.python.org/3/library/os.path.html#os.path.commonpath _support_symlinks = os.name != "nt" except PermissionError: # Permission error means src and dst are not in the same volume (e.g. destination path has been provided # by the user via `local_dir`. Let's test symlink support there) _support_symlinks = are_symlinks_supported(abs_dst_folder) except OSError as e: # OS error (errno=30) means that the commonpath is readonly on Linux/MacOS. if e.errno == errno.EROFS: _support_symlinks = are_symlinks_supported(abs_dst_folder) else: raise # Symlinks are supported => let's create a symlink. if _support_symlinks: src_rel_or_abs = relative_src or abs_src logger.debug(f"Creating pointer from {src_rel_or_abs} to {abs_dst}") try: os.symlink(src_rel_or_abs, abs_dst) return except FileExistsError: if os.path.islink(abs_dst) and os.path.realpath(abs_dst) == os.path.realpath(abs_src): # `abs_dst` already exists and is a symlink to the `abs_src` blob. It is most likely that the file has # been cached twice concurrently (exactly between `os.remove` and `os.symlink`). Do nothing. return else: # Very unlikely to happen. Means a file `dst` has been created exactly between `os.remove` and # `os.symlink` and is not a symlink to the `abs_src` blob file. Raise exception. raise except PermissionError: # Permission error means src and dst are not in the same volume (e.g. download to local dir) and symlink # is supported on both volumes but not between them. Let's just make a hard copy in that case. pass # Symlinks are not supported => let's move or copy the file. if new_blob: logger.info(f"Symlink not supported. Moving file from {abs_src} to {abs_dst}") shutil.move(abs_src, abs_dst, copy_function=_copy_no_matter_what) else: logger.info(f"Symlink not supported. Copying file from {abs_src} to {abs_dst}") shutil.copyfile(abs_src, abs_dst) def _cache_commit_hash_for_specific_revision(storage_folder: str, revision: str, commit_hash: str) -> None: """Cache reference between a revision (tag, branch or truncated commit hash) and the corresponding commit hash. Does nothing if `revision` is already a proper `commit_hash` or reference is already cached. """ if revision != commit_hash: ref_path = Path(storage_folder) / "refs" / revision ref_path.parent.mkdir(parents=True, exist_ok=True) if not ref_path.exists() or commit_hash != ref_path.read_text(): # Update ref only if has been updated. Could cause useless error in case # repo is already cached and user doesn't have write access to cache folder. # See https://github.com/huggingface/huggingface_hub/issues/1216. ref_path.write_text(commit_hash) @validate_hf_hub_args def repo_folder_name(*, repo_id: str, repo_type: str) -> str: """Return a serialized version of a hf.co repo name and type, safe for disk storage as a single non-nested folder. Example: models--julien-c--EsperBERTo-small """ # remove all `/` occurrences to correctly convert repo to directory name parts = [f"{repo_type}s", *repo_id.split("/")] return constants.REPO_ID_SEPARATOR.join(parts) def _check_disk_space(expected_size: int, target_dir: Union[str, Path]) -> None: """Check disk usage and log a warning if there is not enough disk space to download the file. Args: expected_size (`int`): The expected size of the file in bytes. target_dir (`str`): The directory where the file will be stored after downloading. """ target_dir = Path(target_dir) # format as `Path` for path in [target_dir] + list(target_dir.parents): # first check target_dir, then each parents one by one try: target_dir_free = shutil.disk_usage(path).free if target_dir_free < expected_size: warnings.warn( "Not enough free disk space to download the file. " f"The expected file size is: {expected_size / 1e6:.2f} MB. " f"The target location {target_dir} only has {target_dir_free / 1e6:.2f} MB free disk space." ) return except OSError: # raise on anything: file does not exist or space disk cannot be checked pass @validate_hf_hub_args def hf_hub_download( repo_id: str, filename: str, *, subfolder: Optional[str] = None, repo_type: Optional[str] = None, revision: Optional[str] = None, library_name: Optional[str] = None, library_version: Optional[str] = None, cache_dir: Union[str, Path, None] = None, local_dir: Union[str, Path, None] = None, user_agent: Union[Dict, str, None] = None, force_download: bool = False, proxies: Optional[Dict] = None, etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, token: Union[bool, str, None] = None, local_files_only: bool = False, headers: Optional[Dict[str, str]] = None, endpoint: Optional[str] = None, resume_download: Optional[bool] = None, force_filename: Optional[str] = None, local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", ) -> str: """Download a given file if it's not already present in the local cache. The new cache file layout looks like this: - The cache directory contains one subfolder per repo_id (namespaced by repo type) - inside each repo folder: - refs is a list of the latest known revision => commit_hash pairs - blobs contains the actual file blobs (identified by their git-sha or sha256, depending on whether they're LFS files or not) - snapshots contains one subfolder per commit, each "commit" contains the subset of the files that have been resolved at that particular commit. Each filename is a symlink to the blob at that particular commit. ``` [ 96] . └── [ 160] models--julien-c--EsperBERTo-small ├── [ 160] blobs │ ├── [321M] 403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd │ ├── [ 398] 7cb18dc9bafbfcf74629a4b760af1b160957a83e │ └── [1.4K] d7edf6bd2a681fb0175f7735299831ee1b22b812 ├── [ 96] refs │ └── [ 40] main └── [ 128] snapshots ├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f │ ├── [ 52] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812 │ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd └── [ 128] bbc77c8132af1cc5cf678da3f1ddf2de43606d48 ├── [ 52] README.md -> ../../blobs/7cb18dc9bafbfcf74629a4b760af1b160957a83e └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd ``` If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir` to store some metadata related to the downloaded files. While this mechanism is not as robust as the main cache-system, it's optimized for regularly pulling the latest version of a repository. Args: repo_id (`str`): A user or an organization name and a repo name separated by a `/`. filename (`str`): The name of the file in the repo. subfolder (`str`, *optional*): An optional value corresponding to a folder inside the model repo. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if downloading from a dataset or space, `None` or `"model"` if downloading from a model. Default is `None`. revision (`str`, *optional*): An optional Git revision id which can be a branch name, a tag, or a commit hash. library_name (`str`, *optional*): The name of the library to which the object corresponds. library_version (`str`, *optional*): The version of the library. cache_dir (`str`, `Path`, *optional*): Path to the folder where cached files are stored. local_dir (`str` or `Path`, *optional*): If provided, the downloaded file will be placed under this directory. user_agent (`dict`, `str`, *optional*): The user-agent info in the form of a dictionary or a string. force_download (`bool`, *optional*, defaults to `False`): Whether the file should be downloaded even if it already exists in the local cache. proxies (`dict`, *optional*): Dictionary mapping protocol to the URL of the proxy passed to `requests.request`. etag_timeout (`float`, *optional*, defaults to `10`): When fetching ETag, how many seconds to wait for the server to send data before giving up which is passed to `requests.request`. token (`str`, `bool`, *optional*): A token to be used for the download. - If `True`, the token is read from the HuggingFace config folder. - If a string, it's used as the authentication token. local_files_only (`bool`, *optional*, defaults to `False`): If `True`, avoid downloading the file and return the path to the local cached file if it exists. headers (`dict`, *optional*): Additional headers to be sent with the request. Returns: `str`: Local path of file or if networking is off, last version of file cached on disk. Raises: [`~utils.RepositoryNotFoundError`] If the repository to download from cannot be found. This may be because it doesn't exist, or because it is set to `private` and you do not have access. [`~utils.RevisionNotFoundError`] If the revision to download from cannot be found. [`~utils.EntryNotFoundError`] If the file to download cannot be found. [`~utils.LocalEntryNotFoundError`] If network is disabled or unavailable and file is not found in cache. [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) If `token=True` but the token cannot be found. [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) If ETag cannot be determined. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If some parameter value is invalid. """ if constants.HF_HUB_ETAG_TIMEOUT != constants.DEFAULT_ETAG_TIMEOUT: # Respect environment variable above user value etag_timeout = constants.HF_HUB_ETAG_TIMEOUT if force_filename is not None: warnings.warn( "The `force_filename` parameter is deprecated as a new caching system, " "which keeps the filenames as they are on the Hub, is now in place.", FutureWarning, ) if resume_download is not None: warnings.warn( "`resume_download` is deprecated and will be removed in version 1.0.0. " "Downloads always resume when possible. " "If you want to force a new download, use `force_download=True`.", FutureWarning, ) if cache_dir is None: cache_dir = constants.HF_HUB_CACHE if revision is None: revision = constants.DEFAULT_REVISION if isinstance(cache_dir, Path): cache_dir = str(cache_dir) if isinstance(local_dir, Path): local_dir = str(local_dir) if subfolder == "": subfolder = None if subfolder is not None: # This is used to create a URL, and not a local path, hence the forward slash. filename = f"{subfolder}/{filename}" if repo_type is None: repo_type = "model" if repo_type not in constants.REPO_TYPES: raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(constants.REPO_TYPES)}") hf_headers = build_hf_headers( token=token, library_name=library_name, library_version=library_version, user_agent=user_agent, headers=headers, ) if local_dir is not None: if local_dir_use_symlinks != "auto": warnings.warn( "`local_dir_use_symlinks` parameter is deprecated and will be ignored. " "The process to download files to a local folder has been updated and do " "not rely on symlinks anymore. You only need to pass a destination folder " "as`local_dir`.\n" "For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/download#download-files-to-local-folder." ) return _hf_hub_download_to_local_dir( # Destination local_dir=local_dir, # File info repo_id=repo_id, repo_type=repo_type, filename=filename, revision=revision, # HTTP info endpoint=endpoint, etag_timeout=etag_timeout, headers=hf_headers, proxies=proxies, token=token, # Additional options cache_dir=cache_dir, force_download=force_download, local_files_only=local_files_only, ) else: return _hf_hub_download_to_cache_dir( # Destination cache_dir=cache_dir, # File info repo_id=repo_id, filename=filename, repo_type=repo_type, revision=revision, # HTTP info endpoint=endpoint, etag_timeout=etag_timeout, headers=hf_headers, proxies=proxies, token=token, # Additional options local_files_only=local_files_only, force_download=force_download, ) def _hf_hub_download_to_cache_dir( *, # Destination cache_dir: str, # File info repo_id: str, filename: str, repo_type: str, revision: str, # HTTP info endpoint: Optional[str], etag_timeout: float, headers: Dict[str, str], proxies: Optional[Dict], token: Optional[Union[bool, str]], # Additional options local_files_only: bool, force_download: bool, ) -> str: """Download a given file to a cache folder, if not already present. Method should not be called directly. Please use `hf_hub_download` instead. """ locks_dir = os.path.join(cache_dir, ".locks") storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type)) # cross platform transcription of filename, to be used as a local file path. relative_filename = os.path.join(*filename.split("/")) if os.name == "nt": if relative_filename.startswith("..\\") or "\\..\\" in relative_filename: raise ValueError( f"Invalid filename: cannot handle filename '{relative_filename}' on Windows. Please ask the repository" " owner to rename this file." ) # if user provides a commit_hash and they already have the file on disk, shortcut everything. if REGEX_COMMIT_HASH.match(revision): pointer_path = _get_pointer_path(storage_folder, revision, relative_filename) if os.path.exists(pointer_path) and not force_download: return pointer_path # Try to get metadata (etag, commit_hash, url, size) from the server. # If we can't, a HEAD request error is returned. (url_to_download, etag, commit_hash, expected_size, xet_file_data, head_call_error) = _get_metadata_or_catch_error( repo_id=repo_id, filename=filename, repo_type=repo_type, revision=revision, endpoint=endpoint, proxies=proxies, etag_timeout=etag_timeout, headers=headers, token=token, local_files_only=local_files_only, storage_folder=storage_folder, relative_filename=relative_filename, ) # etag can be None for several reasons: # 1. we passed local_files_only. # 2. we don't have a connection # 3. Hub is down (HTTP 500, 503, 504) # 4. repo is not found -for example private or gated- and invalid/missing token sent # 5. Hub is blocked by a firewall or proxy is not set correctly. # => Try to get the last downloaded one from the specified revision. # # If the specified revision is a commit hash, look inside "snapshots". # If the specified revision is a branch or tag, look inside "refs". if head_call_error is not None: # Couldn't make a HEAD call => let's try to find a local file if not force_download: commit_hash = None if REGEX_COMMIT_HASH.match(revision): commit_hash = revision else: ref_path = os.path.join(storage_folder, "refs", revision) if os.path.isfile(ref_path): with open(ref_path) as f: commit_hash = f.read() # Return pointer file if exists if commit_hash is not None: pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename) if os.path.exists(pointer_path) and not force_download: return pointer_path # Otherwise, raise appropriate error _raise_on_head_call_error(head_call_error, force_download, local_files_only) # From now on, etag, commit_hash, url and size are not None. assert etag is not None, "etag must have been retrieved from server" assert commit_hash is not None, "commit_hash must have been retrieved from server" assert url_to_download is not None, "file location must have been retrieved from server" assert expected_size is not None, "expected_size must have been retrieved from server" blob_path = os.path.join(storage_folder, "blobs", etag) pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename) os.makedirs(os.path.dirname(blob_path), exist_ok=True) os.makedirs(os.path.dirname(pointer_path), exist_ok=True) # if passed revision is not identical to commit_hash # then revision has to be a branch name or tag name. # In that case store a ref. _cache_commit_hash_for_specific_revision(storage_folder, revision, commit_hash) # If file already exists, return it (except if force_download=True) if not force_download: if os.path.exists(pointer_path): return pointer_path if os.path.exists(blob_path): # we have the blob already, but not the pointer _create_symlink(blob_path, pointer_path, new_blob=False) return pointer_path # Prevent parallel downloads of the same file with a lock. # etag could be duplicated across repos, lock_path = os.path.join(locks_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type), f"{etag}.lock") # Some Windows versions do not allow for paths longer than 255 characters. # In this case, we must specify it as an extended path by using the "\\?\" prefix. if os.name == "nt" and len(os.path.abspath(lock_path)) > 255: lock_path = "\\\\?\\" + os.path.abspath(lock_path) if os.name == "nt" and len(os.path.abspath(blob_path)) > 255: blob_path = "\\\\?\\" + os.path.abspath(blob_path) # Local file doesn't exist or etag isn't a match => retrieve file from remote (or cache) Path(lock_path).parent.mkdir(parents=True, exist_ok=True) with WeakFileLock(lock_path): _download_to_tmp_and_move( incomplete_path=Path(blob_path + ".incomplete"), destination_path=Path(blob_path), url_to_download=url_to_download, proxies=proxies, headers=headers, expected_size=expected_size, filename=filename, force_download=force_download, etag=etag, xet_file_data=xet_file_data, ) if not os.path.exists(pointer_path): _create_symlink(blob_path, pointer_path, new_blob=True) return pointer_path def _hf_hub_download_to_local_dir( *, # Destination local_dir: Union[str, Path], # File info repo_id: str, repo_type: str, filename: str, revision: str, # HTTP info endpoint: Optional[str], etag_timeout: float, headers: Dict[str, str], proxies: Optional[Dict], token: Union[bool, str, None], # Additional options cache_dir: str, force_download: bool, local_files_only: bool, ) -> str: """Download a given file to a local folder, if not already present. Method should not be called directly. Please use `hf_hub_download` instead. """ # Some Windows versions do not allow for paths longer than 255 characters. # In this case, we must specify it as an extended path by using the "\\?\" prefix. if os.name == "nt" and len(os.path.abspath(local_dir)) > 255: local_dir = "\\\\?\\" + os.path.abspath(local_dir) local_dir = Path(local_dir) paths = get_local_download_paths(local_dir=local_dir, filename=filename) local_metadata = read_download_metadata(local_dir=local_dir, filename=filename) # Local file exists + metadata exists + commit_hash matches => return file if ( not force_download and REGEX_COMMIT_HASH.match(revision) and paths.file_path.is_file() and local_metadata is not None and local_metadata.commit_hash == revision ): return str(paths.file_path) # Local file doesn't exist or commit_hash doesn't match => we need the etag (url_to_download, etag, commit_hash, expected_size, xet_file_data, head_call_error) = _get_metadata_or_catch_error( repo_id=repo_id, filename=filename, repo_type=repo_type, revision=revision, endpoint=endpoint, proxies=proxies, etag_timeout=etag_timeout, headers=headers, token=token, local_files_only=local_files_only, ) if head_call_error is not None: # No HEAD call but local file exists => default to local file if not force_download and paths.file_path.is_file(): logger.warning( f"Couldn't access the Hub to check for update but local file already exists. Defaulting to existing file. (error: {head_call_error})" ) return str(paths.file_path) # Otherwise => raise _raise_on_head_call_error(head_call_error, force_download, local_files_only) # From now on, etag, commit_hash, url and size are not None. assert etag is not None, "etag must have been retrieved from server" assert commit_hash is not None, "commit_hash must have been retrieved from server" assert url_to_download is not None, "file location must have been retrieved from server" assert expected_size is not None, "expected_size must have been retrieved from server" # Local file exists => check if it's up-to-date if not force_download and paths.file_path.is_file(): # etag matches => update metadata and return file if local_metadata is not None and local_metadata.etag == etag: write_download_metadata(local_dir=local_dir, filename=filename, commit_hash=commit_hash, etag=etag) return str(paths.file_path) # metadata is outdated + etag is a sha256 # => means it's an LFS file (large) # => let's compute local hash and compare # => if match, update metadata and return file if local_metadata is None and REGEX_SHA256.match(etag) is not None: with open(paths.file_path, "rb") as f: file_hash = sha_fileobj(f).hex() if file_hash == etag: write_download_metadata(local_dir=local_dir, filename=filename, commit_hash=commit_hash, etag=etag) return str(paths.file_path) # Local file doesn't exist or etag isn't a match => retrieve file from remote (or cache) # If we are lucky enough, the file is already in the cache => copy it if not force_download: cached_path = try_to_load_from_cache( repo_id=repo_id, filename=filename, cache_dir=cache_dir, revision=commit_hash, repo_type=repo_type, ) if isinstance(cached_path, str): with WeakFileLock(paths.lock_path): paths.file_path.parent.mkdir(parents=True, exist_ok=True) shutil.copyfile(cached_path, paths.file_path) write_download_metadata(local_dir=local_dir, filename=filename, commit_hash=commit_hash, etag=etag) return str(paths.file_path) # Otherwise, let's download the file! with WeakFileLock(paths.lock_path): paths.file_path.unlink(missing_ok=True) # delete outdated file first _download_to_tmp_and_move( incomplete_path=paths.incomplete_path(etag), destination_path=paths.file_path, url_to_download=url_to_download, proxies=proxies, headers=headers, expected_size=expected_size, filename=filename, force_download=force_download, etag=etag, xet_file_data=xet_file_data, ) write_download_metadata(local_dir=local_dir, filename=filename, commit_hash=commit_hash, etag=etag) return str(paths.file_path) @validate_hf_hub_args def try_to_load_from_cache( repo_id: str, filename: str, cache_dir: Union[str, Path, None] = None, revision: Optional[str] = None, repo_type: Optional[str] = None, ) -> Union[str, _CACHED_NO_EXIST_T, None]: """ Explores the cache to return the latest cached file for a given revision if found. This function will not raise any exception if the file in not cached. Args: cache_dir (`str` or `os.PathLike`): The folder where the cached files lie. repo_id (`str`): The ID of the repo on huggingface.co. filename (`str`): The filename to look for inside `repo_id`. revision (`str`, *optional*): The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is provided either. repo_type (`str`, *optional*): The type of the repository. Will default to `"model"`. Returns: `Optional[str]` or `_CACHED_NO_EXIST`: Will return `None` if the file was not cached. Otherwise: - The exact path to the cached file if it's found in the cache - A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was cached. Example: ```python from huggingface_hub import try_to_load_from_cache, _CACHED_NO_EXIST filepath = try_to_load_from_cache() if isinstance(filepath, str): # file exists and is cached ... elif filepath is _CACHED_NO_EXIST: # non-existence of file is cached ... else: # file is not cached ... ``` """ if revision is None: revision = "main" if repo_type is None: repo_type = "model" if repo_type not in constants.REPO_TYPES: raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(constants.REPO_TYPES)}") if cache_dir is None: cache_dir = constants.HF_HUB_CACHE object_id = repo_id.replace("/", "--") repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}") if not os.path.isdir(repo_cache): # No cache for this model return None refs_dir = os.path.join(repo_cache, "refs") snapshots_dir = os.path.join(repo_cache, "snapshots") no_exist_dir = os.path.join(repo_cache, ".no_exist") # Resolve refs (for instance to convert main to the associated commit sha) if os.path.isdir(refs_dir): revision_file = os.path.join(refs_dir, revision) if os.path.isfile(revision_file): with open(revision_file) as f: revision = f.read() # Check if file is cached as "no_exist" if os.path.isfile(os.path.join(no_exist_dir, revision, filename)): return _CACHED_NO_EXIST # Check if revision folder exists if not os.path.exists(snapshots_dir): return None cached_shas = os.listdir(snapshots_dir) if revision not in cached_shas: # No cache for this revision and we won't try to return a random revision return None # Check if file exists in cache cached_file = os.path.join(snapshots_dir, revision, filename) return cached_file if os.path.isfile(cached_file) else None @validate_hf_hub_args def get_hf_file_metadata( url: str, token: Union[bool, str, None] = None, proxies: Optional[Dict] = None, timeout: Optional[float] = constants.DEFAULT_REQUEST_TIMEOUT, library_name: Optional[str] = None, library_version: Optional[str] = None, user_agent: Union[Dict, str, None] = None, headers: Optional[Dict[str, str]] = None, ) -> HfFileMetadata: """Fetch metadata of a file versioned on the Hub for a given url. Args: url (`str`): File url, for example returned by [`hf_hub_url`]. token (`str` or `bool`, *optional*): A token to be used for the download. - If `True`, the token is read from the HuggingFace config folder. - If `False` or `None`, no token is provided. - If a string, it's used as the authentication token. proxies (`dict`, *optional*): Dictionary mapping protocol to the URL of the proxy passed to `requests.request`. timeout (`float`, *optional*, defaults to 10): How many seconds to wait for the server to send metadata before giving up. library_name (`str`, *optional*): The name of the library to which the object corresponds. library_version (`str`, *optional*): The version of the library. user_agent (`dict`, `str`, *optional*): The user-agent info in the form of a dictionary or a string. headers (`dict`, *optional*): Additional headers to be sent with the request. Returns: A [`HfFileMetadata`] object containing metadata such as location, etag, size and commit_hash. """ hf_headers = build_hf_headers( token=token, library_name=library_name, library_version=library_version, user_agent=user_agent, headers=headers, ) hf_headers["Accept-Encoding"] = "identity" # prevent any compression => we want to know the real size of the file # Retrieve metadata r = _request_wrapper( method="HEAD", url=url, headers=hf_headers, allow_redirects=False, follow_relative_redirects=True, proxies=proxies, timeout=timeout, ) hf_raise_for_status(r) # Return return HfFileMetadata( commit_hash=r.headers.get(constants.HUGGINGFACE_HEADER_X_REPO_COMMIT), # We favor a custom header indicating the etag of the linked resource, and # we fallback to the regular etag header. etag=_normalize_etag(r.headers.get(constants.HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get("ETag")), # Either from response headers (if redirected) or defaults to request url # Do not use directly `url`, as `_request_wrapper` might have followed relative # redirects. location=r.headers.get("Location") or r.request.url, # type: ignore size=_int_or_none( r.headers.get(constants.HUGGINGFACE_HEADER_X_LINKED_SIZE) or r.headers.get("Content-Length") ), xet_file_data=parse_xet_file_data_from_response(r), # type: ignore ) def _get_metadata_or_catch_error( *, repo_id: str, filename: str, repo_type: str, revision: str, endpoint: Optional[str], proxies: Optional[Dict], etag_timeout: Optional[float], headers: Dict[str, str], # mutated inplace! token: Union[bool, str, None], local_files_only: bool, relative_filename: Optional[str] = None, # only used to store `.no_exists` in cache storage_folder: Optional[str] = None, # only used to store `.no_exists` in cache ) -> Union[ # Either an exception is caught and returned Tuple[None, None, None, None, None, Exception], # Or the metadata is returned as # `(url_to_download, etag, commit_hash, expected_size, xet_file_data, None)` Tuple[str, str, str, int, Optional[XetFileData], None], ]: """Get metadata for a file on the Hub, safely handling network issues. Returns either the etag, commit_hash and expected size of the file, or the error raised while fetching the metadata. NOTE: This function mutates `headers` inplace! It removes the `authorization` header if the file is a LFS blob and the domain of the url is different from the domain of the location (typically an S3 bucket). """ if local_files_only: return ( None, None, None, None, None, OfflineModeIsEnabled( f"Cannot access file since 'local_files_only=True' as been set. (repo_id: {repo_id}, repo_type: {repo_type}, revision: {revision}, filename: {filename})" ), ) url = hf_hub_url(repo_id, filename, repo_type=repo_type, revision=revision, endpoint=endpoint) url_to_download: str = url etag: Optional[str] = None commit_hash: Optional[str] = None expected_size: Optional[int] = None head_error_call: Optional[Exception] = None xet_file_data: Optional[XetFileData] = None # Try to get metadata from the server. # Do not raise yet if the file is not found or not accessible. if not local_files_only: try: try: metadata = get_hf_file_metadata( url=url, proxies=proxies, timeout=etag_timeout, headers=headers, token=token ) except EntryNotFoundError as http_error: if storage_folder is not None and relative_filename is not None: # Cache the non-existence of the file commit_hash = http_error.response.headers.get(constants.HUGGINGFACE_HEADER_X_REPO_COMMIT) if commit_hash is not None: no_exist_file_path = Path(storage_folder) / ".no_exist" / commit_hash / relative_filename try: no_exist_file_path.parent.mkdir(parents=True, exist_ok=True) no_exist_file_path.touch() except OSError as e: logger.error( f"Could not cache non-existence of file. Will ignore error and continue. Error: {e}" ) _cache_commit_hash_for_specific_revision(storage_folder, revision, commit_hash) raise # Commit hash must exist commit_hash = metadata.commit_hash if commit_hash is None: raise FileMetadataError( "Distant resource does not seem to be on huggingface.co. It is possible that a configuration issue" " prevents you from downloading resources from https://huggingface.co. Please check your firewall" " and proxy settings and make sure your SSL certificates are updated." ) # Etag must exist # If we don't have any of those, raise an error. etag = metadata.etag if etag is None: raise FileMetadataError( "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility." ) # Size must exist expected_size = metadata.size if expected_size is None: raise FileMetadataError("Distant resource does not have a Content-Length.") xet_file_data = metadata.xet_file_data # In case of a redirect, save an extra redirect on the request.get call, # and ensure we download the exact atomic version even if it changed # between the HEAD and the GET (unlikely, but hey). # # If url domain is different => we are downloading from a CDN => url is signed => don't send auth # If url domain is the same => redirect due to repo rename AND downloading a regular file => keep auth if xet_file_data is None and url != metadata.location: url_to_download = metadata.location if urlparse(url).netloc != urlparse(metadata.location).netloc: # Remove authorization header when downloading a LFS blob headers.pop("authorization", None) except (requests.exceptions.SSLError, requests.exceptions.ProxyError): # Actually raise for those subclasses of ConnectionError raise except ( requests.exceptions.ConnectionError, requests.exceptions.Timeout, OfflineModeIsEnabled, ) as error: # Otherwise, our Internet connection is down. # etag is None head_error_call = error except (RevisionNotFoundError, EntryNotFoundError): # The repo was found but the revision or entry doesn't exist on the Hub (never existed or got deleted) raise except requests.HTTPError as error: # Multiple reasons for an http error: # - Repository is private and invalid/missing token sent # - Repository is gated and invalid/missing token sent # - Hub is down (error 500 or 504) # => let's switch to 'local_files_only=True' to check if the files are already cached. # (if it's not the case, the error will be re-raised) head_error_call = error except FileMetadataError as error: # Multiple reasons for a FileMetadataError: # - Wrong network configuration (proxy, firewall, SSL certificates) # - Inconsistency on the Hub # => let's switch to 'local_files_only=True' to check if the files are already cached. # (if it's not the case, the error will be re-raised) head_error_call = error if not (local_files_only or etag is not None or head_error_call is not None): raise RuntimeError("etag is empty due to uncovered problems") return (url_to_download, etag, commit_hash, expected_size, xet_file_data, head_error_call) # type: ignore [return-value] def _raise_on_head_call_error(head_call_error: Exception, force_download: bool, local_files_only: bool) -> NoReturn: """Raise an appropriate error when the HEAD call failed and we cannot locate a local file.""" # No head call => we cannot force download. if force_download: if local_files_only: raise ValueError("Cannot pass 'force_download=True' and 'local_files_only=True' at the same time.") elif isinstance(head_call_error, OfflineModeIsEnabled): raise ValueError("Cannot pass 'force_download=True' when offline mode is enabled.") from head_call_error else: raise ValueError("Force download failed due to the above error.") from head_call_error # No head call + couldn't find an appropriate file on disk => raise an error. if local_files_only: raise LocalEntryNotFoundError( "Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable" " hf.co look-ups and downloads online, set 'local_files_only' to False." ) elif isinstance(head_call_error, (RepositoryNotFoundError, GatedRepoError)) or ( isinstance(head_call_error, HfHubHTTPError) and head_call_error.response.status_code == 401 ): # Repo not found or gated => let's raise the actual error # Unauthorized => likely a token issue => let's raise the actual error raise head_call_error else: # Otherwise: most likely a connection issue or Hub downtime => let's warn the user raise LocalEntryNotFoundError( "An error happened while trying to locate the file on the Hub and we cannot find the requested files" " in the local cache. Please check your connection and try again or make sure your Internet connection" " is on." ) from head_call_error def _download_to_tmp_and_move( incomplete_path: Path, destination_path: Path, url_to_download: str, proxies: Optional[Dict], headers: Dict[str, str], expected_size: Optional[int], filename: str, force_download: bool, etag: Optional[str], xet_file_data: Optional[XetFileData], ) -> None: """Download content from a URL to a destination path. Internal logic: - return early if file is already downloaded - resume download if possible (from incomplete file) - do not resume download if `force_download=True` or `HF_HUB_ENABLE_HF_TRANSFER=True` - check disk space before downloading - download content to a temporary file - set correct permissions on temporary file - move the temporary file to the destination path Both `incomplete_path` and `destination_path` must be on the same volume to avoid a local copy. """ if destination_path.exists() and not force_download: # Do nothing if already exists (except if force_download=True) return if incomplete_path.exists() and (force_download or (constants.HF_HUB_ENABLE_HF_TRANSFER and not proxies)): # By default, we will try to resume the download if possible. # However, if the user has set `force_download=True` or if `hf_transfer` is enabled, then we should # not resume the download => delete the incomplete file. message = f"Removing incomplete file '{incomplete_path}'" if force_download: message += " (force_download=True)" elif constants.HF_HUB_ENABLE_HF_TRANSFER and not proxies: message += " (hf_transfer=True)" logger.info(message) incomplete_path.unlink(missing_ok=True) with incomplete_path.open("ab") as f: resume_size = f.tell() message = f"Downloading '{filename}' to '{incomplete_path}'" if resume_size > 0 and expected_size is not None: message += f" (resume from {resume_size}/{expected_size})" logger.info(message) if expected_size is not None: # might be None if HTTP header not set correctly # Check disk space in both tmp and destination path _check_disk_space(expected_size, incomplete_path.parent) _check_disk_space(expected_size, destination_path.parent) if xet_file_data is not None and is_xet_available(): logger.info("Xet Storage is enabled for this repo. Downloading file from Xet Storage..") xet_get( incomplete_path=incomplete_path, xet_file_data=xet_file_data, headers=headers, expected_size=expected_size, displayed_filename=filename, ) else: if xet_file_data is not None: logger.warning( "Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. " "Falling back to regular HTTP download. " "For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`" ) http_get( url_to_download, f, proxies=proxies, resume_size=resume_size, headers=headers, expected_size=expected_size, ) logger.info(f"Download complete. Moving file to {destination_path}") _chmod_and_move(incomplete_path, destination_path) def _int_or_none(value: Optional[str]) -> Optional[int]: try: return int(value) # type: ignore except (TypeError, ValueError): return None def _chmod_and_move(src: Path, dst: Path) -> None: """Set correct permission before moving a blob from tmp directory to cache dir. Do not take into account the `umask` from the process as there is no convenient way to get it that is thread-safe. See: - About umask: https://docs.python.org/3/library/os.html#os.umask - Thread-safety: https://stackoverflow.com/a/70343066 - About solution: https://github.com/huggingface/huggingface_hub/pull/1220#issuecomment-1326211591 - Fix issue: https://github.com/huggingface/huggingface_hub/issues/1141 - Fix issue: https://github.com/huggingface/huggingface_hub/issues/1215 """ # Get umask by creating a temporary file in the cached repo folder. tmp_file = dst.parent.parent / f"tmp_{uuid.uuid4()}" try: tmp_file.touch() cache_dir_mode = Path(tmp_file).stat().st_mode os.chmod(str(src), stat.S_IMODE(cache_dir_mode)) except OSError as e: logger.warning( f"Could not set the permissions on the file '{src}'. Error: {e}.\nContinuing without setting permissions." ) finally: try: tmp_file.unlink() except OSError: # fails if `tmp_file.touch()` failed => do nothing # See https://github.com/huggingface/huggingface_hub/issues/2359 pass shutil.move(str(src), str(dst), copy_function=_copy_no_matter_what) def _copy_no_matter_what(src: str, dst: str) -> None: """Copy file from src to dst. If `shutil.copy2` fails, fallback to `shutil.copyfile`. """ try: # Copy file with metadata and permission # Can fail e.g. if dst is an S3 mount shutil.copy2(src, dst) except OSError: # Copy only file content shutil.copyfile(src, dst) def _get_pointer_path(storage_folder: str, revision: str, relative_filename: str) -> str: # Using `os.path.abspath` instead of `Path.resolve()` to avoid resolving symlinks snapshot_path = os.path.join(storage_folder, "snapshots") pointer_path = os.path.join(snapshot_path, revision, relative_filename) if Path(os.path.abspath(snapshot_path)) not in Path(os.path.abspath(pointer_path)).parents: raise ValueError( "Invalid pointer path: cannot create pointer path in snapshot folder if" f" `storage_folder='{storage_folder}'`, `revision='{revision}'` and" f" `relative_filename='{relative_filename}'`." ) return pointer_path huggingface_hub-0.31.1/src/huggingface_hub/hf_api.py000066400000000000000000015356271500667546600224370ustar00rootroot00000000000000# coding=utf-8 # Copyright 2019-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import inspect import io import json import re import struct import warnings from collections import defaultdict from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import asdict, dataclass, field from datetime import datetime from functools import wraps from itertools import islice from pathlib import Path from typing import ( Any, BinaryIO, Callable, Dict, Iterable, Iterator, List, Literal, Optional, Tuple, TypeVar, Union, overload, ) from urllib.parse import quote, unquote import requests from requests.exceptions import HTTPError from tqdm.auto import tqdm as base_tqdm from tqdm.contrib.concurrent import thread_map from . import constants from ._commit_api import ( CommitOperation, CommitOperationAdd, CommitOperationCopy, CommitOperationDelete, _fetch_files_to_copy, _fetch_upload_modes, _prepare_commit_payload, _upload_lfs_files, _upload_xet_files, _warn_on_overwriting_operations, ) from ._inference_endpoints import InferenceEndpoint, InferenceEndpointType from ._space_api import SpaceHardware, SpaceRuntime, SpaceStorage, SpaceVariable from ._upload_large_folder import upload_large_folder_internal from .community import ( Discussion, DiscussionComment, DiscussionStatusChange, DiscussionTitleChange, DiscussionWithDetails, deserialize_event, ) from .constants import ( DEFAULT_ETAG_TIMEOUT, # noqa: F401 # kept for backward compatibility DEFAULT_REQUEST_TIMEOUT, # noqa: F401 # kept for backward compatibility DEFAULT_REVISION, # noqa: F401 # kept for backward compatibility DISCUSSION_STATUS, # noqa: F401 # kept for backward compatibility DISCUSSION_TYPES, # noqa: F401 # kept for backward compatibility ENDPOINT, # noqa: F401 # kept for backward compatibility INFERENCE_ENDPOINTS_ENDPOINT, # noqa: F401 # kept for backward compatibility REGEX_COMMIT_OID, # noqa: F401 # kept for backward compatibility REPO_TYPE_MODEL, # noqa: F401 # kept for backward compatibility REPO_TYPES, # noqa: F401 # kept for backward compatibility REPO_TYPES_MAPPING, # noqa: F401 # kept for backward compatibility REPO_TYPES_URL_PREFIXES, # noqa: F401 # kept for backward compatibility SAFETENSORS_INDEX_FILE, # noqa: F401 # kept for backward compatibility SAFETENSORS_MAX_HEADER_LENGTH, # noqa: F401 # kept for backward compatibility SAFETENSORS_SINGLE_FILE, # noqa: F401 # kept for backward compatibility SPACES_SDK_TYPES, # noqa: F401 # kept for backward compatibility WEBHOOK_DOMAIN_T, # noqa: F401 # kept for backward compatibility DiscussionStatusFilter, # noqa: F401 # kept for backward compatibility DiscussionTypeFilter, # noqa: F401 # kept for backward compatibility ) from .errors import ( BadRequestError, EntryNotFoundError, GatedRepoError, HfHubHTTPError, RepositoryNotFoundError, RevisionNotFoundError, ) from .file_download import HfFileMetadata, get_hf_file_metadata, hf_hub_url from .repocard_data import DatasetCardData, ModelCardData, SpaceCardData from .utils import ( DEFAULT_IGNORE_PATTERNS, HfFolder, # noqa: F401 # kept for backward compatibility LocalTokenNotFoundError, NotASafetensorsRepoError, SafetensorsFileMetadata, SafetensorsParsingError, SafetensorsRepoMetadata, TensorInfo, build_hf_headers, chunk_iterable, experimental, filter_repo_objects, fix_hf_endpoint_in_url, get_session, get_token, hf_raise_for_status, logging, paginate, parse_datetime, validate_hf_hub_args, ) from .utils import tqdm as hf_tqdm from .utils._auth import _get_token_from_environment, _get_token_from_file, _get_token_from_google_colab from .utils._deprecation import _deprecate_method from .utils._runtime import is_xet_available from .utils._typing import CallableT from .utils.endpoint_helpers import _is_emission_within_threshold R = TypeVar("R") # Return type CollectionItemType_T = Literal["model", "dataset", "space", "paper"] ExpandModelProperty_T = Literal[ "author", "baseModels", "cardData", "childrenModelCount", "config", "createdAt", "disabled", "downloads", "downloadsAllTime", "gated", "gguf", "inference", "inferenceProviderMapping", "lastModified", "library_name", "likes", "mask_token", "model-index", "pipeline_tag", "private", "resourceGroup", "safetensors", "sha", "siblings", "spaces", "tags", "transformersInfo", "trendingScore", "usedStorage", "widgetData", "xetEnabled", ] ExpandDatasetProperty_T = Literal[ "author", "cardData", "citation", "createdAt", "description", "disabled", "downloads", "downloadsAllTime", "gated", "lastModified", "likes", "paperswithcode_id", "private", "resourceGroup", "sha", "siblings", "tags", "trendingScore", "usedStorage", "xetEnabled", ] ExpandSpaceProperty_T = Literal[ "author", "cardData", "createdAt", "datasets", "disabled", "lastModified", "likes", "models", "private", "resourceGroup", "runtime", "sdk", "sha", "siblings", "subdomain", "tags", "trendingScore", "usedStorage", "xetEnabled", ] USERNAME_PLACEHOLDER = "hf_user" _REGEX_DISCUSSION_URL = re.compile(r".*/discussions/(\d+)$") _CREATE_COMMIT_NO_REPO_ERROR_MESSAGE = ( "\nNote: Creating a commit assumes that the repo already exists on the" " Huggingface Hub. Please use `create_repo` if it's not the case." ) _AUTH_CHECK_NO_REPO_ERROR_MESSAGE = ( "\nNote: The repository either does not exist or you do not have access rights." " Please check the repository ID and your access permissions." " If this is a private repository, ensure that your token is correct." ) logger = logging.get_logger(__name__) def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> Tuple[Optional[str], Optional[str], str]: """ Returns the repo type and ID from a huggingface.co URL linking to a repository Args: hf_id (`str`): An URL or ID of a repository on the HF hub. Accepted values are: - https://huggingface.co/// - https://huggingface.co// - hf://// - hf:/// - // - / - hub_url (`str`, *optional*): The URL of the HuggingFace Hub, defaults to https://huggingface.co Returns: A tuple with three items: repo_type (`str` or `None`), namespace (`str` or `None`) and repo_id (`str`). Raises: [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If URL cannot be parsed. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If `repo_type` is unknown. """ input_hf_id = hf_id hub_url = re.sub(r"https?://", "", hub_url if hub_url is not None else constants.ENDPOINT) is_hf_url = hub_url in hf_id and "@" not in hf_id HFFS_PREFIX = "hf://" if hf_id.startswith(HFFS_PREFIX): # Remove "hf://" prefix if exists hf_id = hf_id[len(HFFS_PREFIX) :] url_segments = hf_id.split("/") is_hf_id = len(url_segments) <= 3 namespace: Optional[str] if is_hf_url: namespace, repo_id = url_segments[-2:] if namespace == hub_url: namespace = None if len(url_segments) > 2 and hub_url not in url_segments[-3]: repo_type = url_segments[-3] elif namespace in constants.REPO_TYPES_MAPPING: # Mean canonical dataset or model repo_type = constants.REPO_TYPES_MAPPING[namespace] namespace = None else: repo_type = None elif is_hf_id: if len(url_segments) == 3: # Passed // or // repo_type, namespace, repo_id = url_segments[-3:] elif len(url_segments) == 2: if url_segments[0] in constants.REPO_TYPES_MAPPING: # Passed '' or 'datasets/' for a canonical model or dataset repo_type = constants.REPO_TYPES_MAPPING[url_segments[0]] namespace = None repo_id = hf_id.split("/")[-1] else: # Passed / or / namespace, repo_id = hf_id.split("/")[-2:] repo_type = None else: # Passed repo_id = url_segments[0] namespace, repo_type = None, None else: raise ValueError(f"Unable to retrieve user and repo ID from the passed HF ID: {hf_id}") # Check if repo type is known (mapping "spaces" => "space" + empty value => `None`) if repo_type in constants.REPO_TYPES_MAPPING: repo_type = constants.REPO_TYPES_MAPPING[repo_type] if repo_type == "": repo_type = None if repo_type not in constants.REPO_TYPES: raise ValueError(f"Unknown `repo_type`: '{repo_type}' ('{input_hf_id}')") return repo_type, namespace, repo_id @dataclass class LastCommitInfo(dict): oid: str title: str date: datetime def __post_init__(self): # hack to make LastCommitInfo backward compatible self.update(asdict(self)) @dataclass class BlobLfsInfo(dict): size: int sha256: str pointer_size: int def __post_init__(self): # hack to make BlobLfsInfo backward compatible self.update(asdict(self)) @dataclass class BlobSecurityInfo(dict): safe: bool # duplicate information with "status" field, keeping it for backward compatibility status: str av_scan: Optional[Dict] pickle_import_scan: Optional[Dict] def __post_init__(self): # hack to make BlogSecurityInfo backward compatible self.update(asdict(self)) @dataclass class TransformersInfo(dict): auto_model: str custom_class: Optional[str] = None # possible `pipeline_tag` values: https://github.com/huggingface/huggingface.js/blob/3ee32554b8620644a6287e786b2a83bf5caf559c/packages/tasks/src/pipelines.ts#L72 pipeline_tag: Optional[str] = None processor: Optional[str] = None def __post_init__(self): # hack to make TransformersInfo backward compatible self.update(asdict(self)) @dataclass class SafeTensorsInfo(dict): parameters: Dict[str, int] total: int def __post_init__(self): # hack to make SafeTensorsInfo backward compatible self.update(asdict(self)) @dataclass class CommitInfo(str): """Data structure containing information about a newly created commit. Returned by any method that creates a commit on the Hub: [`create_commit`], [`upload_file`], [`upload_folder`], [`delete_file`], [`delete_folder`]. It inherits from `str` for backward compatibility but using methods specific to `str` is deprecated. Attributes: commit_url (`str`): Url where to find the commit. commit_message (`str`): The summary (first line) of the commit that has been created. commit_description (`str`): Description of the commit that has been created. Can be empty. oid (`str`): Commit hash id. Example: `"91c54ad1727ee830252e457677f467be0bfd8a57"`. pr_url (`str`, *optional*): Url to the PR that has been created, if any. Populated when `create_pr=True` is passed. pr_revision (`str`, *optional*): Revision of the PR that has been created, if any. Populated when `create_pr=True` is passed. Example: `"refs/pr/1"`. pr_num (`int`, *optional*): Number of the PR discussion that has been created, if any. Populated when `create_pr=True` is passed. Can be passed as `discussion_num` in [`get_discussion_details`]. Example: `1`. repo_url (`RepoUrl`): Repo URL of the commit containing info like repo_id, repo_type, etc. _url (`str`, *optional*): Legacy url for `str` compatibility. Can be the url to the uploaded file on the Hub (if returned by [`upload_file`]), to the uploaded folder on the Hub (if returned by [`upload_folder`]) or to the commit on the Hub (if returned by [`create_commit`]). Defaults to `commit_url`. It is deprecated to use this attribute. Please use `commit_url` instead. """ commit_url: str commit_message: str commit_description: str oid: str pr_url: Optional[str] = None # Computed from `commit_url` in `__post_init__` repo_url: RepoUrl = field(init=False) # Computed from `pr_url` in `__post_init__` pr_revision: Optional[str] = field(init=False) pr_num: Optional[str] = field(init=False) # legacy url for `str` compatibility (ex: url to uploaded file, url to uploaded folder, url to PR, etc.) _url: str = field(repr=False, default=None) # type: ignore # defaults to `commit_url` def __new__(cls, *args, commit_url: str, _url: Optional[str] = None, **kwargs): return str.__new__(cls, _url or commit_url) def __post_init__(self): """Populate pr-related fields after initialization. See https://docs.python.org/3.10/library/dataclasses.html#post-init-processing. """ # Repo info self.repo_url = RepoUrl(self.commit_url.split("/commit/")[0]) # PR info if self.pr_url is not None: self.pr_revision = _parse_revision_from_pr_url(self.pr_url) self.pr_num = int(self.pr_revision.split("/")[-1]) else: self.pr_revision = None self.pr_num = None @dataclass class AccessRequest: """Data structure containing information about a user access request. Attributes: username (`str`): Username of the user who requested access. fullname (`str`): Fullname of the user who requested access. email (`Optional[str]`): Email of the user who requested access. Can only be `None` in the /accepted list if the user was granted access manually. timestamp (`datetime`): Timestamp of the request. status (`Literal["pending", "accepted", "rejected"]`): Status of the request. Can be one of `["pending", "accepted", "rejected"]`. fields (`Dict[str, Any]`, *optional*): Additional fields filled by the user in the gate form. """ username: str fullname: str email: Optional[str] timestamp: datetime status: Literal["pending", "accepted", "rejected"] # Additional fields filled by the user in the gate form fields: Optional[Dict[str, Any]] = None @dataclass class WebhookWatchedItem: """Data structure containing information about the items watched by a webhook. Attributes: type (`Literal["dataset", "model", "org", "space", "user"]`): Type of the item to be watched. Can be one of `["dataset", "model", "org", "space", "user"]`. name (`str`): Name of the item to be watched. Can be the username, organization name, model name, dataset name or space name. """ type: Literal["dataset", "model", "org", "space", "user"] name: str @dataclass class WebhookInfo: """Data structure containing information about a webhook. Attributes: id (`str`): ID of the webhook. url (`str`): URL of the webhook. watched (`List[WebhookWatchedItem]`): List of items watched by the webhook, see [`WebhookWatchedItem`]. domains (`List[WEBHOOK_DOMAIN_T]`): List of domains the webhook is watching. Can be one of `["repo", "discussions"]`. secret (`str`, *optional*): Secret of the webhook. disabled (`bool`): Whether the webhook is disabled or not. """ id: str url: str watched: List[WebhookWatchedItem] domains: List[constants.WEBHOOK_DOMAIN_T] secret: Optional[str] disabled: bool class RepoUrl(str): """Subclass of `str` describing a repo URL on the Hub. `RepoUrl` is returned by `HfApi.create_repo`. It inherits from `str` for backward compatibility. At initialization, the URL is parsed to populate properties: - endpoint (`str`) - namespace (`Optional[str]`) - repo_name (`str`) - repo_id (`str`) - repo_type (`Literal["model", "dataset", "space"]`) - url (`str`) Args: url (`Any`): String value of the repo url. endpoint (`str`, *optional*): Endpoint of the Hub. Defaults to . Example: ```py >>> RepoUrl('https://huggingface.co/gpt2') RepoUrl('https://huggingface.co/gpt2', endpoint='https://huggingface.co', repo_type='model', repo_id='gpt2') >>> RepoUrl('https://hub-ci.huggingface.co/datasets/dummy_user/dummy_dataset', endpoint='https://hub-ci.huggingface.co') RepoUrl('https://hub-ci.huggingface.co/datasets/dummy_user/dummy_dataset', endpoint='https://hub-ci.huggingface.co', repo_type='dataset', repo_id='dummy_user/dummy_dataset') >>> RepoUrl('hf://datasets/my-user/my-dataset') RepoUrl('hf://datasets/my-user/my-dataset', endpoint='https://huggingface.co', repo_type='dataset', repo_id='user/dataset') >>> HfApi.create_repo("dummy_model") RepoUrl('https://huggingface.co/Wauplin/dummy_model', endpoint='https://huggingface.co', repo_type='model', repo_id='Wauplin/dummy_model') ``` Raises: [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If URL cannot be parsed. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If `repo_type` is unknown. """ def __new__(cls, url: Any, endpoint: Optional[str] = None): url = fix_hf_endpoint_in_url(url, endpoint=endpoint) return super(RepoUrl, cls).__new__(cls, url) def __init__(self, url: Any, endpoint: Optional[str] = None) -> None: super().__init__() # Parse URL self.endpoint = endpoint or constants.ENDPOINT repo_type, namespace, repo_name = repo_type_and_id_from_hf_id(self, hub_url=self.endpoint) # Populate fields self.namespace = namespace self.repo_name = repo_name self.repo_id = repo_name if namespace is None else f"{namespace}/{repo_name}" self.repo_type = repo_type or constants.REPO_TYPE_MODEL self.url = str(self) # just in case it's needed def __repr__(self) -> str: return f"RepoUrl('{self}', endpoint='{self.endpoint}', repo_type='{self.repo_type}', repo_id='{self.repo_id}')" @dataclass class RepoSibling: """ Contains basic information about a repo file inside a repo on the Hub. All attributes of this class are optional except `rfilename`. This is because only the file names are returned when listing repositories on the Hub (with [`list_models`], [`list_datasets`] or [`list_spaces`]). If you need more information like file size, blob id or lfs details, you must request them specifically from one repo at a time (using [`model_info`], [`dataset_info`] or [`space_info`]) as it adds more constraints on the backend server to retrieve these. Attributes: rfilename (str): file name, relative to the repo root. size (`int`, *optional*): The file's size, in bytes. This attribute is defined when `files_metadata` argument of [`repo_info`] is set to `True`. It's `None` otherwise. blob_id (`str`, *optional*): The file's git OID. This attribute is defined when `files_metadata` argument of [`repo_info`] is set to `True`. It's `None` otherwise. lfs (`BlobLfsInfo`, *optional*): The file's LFS metadata. This attribute is defined when`files_metadata` argument of [`repo_info`] is set to `True` and the file is stored with Git LFS. It's `None` otherwise. """ rfilename: str size: Optional[int] = None blob_id: Optional[str] = None lfs: Optional[BlobLfsInfo] = None @dataclass class RepoFile: """ Contains information about a file on the Hub. Attributes: path (str): file path relative to the repo root. size (`int`): The file's size, in bytes. blob_id (`str`): The file's git OID. lfs (`BlobLfsInfo`): The file's LFS metadata. last_commit (`LastCommitInfo`, *optional*): The file's last commit metadata. Only defined if [`list_repo_tree`] and [`get_paths_info`] are called with `expand=True`. security (`BlobSecurityInfo`, *optional*): The file's security scan metadata. Only defined if [`list_repo_tree`] and [`get_paths_info`] are called with `expand=True`. """ path: str size: int blob_id: str lfs: Optional[BlobLfsInfo] = None last_commit: Optional[LastCommitInfo] = None security: Optional[BlobSecurityInfo] = None def __init__(self, **kwargs): self.path = kwargs.pop("path") self.size = kwargs.pop("size") self.blob_id = kwargs.pop("oid") lfs = kwargs.pop("lfs", None) if lfs is not None: lfs = BlobLfsInfo(size=lfs["size"], sha256=lfs["oid"], pointer_size=lfs["pointerSize"]) self.lfs = lfs last_commit = kwargs.pop("lastCommit", None) or kwargs.pop("last_commit", None) if last_commit is not None: last_commit = LastCommitInfo( oid=last_commit["id"], title=last_commit["title"], date=parse_datetime(last_commit["date"]) ) self.last_commit = last_commit security = kwargs.pop("securityFileStatus", None) if security is not None: safe = security["status"] == "safe" security = BlobSecurityInfo( safe=safe, status=security["status"], av_scan=security["avScan"], pickle_import_scan=security["pickleImportScan"], ) self.security = security # backwards compatibility self.rfilename = self.path self.lastCommit = self.last_commit @dataclass class RepoFolder: """ Contains information about a folder on the Hub. Attributes: path (str): folder path relative to the repo root. tree_id (`str`): The folder's git OID. last_commit (`LastCommitInfo`, *optional*): The folder's last commit metadata. Only defined if [`list_repo_tree`] and [`get_paths_info`] are called with `expand=True`. """ path: str tree_id: str last_commit: Optional[LastCommitInfo] = None def __init__(self, **kwargs): self.path = kwargs.pop("path") self.tree_id = kwargs.pop("oid") last_commit = kwargs.pop("lastCommit", None) or kwargs.pop("last_commit", None) if last_commit is not None: last_commit = LastCommitInfo( oid=last_commit["id"], title=last_commit["title"], date=parse_datetime(last_commit["date"]) ) self.last_commit = last_commit @dataclass class InferenceProviderMapping: hf_model_id: str status: Literal["live", "staging"] provider_id: str task: str adapter: Optional[str] = None adapter_weights_path: Optional[str] = None def __init__(self, **kwargs): self.hf_model_id = kwargs.pop("hf_model_id") self.status = kwargs.pop("status") self.provider_id = kwargs.pop("providerId") self.task = kwargs.pop("task") self.adapter = kwargs.pop("adapter", None) self.adapter_weights_path = kwargs.pop("adapterWeightsPath", None) self.__dict__.update(**kwargs) @dataclass class ModelInfo: """ Contains information about a model on the Hub. Most attributes of this class are optional. This is because the data returned by the Hub depends on the query made. In general, the more specific the query, the more information is returned. On the contrary, when listing models using [`list_models`] only a subset of the attributes are returned. Attributes: id (`str`): ID of model. author (`str`, *optional*): Author of the model. sha (`str`, *optional*): Repo SHA at this particular revision. created_at (`datetime`, *optional*): Date of creation of the repo on the Hub. Note that the lowest value is `2022-03-02T23:29:04.000Z`, corresponding to the date when we began to store creation dates. last_modified (`datetime`, *optional*): Date of last commit to the repo. private (`bool`): Is the repo private. disabled (`bool`, *optional*): Is the repo disabled. downloads (`int`): Number of downloads of the model over the last 30 days. downloads_all_time (`int`): Cumulated number of downloads of the model since its creation. gated (`Literal["auto", "manual", False]`, *optional*): Is the repo gated. If so, whether there is manual or automatic approval. gguf (`Dict`, *optional*): GGUF information of the model. inference (`Literal["cold", "frozen", "warm"]`, *optional*): Status of the model on the inference API. Warm models are available for immediate use. Cold models will be loaded on first inference call. Frozen models are not available in Inference API. inference_provider_mapping (`Dict`, *optional*): Model's inference provider mapping. likes (`int`): Number of likes of the model. library_name (`str`, *optional*): Library associated with the model. tags (`List[str]`): List of tags of the model. Compared to `card_data.tags`, contains extra tags computed by the Hub (e.g. supported libraries, model's arXiv). pipeline_tag (`str`, *optional*): Pipeline tag associated with the model. mask_token (`str`, *optional*): Mask token used by the model. widget_data (`Any`, *optional*): Widget data associated with the model. model_index (`Dict`, *optional*): Model index for evaluation. config (`Dict`, *optional*): Model configuration. transformers_info (`TransformersInfo`, *optional*): Transformers-specific info (auto class, processor, etc.) associated with the model. trending_score (`int`, *optional*): Trending score of the model. card_data (`ModelCardData`, *optional*): Model Card Metadata as a [`huggingface_hub.repocard_data.ModelCardData`] object. siblings (`List[RepoSibling]`): List of [`huggingface_hub.hf_api.RepoSibling`] objects that constitute the model. spaces (`List[str]`, *optional*): List of spaces using the model. safetensors (`SafeTensorsInfo`, *optional*): Model's safetensors information. security_repo_status (`Dict`, *optional*): Model's security scan status. """ id: str author: Optional[str] sha: Optional[str] created_at: Optional[datetime] last_modified: Optional[datetime] private: Optional[bool] disabled: Optional[bool] downloads: Optional[int] downloads_all_time: Optional[int] gated: Optional[Literal["auto", "manual", False]] gguf: Optional[Dict] inference: Optional[Literal["warm", "cold", "frozen"]] inference_provider_mapping: Optional[Dict[str, InferenceProviderMapping]] likes: Optional[int] library_name: Optional[str] tags: Optional[List[str]] pipeline_tag: Optional[str] mask_token: Optional[str] card_data: Optional[ModelCardData] widget_data: Optional[Any] model_index: Optional[Dict] config: Optional[Dict] transformers_info: Optional[TransformersInfo] trending_score: Optional[int] siblings: Optional[List[RepoSibling]] spaces: Optional[List[str]] safetensors: Optional[SafeTensorsInfo] security_repo_status: Optional[Dict] xet_enabled: Optional[bool] def __init__(self, **kwargs): self.id = kwargs.pop("id") self.author = kwargs.pop("author", None) self.sha = kwargs.pop("sha", None) last_modified = kwargs.pop("lastModified", None) or kwargs.pop("last_modified", None) self.last_modified = parse_datetime(last_modified) if last_modified else None created_at = kwargs.pop("createdAt", None) or kwargs.pop("created_at", None) self.created_at = parse_datetime(created_at) if created_at else None self.private = kwargs.pop("private", None) self.gated = kwargs.pop("gated", None) self.disabled = kwargs.pop("disabled", None) self.downloads = kwargs.pop("downloads", None) self.downloads_all_time = kwargs.pop("downloadsAllTime", None) self.likes = kwargs.pop("likes", None) self.library_name = kwargs.pop("library_name", None) self.gguf = kwargs.pop("gguf", None) self.inference = kwargs.pop("inference", None) self.inference_provider_mapping = kwargs.pop("inferenceProviderMapping", None) if self.inference_provider_mapping: self.inference_provider_mapping = { provider: InferenceProviderMapping( **{**value, "hf_model_id": self.id} ) # little hack to simplify Inference Providers logic for provider, value in self.inference_provider_mapping.items() } self.tags = kwargs.pop("tags", None) self.pipeline_tag = kwargs.pop("pipeline_tag", None) self.mask_token = kwargs.pop("mask_token", None) self.trending_score = kwargs.pop("trendingScore", None) card_data = kwargs.pop("cardData", None) or kwargs.pop("card_data", None) self.card_data = ( ModelCardData(**card_data, ignore_metadata_errors=True) if isinstance(card_data, dict) else card_data ) self.widget_data = kwargs.pop("widgetData", None) self.model_index = kwargs.pop("model-index", None) or kwargs.pop("model_index", None) self.config = kwargs.pop("config", None) transformers_info = kwargs.pop("transformersInfo", None) or kwargs.pop("transformers_info", None) self.transformers_info = TransformersInfo(**transformers_info) if transformers_info else None siblings = kwargs.pop("siblings", None) self.siblings = ( [ RepoSibling( rfilename=sibling["rfilename"], size=sibling.get("size"), blob_id=sibling.get("blobId"), lfs=( BlobLfsInfo( size=sibling["lfs"]["size"], sha256=sibling["lfs"]["sha256"], pointer_size=sibling["lfs"]["pointerSize"], ) if sibling.get("lfs") else None ), ) for sibling in siblings ] if siblings is not None else None ) self.spaces = kwargs.pop("spaces", None) safetensors = kwargs.pop("safetensors", None) self.safetensors = ( SafeTensorsInfo( parameters=safetensors["parameters"], total=safetensors["total"], ) if safetensors else None ) self.security_repo_status = kwargs.pop("securityRepoStatus", None) self.xet_enabled = kwargs.pop("xetEnabled", None) # backwards compatibility self.lastModified = self.last_modified self.cardData = self.card_data self.transformersInfo = self.transformers_info self.__dict__.update(**kwargs) @dataclass class DatasetInfo: """ Contains information about a dataset on the Hub. Most attributes of this class are optional. This is because the data returned by the Hub depends on the query made. In general, the more specific the query, the more information is returned. On the contrary, when listing datasets using [`list_datasets`] only a subset of the attributes are returned. Attributes: id (`str`): ID of dataset. author (`str`): Author of the dataset. sha (`str`): Repo SHA at this particular revision. created_at (`datetime`, *optional*): Date of creation of the repo on the Hub. Note that the lowest value is `2022-03-02T23:29:04.000Z`, corresponding to the date when we began to store creation dates. last_modified (`datetime`, *optional*): Date of last commit to the repo. private (`bool`): Is the repo private. disabled (`bool`, *optional*): Is the repo disabled. gated (`Literal["auto", "manual", False]`, *optional*): Is the repo gated. If so, whether there is manual or automatic approval. downloads (`int`): Number of downloads of the dataset over the last 30 days. downloads_all_time (`int`): Cumulated number of downloads of the model since its creation. likes (`int`): Number of likes of the dataset. tags (`List[str]`): List of tags of the dataset. card_data (`DatasetCardData`, *optional*): Model Card Metadata as a [`huggingface_hub.repocard_data.DatasetCardData`] object. siblings (`List[RepoSibling]`): List of [`huggingface_hub.hf_api.RepoSibling`] objects that constitute the dataset. paperswithcode_id (`str`, *optional*): Papers with code ID of the dataset. trending_score (`int`, *optional*): Trending score of the dataset. """ id: str author: Optional[str] sha: Optional[str] created_at: Optional[datetime] last_modified: Optional[datetime] private: Optional[bool] gated: Optional[Literal["auto", "manual", False]] disabled: Optional[bool] downloads: Optional[int] downloads_all_time: Optional[int] likes: Optional[int] paperswithcode_id: Optional[str] tags: Optional[List[str]] trending_score: Optional[int] card_data: Optional[DatasetCardData] siblings: Optional[List[RepoSibling]] xet_enabled: Optional[bool] def __init__(self, **kwargs): self.id = kwargs.pop("id") self.author = kwargs.pop("author", None) self.sha = kwargs.pop("sha", None) created_at = kwargs.pop("createdAt", None) or kwargs.pop("created_at", None) self.created_at = parse_datetime(created_at) if created_at else None last_modified = kwargs.pop("lastModified", None) or kwargs.pop("last_modified", None) self.last_modified = parse_datetime(last_modified) if last_modified else None self.private = kwargs.pop("private", None) self.gated = kwargs.pop("gated", None) self.disabled = kwargs.pop("disabled", None) self.downloads = kwargs.pop("downloads", None) self.downloads_all_time = kwargs.pop("downloadsAllTime", None) self.likes = kwargs.pop("likes", None) self.paperswithcode_id = kwargs.pop("paperswithcode_id", None) self.tags = kwargs.pop("tags", None) self.trending_score = kwargs.pop("trendingScore", None) card_data = kwargs.pop("cardData", None) or kwargs.pop("card_data", None) self.card_data = ( DatasetCardData(**card_data, ignore_metadata_errors=True) if isinstance(card_data, dict) else card_data ) siblings = kwargs.pop("siblings", None) self.siblings = ( [ RepoSibling( rfilename=sibling["rfilename"], size=sibling.get("size"), blob_id=sibling.get("blobId"), lfs=( BlobLfsInfo( size=sibling["lfs"]["size"], sha256=sibling["lfs"]["sha256"], pointer_size=sibling["lfs"]["pointerSize"], ) if sibling.get("lfs") else None ), ) for sibling in siblings ] if siblings is not None else None ) self.xet_enabled = kwargs.pop("xetEnabled", None) # backwards compatibility self.lastModified = self.last_modified self.cardData = self.card_data self.__dict__.update(**kwargs) @dataclass class SpaceInfo: """ Contains information about a Space on the Hub. Most attributes of this class are optional. This is because the data returned by the Hub depends on the query made. In general, the more specific the query, the more information is returned. On the contrary, when listing spaces using [`list_spaces`] only a subset of the attributes are returned. Attributes: id (`str`): ID of the Space. author (`str`, *optional*): Author of the Space. sha (`str`, *optional*): Repo SHA at this particular revision. created_at (`datetime`, *optional*): Date of creation of the repo on the Hub. Note that the lowest value is `2022-03-02T23:29:04.000Z`, corresponding to the date when we began to store creation dates. last_modified (`datetime`, *optional*): Date of last commit to the repo. private (`bool`): Is the repo private. gated (`Literal["auto", "manual", False]`, *optional*): Is the repo gated. If so, whether there is manual or automatic approval. disabled (`bool`, *optional*): Is the Space disabled. host (`str`, *optional*): Host URL of the Space. subdomain (`str`, *optional*): Subdomain of the Space. likes (`int`): Number of likes of the Space. tags (`List[str]`): List of tags of the Space. siblings (`List[RepoSibling]`): List of [`huggingface_hub.hf_api.RepoSibling`] objects that constitute the Space. card_data (`SpaceCardData`, *optional*): Space Card Metadata as a [`huggingface_hub.repocard_data.SpaceCardData`] object. runtime (`SpaceRuntime`, *optional*): Space runtime information as a [`huggingface_hub.hf_api.SpaceRuntime`] object. sdk (`str`, *optional*): SDK used by the Space. models (`List[str]`, *optional*): List of models used by the Space. datasets (`List[str]`, *optional*): List of datasets used by the Space. trending_score (`int`, *optional*): Trending score of the Space. """ id: str author: Optional[str] sha: Optional[str] created_at: Optional[datetime] last_modified: Optional[datetime] private: Optional[bool] gated: Optional[Literal["auto", "manual", False]] disabled: Optional[bool] host: Optional[str] subdomain: Optional[str] likes: Optional[int] sdk: Optional[str] tags: Optional[List[str]] siblings: Optional[List[RepoSibling]] trending_score: Optional[int] card_data: Optional[SpaceCardData] runtime: Optional[SpaceRuntime] models: Optional[List[str]] datasets: Optional[List[str]] xet_enabled: Optional[bool] def __init__(self, **kwargs): self.id = kwargs.pop("id") self.author = kwargs.pop("author", None) self.sha = kwargs.pop("sha", None) created_at = kwargs.pop("createdAt", None) or kwargs.pop("created_at", None) self.created_at = parse_datetime(created_at) if created_at else None last_modified = kwargs.pop("lastModified", None) or kwargs.pop("last_modified", None) self.last_modified = parse_datetime(last_modified) if last_modified else None self.private = kwargs.pop("private", None) self.gated = kwargs.pop("gated", None) self.disabled = kwargs.pop("disabled", None) self.host = kwargs.pop("host", None) self.subdomain = kwargs.pop("subdomain", None) self.likes = kwargs.pop("likes", None) self.sdk = kwargs.pop("sdk", None) self.tags = kwargs.pop("tags", None) self.trending_score = kwargs.pop("trendingScore", None) card_data = kwargs.pop("cardData", None) or kwargs.pop("card_data", None) self.card_data = ( SpaceCardData(**card_data, ignore_metadata_errors=True) if isinstance(card_data, dict) else card_data ) siblings = kwargs.pop("siblings", None) self.siblings = ( [ RepoSibling( rfilename=sibling["rfilename"], size=sibling.get("size"), blob_id=sibling.get("blobId"), lfs=( BlobLfsInfo( size=sibling["lfs"]["size"], sha256=sibling["lfs"]["sha256"], pointer_size=sibling["lfs"]["pointerSize"], ) if sibling.get("lfs") else None ), ) for sibling in siblings ] if siblings is not None else None ) runtime = kwargs.pop("runtime", None) self.runtime = SpaceRuntime(runtime) if runtime else None self.models = kwargs.pop("models", None) self.datasets = kwargs.pop("datasets", None) self.xet_enabled = kwargs.pop("xetEnabled", None) # backwards compatibility self.lastModified = self.last_modified self.cardData = self.card_data self.__dict__.update(**kwargs) @dataclass class CollectionItem: """ Contains information about an item of a Collection (model, dataset, Space or paper). Attributes: item_object_id (`str`): Unique ID of the item in the collection. item_id (`str`): ID of the underlying object on the Hub. Can be either a repo_id or a paper id e.g. `"jbilcke-hf/ai-comic-factory"`, `"2307.09288"`. item_type (`str`): Type of the underlying object. Can be one of `"model"`, `"dataset"`, `"space"` or `"paper"`. position (`int`): Position of the item in the collection. note (`str`, *optional*): Note associated with the item, as plain text. """ item_object_id: str # id in database item_id: str # repo_id or paper id item_type: str position: int note: Optional[str] = None def __init__( self, _id: str, id: str, type: CollectionItemType_T, position: int, note: Optional[Dict] = None, **kwargs ) -> None: self.item_object_id: str = _id # id in database self.item_id: str = id # repo_id or paper id self.item_type: CollectionItemType_T = type self.position: int = position self.note: str = note["text"] if note is not None else None @dataclass class Collection: """ Contains information about a Collection on the Hub. Attributes: slug (`str`): Slug of the collection. E.g. `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. title (`str`): Title of the collection. E.g. `"Recent models"`. owner (`str`): Owner of the collection. E.g. `"TheBloke"`. items (`List[CollectionItem]`): List of items in the collection. last_updated (`datetime`): Date of the last update of the collection. position (`int`): Position of the collection in the list of collections of the owner. private (`bool`): Whether the collection is private or not. theme (`str`): Theme of the collection. E.g. `"green"`. upvotes (`int`): Number of upvotes of the collection. description (`str`, *optional*): Description of the collection, as plain text. url (`str`): (property) URL of the collection on the Hub. """ slug: str title: str owner: str items: List[CollectionItem] last_updated: datetime position: int private: bool theme: str upvotes: int description: Optional[str] = None def __init__(self, **kwargs) -> None: self.slug = kwargs.pop("slug") self.title = kwargs.pop("title") self.owner = kwargs.pop("owner") self.items = [CollectionItem(**item) for item in kwargs.pop("items")] self.last_updated = parse_datetime(kwargs.pop("lastUpdated")) self.position = kwargs.pop("position") self.private = kwargs.pop("private") self.theme = kwargs.pop("theme") self.upvotes = kwargs.pop("upvotes") self.description = kwargs.pop("description", None) endpoint = kwargs.pop("endpoint", None) if endpoint is None: endpoint = constants.ENDPOINT self._url = f"{endpoint}/collections/{self.slug}" @property def url(self) -> str: """Returns the URL of the collection on the Hub.""" return self._url @dataclass class GitRefInfo: """ Contains information about a git reference for a repo on the Hub. Attributes: name (`str`): Name of the reference (e.g. tag name or branch name). ref (`str`): Full git ref on the Hub (e.g. `"refs/heads/main"` or `"refs/tags/v1.0"`). target_commit (`str`): OID of the target commit for the ref (e.g. `"e7da7f221d5bf496a48136c0cd264e630fe9fcc8"`) """ name: str ref: str target_commit: str @dataclass class GitRefs: """ Contains information about all git references for a repo on the Hub. Object is returned by [`list_repo_refs`]. Attributes: branches (`List[GitRefInfo]`): A list of [`GitRefInfo`] containing information about branches on the repo. converts (`List[GitRefInfo]`): A list of [`GitRefInfo`] containing information about "convert" refs on the repo. Converts are refs used (internally) to push preprocessed data in Dataset repos. tags (`List[GitRefInfo]`): A list of [`GitRefInfo`] containing information about tags on the repo. pull_requests (`List[GitRefInfo]`, *optional*): A list of [`GitRefInfo`] containing information about pull requests on the repo. Only returned if `include_prs=True` is set. """ branches: List[GitRefInfo] converts: List[GitRefInfo] tags: List[GitRefInfo] pull_requests: Optional[List[GitRefInfo]] = None @dataclass class GitCommitInfo: """ Contains information about a git commit for a repo on the Hub. Check out [`list_repo_commits`] for more details. Attributes: commit_id (`str`): OID of the commit (e.g. `"e7da7f221d5bf496a48136c0cd264e630fe9fcc8"`) authors (`List[str]`): List of authors of the commit. created_at (`datetime`): Datetime when the commit was created. title (`str`): Title of the commit. This is a free-text value entered by the authors. message (`str`): Description of the commit. This is a free-text value entered by the authors. formatted_title (`str`): Title of the commit formatted as HTML. Only returned if `formatted=True` is set. formatted_message (`str`): Description of the commit formatted as HTML. Only returned if `formatted=True` is set. """ commit_id: str authors: List[str] created_at: datetime title: str message: str formatted_title: Optional[str] formatted_message: Optional[str] @dataclass class UserLikes: """ Contains information about a user likes on the Hub. Attributes: user (`str`): Name of the user for which we fetched the likes. total (`int`): Total number of likes. datasets (`List[str]`): List of datasets liked by the user (as repo_ids). models (`List[str]`): List of models liked by the user (as repo_ids). spaces (`List[str]`): List of spaces liked by the user (as repo_ids). """ # Metadata user: str total: int # User likes datasets: List[str] models: List[str] spaces: List[str] @dataclass class Organization: """ Contains information about an organization on the Hub. Attributes: avatar_url (`str`): URL of the organization's avatar. name (`str`): Name of the organization on the Hub (unique). fullname (`str`): Organization's full name. """ avatar_url: str name: str fullname: str def __init__(self, **kwargs) -> None: self.avatar_url = kwargs.pop("avatarUrl", "") self.name = kwargs.pop("name", "") self.fullname = kwargs.pop("fullname", "") # forward compatibility self.__dict__.update(**kwargs) @dataclass class User: """ Contains information about a user on the Hub. Attributes: username (`str`): Name of the user on the Hub (unique). fullname (`str`): User's full name. avatar_url (`str`): URL of the user's avatar. details (`str`, *optional*): User's details. is_following (`bool`, *optional*): Whether the authenticated user is following this user. is_pro (`bool`, *optional*): Whether the user is a pro user. num_models (`int`, *optional*): Number of models created by the user. num_datasets (`int`, *optional*): Number of datasets created by the user. num_spaces (`int`, *optional*): Number of spaces created by the user. num_discussions (`int`, *optional*): Number of discussions initiated by the user. num_papers (`int`, *optional*): Number of papers authored by the user. num_upvotes (`int`, *optional*): Number of upvotes received by the user. num_likes (`int`, *optional*): Number of likes given by the user. num_following (`int`, *optional*): Number of users this user is following. num_followers (`int`, *optional*): Number of users following this user. orgs (list of [`Organization`]): List of organizations the user is part of. """ # Metadata username: str fullname: str avatar_url: str details: Optional[str] = None is_following: Optional[bool] = None is_pro: Optional[bool] = None num_models: Optional[int] = None num_datasets: Optional[int] = None num_spaces: Optional[int] = None num_discussions: Optional[int] = None num_papers: Optional[int] = None num_upvotes: Optional[int] = None num_likes: Optional[int] = None num_following: Optional[int] = None num_followers: Optional[int] = None orgs: List[Organization] = field(default_factory=list) def __init__(self, **kwargs) -> None: self.username = kwargs.pop("user", "") self.fullname = kwargs.pop("fullname", "") self.avatar_url = kwargs.pop("avatarUrl", "") self.is_following = kwargs.pop("isFollowing", None) self.is_pro = kwargs.pop("isPro", None) self.details = kwargs.pop("details", None) self.num_models = kwargs.pop("numModels", None) self.num_datasets = kwargs.pop("numDatasets", None) self.num_spaces = kwargs.pop("numSpaces", None) self.num_discussions = kwargs.pop("numDiscussions", None) self.num_papers = kwargs.pop("numPapers", None) self.num_upvotes = kwargs.pop("numUpvotes", None) self.num_likes = kwargs.pop("numLikes", None) self.num_following = kwargs.pop("numFollowing", None) self.num_followers = kwargs.pop("numFollowers", None) self.user_type = kwargs.pop("type", None) self.orgs = [Organization(**org) for org in kwargs.pop("orgs", [])] # forward compatibility self.__dict__.update(**kwargs) @dataclass class PaperInfo: """ Contains information about a paper on the Hub. Attributes: id (`str`): arXiv paper ID. authors (`List[str]`, **optional**): Names of paper authors published_at (`datetime`, **optional**): Date paper published. title (`str`, **optional**): Title of the paper. summary (`str`, **optional**): Summary of the paper. upvotes (`int`, **optional**): Number of upvotes for the paper on the Hub. discussion_id (`str`, **optional**): Discussion ID for the paper on the Hub. source (`str`, **optional**): Source of the paper. comments (`int`, **optional**): Number of comments for the paper on the Hub. submitted_at (`datetime`, **optional**): Date paper appeared in daily papers on the Hub. submitted_by (`User`, **optional**): Information about who submitted the daily paper. """ id: str authors: Optional[List[str]] published_at: Optional[datetime] title: Optional[str] summary: Optional[str] upvotes: Optional[int] discussion_id: Optional[str] source: Optional[str] comments: Optional[int] submitted_at: Optional[datetime] submitted_by: Optional[User] def __init__(self, **kwargs) -> None: paper = kwargs.pop("paper", {}) self.id = kwargs.pop("id", None) or paper.pop("id", None) authors = paper.pop("authors", None) or kwargs.pop("authors", None) self.authors = [author.pop("name", None) for author in authors] if authors else None published_at = paper.pop("publishedAt", None) or kwargs.pop("publishedAt", None) self.published_at = parse_datetime(published_at) if published_at else None self.title = kwargs.pop("title", None) self.source = kwargs.pop("source", None) self.summary = paper.pop("summary", None) or kwargs.pop("summary", None) self.upvotes = paper.pop("upvotes", None) or kwargs.pop("upvotes", None) self.discussion_id = paper.pop("discussionId", None) or kwargs.pop("discussionId", None) self.comments = kwargs.pop("numComments", 0) submitted_at = kwargs.pop("publishedAt", None) or kwargs.pop("submittedOnDailyAt", None) self.submitted_at = parse_datetime(submitted_at) if submitted_at else None submitted_by = kwargs.pop("submittedBy", None) or kwargs.pop("submittedOnDailyBy", None) self.submitted_by = User(**submitted_by) if submitted_by else None # forward compatibility self.__dict__.update(**kwargs) @dataclass class LFSFileInfo: """ Contains information about a file stored as LFS on a repo on the Hub. Used in the context of listing and permanently deleting LFS files from a repo to free-up space. See [`list_lfs_files`] and [`permanently_delete_lfs_files`] for more details. Git LFS files are tracked using SHA-256 object IDs, rather than file paths, to optimize performance This approach is necessary because a single object can be referenced by multiple paths across different commits, making it impractical to search and resolve these connections. Check out [our documentation](https://huggingface.co/docs/hub/storage-limits#advanced-track-lfs-file-references) to learn how to know which filename(s) is(are) associated with each SHA. Attributes: file_oid (`str`): SHA-256 object ID of the file. This is the identifier to pass when permanently deleting the file. filename (`str`): Possible filename for the LFS object. See the note above for more information. oid (`str`): OID of the LFS object. pushed_at (`datetime`): Date the LFS object was pushed to the repo. ref (`str`, *optional*): Ref where the LFS object has been pushed (if any). size (`int`): Size of the LFS object. Example: ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> lfs_files = api.list_lfs_files("username/my-cool-repo") # Filter files files to delete based on a combination of `filename`, `pushed_at`, `ref` or `size`. # e.g. select only LFS files in the "checkpoints" folder >>> lfs_files_to_delete = (lfs_file for lfs_file in lfs_files if lfs_file.filename.startswith("checkpoints/")) # Permanently delete LFS files >>> api.permanently_delete_lfs_files("username/my-cool-repo", lfs_files_to_delete) ``` """ file_oid: str filename: str oid: str pushed_at: datetime ref: Optional[str] size: int def __init__(self, **kwargs) -> None: self.file_oid = kwargs.pop("fileOid") self.filename = kwargs.pop("filename") self.oid = kwargs.pop("oid") self.pushed_at = parse_datetime(kwargs.pop("pushedAt")) self.ref = kwargs.pop("ref", None) self.size = kwargs.pop("size") # forward compatibility self.__dict__.update(**kwargs) def future_compatible(fn: CallableT) -> CallableT: """Wrap a method of `HfApi` to handle `run_as_future=True`. A method flagged as "future_compatible" will be called in a thread if `run_as_future=True` and return a `concurrent.futures.Future` instance. Otherwise, it will be called normally and return the result. """ sig = inspect.signature(fn) args_params = list(sig.parameters)[1:] # remove "self" from list @wraps(fn) def _inner(self, *args, **kwargs): # Get `run_as_future` value if provided (default to False) if "run_as_future" in kwargs: run_as_future = kwargs["run_as_future"] kwargs["run_as_future"] = False # avoid recursion error else: run_as_future = False for param, value in zip(args_params, args): if param == "run_as_future": run_as_future = value break # Call the function in a thread if `run_as_future=True` if run_as_future: return self.run_as_future(fn, self, *args, **kwargs) # Otherwise, call the function normally return fn(self, *args, **kwargs) _inner.is_future_compatible = True # type: ignore return _inner # type: ignore class HfApi: """ Client to interact with the Hugging Face Hub via HTTP. The client is initialized with some high-level settings used in all requests made to the Hub (HF endpoint, authentication, user agents...). Using the `HfApi` client is preferred but not mandatory as all of its public methods are exposed directly at the root of `huggingface_hub`. Args: endpoint (`str`, *optional*): Endpoint of the Hub. Defaults to . token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. library_name (`str`, *optional*): The name of the library that is making the HTTP request. Will be added to the user-agent header. Example: `"transformers"`. library_version (`str`, *optional*): The version of the library that is making the HTTP request. Will be added to the user-agent header. Example: `"4.24.0"`. user_agent (`str`, `dict`, *optional*): The user agent info in the form of a dictionary or a single string. It will be completed with information about the installed packages. headers (`dict`, *optional*): Additional headers to be sent with each request. Example: `{"X-My-Header": "value"}`. Headers passed here are taking precedence over the default headers. """ def __init__( self, endpoint: Optional[str] = None, token: Union[str, bool, None] = None, library_name: Optional[str] = None, library_version: Optional[str] = None, user_agent: Union[Dict, str, None] = None, headers: Optional[Dict[str, str]] = None, ) -> None: self.endpoint = endpoint if endpoint is not None else constants.ENDPOINT self.token = token self.library_name = library_name self.library_version = library_version self.user_agent = user_agent self.headers = headers self._thread_pool: Optional[ThreadPoolExecutor] = None def run_as_future(self, fn: Callable[..., R], *args, **kwargs) -> Future[R]: """ Run a method in the background and return a Future instance. The main goal is to run methods without blocking the main thread (e.g. to push data during a training). Background jobs are queued to preserve order but are not ran in parallel. If you need to speed-up your scripts by parallelizing lots of call to the API, you must setup and use your own [ThreadPoolExecutor](https://docs.python.org/3/library/concurrent.futures.html#threadpoolexecutor). Note: Most-used methods like [`upload_file`], [`upload_folder`] and [`create_commit`] have a `run_as_future: bool` argument to directly call them in the background. This is equivalent to calling `api.run_as_future(...)` on them but less verbose. Args: fn (`Callable`): The method to run in the background. *args, **kwargs: Arguments with which the method will be called. Return: `Future`: a [Future](https://docs.python.org/3/library/concurrent.futures.html#future-objects) instance to get the result of the task. Example: ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> future = api.run_as_future(api.whoami) # instant >>> future.done() False >>> future.result() # wait until complete and return result (...) >>> future.done() True ``` """ if self._thread_pool is None: self._thread_pool = ThreadPoolExecutor(max_workers=1) self._thread_pool return self._thread_pool.submit(fn, *args, **kwargs) @validate_hf_hub_args def whoami(self, token: Union[bool, str, None] = None) -> Dict: """ Call HF API to know "whoami". Args: token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. """ # Get the effective token using the helper function get_token effective_token = token or self.token or get_token() or True r = get_session().get( f"{self.endpoint}/api/whoami-v2", headers=self._build_hf_headers(token=effective_token), ) try: hf_raise_for_status(r) except HTTPError as e: error_message = "Invalid user token." # Check which token is the effective one and generate the error message accordingly if effective_token == _get_token_from_google_colab(): error_message += " The token from Google Colab vault is invalid. Please update it from the UI." elif effective_token == _get_token_from_environment(): error_message += ( " The token from HF_TOKEN environment variable is invalid. " "Note that HF_TOKEN takes precedence over `huggingface-cli login`." ) elif effective_token == _get_token_from_file(): error_message += " The token stored is invalid. Please run `huggingface-cli login` to update it." raise HTTPError(error_message, request=e.request, response=e.response) from e return r.json() @_deprecate_method( version="1.0", message=( "Permissions are more complex than when `get_token_permission` was first introduced. " "OAuth and fine-grain tokens allows for more detailed permissions. " "If you need to know the permissions associated with a token, please use `whoami` and check the `'auth'` key." ), ) def get_token_permission( self, token: Union[bool, str, None] = None ) -> Literal["read", "write", "fineGrained", None]: """ Check if a given `token` is valid and return its permissions. This method is deprecated and will be removed in version 1.0. Permissions are more complex than when `get_token_permission` was first introduced. OAuth and fine-grain tokens allows for more detailed permissions. If you need to know the permissions associated with a token, please use `whoami` and check the `'auth'` key. For more details about tokens, please refer to https://huggingface.co/docs/hub/security-tokens#what-are-user-access-tokens. Args: token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: `Literal["read", "write", "fineGrained", None]`: Permission granted by the token ("read" or "write"). Returns `None` if no token passed, if token is invalid or if role is not returned by the server. This typically happens when the token is an OAuth token. """ try: return self.whoami(token=token)["auth"]["accessToken"]["role"] except (LocalTokenNotFoundError, HTTPError, KeyError): return None def get_model_tags(self) -> Dict: """ List all valid model tags as a nested namespace object """ path = f"{self.endpoint}/api/models-tags-by-type" r = get_session().get(path) hf_raise_for_status(r) return r.json() def get_dataset_tags(self) -> Dict: """ List all valid dataset tags as a nested namespace object. """ path = f"{self.endpoint}/api/datasets-tags-by-type" r = get_session().get(path) hf_raise_for_status(r) return r.json() @validate_hf_hub_args def list_models( self, *, # Search-query parameter filter: Union[str, Iterable[str], None] = None, author: Optional[str] = None, gated: Optional[bool] = None, inference: Optional[Literal["cold", "frozen", "warm"]] = None, library: Optional[Union[str, List[str]]] = None, language: Optional[Union[str, List[str]]] = None, model_name: Optional[str] = None, task: Optional[Union[str, List[str]]] = None, trained_dataset: Optional[Union[str, List[str]]] = None, tags: Optional[Union[str, List[str]]] = None, search: Optional[str] = None, pipeline_tag: Optional[str] = None, emissions_thresholds: Optional[Tuple[float, float]] = None, # Sorting and pagination parameters sort: Union[Literal["last_modified"], str, None] = None, direction: Optional[Literal[-1]] = None, limit: Optional[int] = None, # Additional data to fetch expand: Optional[List[ExpandModelProperty_T]] = None, full: Optional[bool] = None, cardData: bool = False, fetch_config: bool = False, token: Union[bool, str, None] = None, ) -> Iterable[ModelInfo]: """ List models hosted on the Huggingface Hub, given some filters. Args: filter (`str` or `Iterable[str]`, *optional*): A string or list of string to filter models on the Hub. author (`str`, *optional*): A string which identify the author (user or organization) of the returned models. gated (`bool`, *optional*): A boolean to filter models on the Hub that are gated or not. By default, all models are returned. If `gated=True` is passed, only gated models are returned. If `gated=False` is passed, only non-gated models are returned. inference (`Literal["cold", "frozen", "warm"]`, *optional*): A string to filter models on the Hub by their state on the Inference API. Warm models are available for immediate use. Cold models will be loaded on first inference call. Frozen models are not available in Inference API. library (`str` or `List`, *optional*): A string or list of strings of foundational libraries models were originally trained from, such as pytorch, tensorflow, or allennlp. language (`str` or `List`, *optional*): A string or list of strings of languages, both by name and country code, such as "en" or "English" model_name (`str`, *optional*): A string that contain complete or partial names for models on the Hub, such as "bert" or "bert-base-cased" task (`str` or `List`, *optional*): A string or list of strings of tasks models were designed for, such as: "fill-mask" or "automatic-speech-recognition" trained_dataset (`str` or `List`, *optional*): A string tag or a list of string tags of the trained dataset for a model on the Hub. tags (`str` or `List`, *optional*): A string tag or a list of tags to filter models on the Hub by, such as `text-generation` or `spacy`. search (`str`, *optional*): A string that will be contained in the returned model ids. pipeline_tag (`str`, *optional*): A string pipeline tag to filter models on the Hub by, such as `summarization`. emissions_thresholds (`Tuple`, *optional*): A tuple of two ints or floats representing a minimum and maximum carbon footprint to filter the resulting models with in grams. sort (`Literal["last_modified"]` or `str`, *optional*): The key with which to sort the resulting models. Possible values are "last_modified", "trending_score", "created_at", "downloads" and "likes". direction (`Literal[-1]` or `int`, *optional*): Direction in which to sort. The value `-1` sorts by descending order while all other values sort by ascending order. limit (`int`, *optional*): The limit on the number of models fetched. Leaving this option to `None` fetches all models. expand (`List[ExpandModelProperty_T]`, *optional*): List properties to return in the response. When used, only the properties in the list will be returned. This parameter cannot be used if `full`, `cardData` or `fetch_config` are passed. Possible values are `"author"`, `"baseModels"`, `"cardData"`, `"childrenModelCount"`, `"config"`, `"createdAt"`, `"disabled"`, `"downloads"`, `"downloadsAllTime"`, `"gated"`, `"gguf"`, `"inference"`, `"inferenceProviderMapping"`, `"lastModified"`, `"library_name"`, `"likes"`, `"mask_token"`, `"model-index"`, `"pipeline_tag"`, `"private"`, `"safetensors"`, `"sha"`, `"siblings"`, `"spaces"`, `"tags"`, `"transformersInfo"`, `"trendingScore"`, `"widgetData"`, `"usedStorage"`, `"resourceGroup"` and `"xetEnabled"`. full (`bool`, *optional*): Whether to fetch all model data, including the `last_modified`, the `sha`, the files and the `tags`. This is set to `True` by default when using a filter. cardData (`bool`, *optional*): Whether to grab the metadata for the model as well. Can contain useful information such as carbon emissions, metrics, and datasets trained on. fetch_config (`bool`, *optional*): Whether to fetch the model configs as well. This is not included in `full` due to its size. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: `Iterable[ModelInfo]`: an iterable of [`huggingface_hub.hf_api.ModelInfo`] objects. Example usage with the `filter` argument: ```python >>> from huggingface_hub import HfApi >>> api = HfApi() # List all models >>> api.list_models() # List only the text classification models >>> api.list_models(filter="text-classification") # List only models from the AllenNLP library >>> api.list_models(filter="allennlp") ``` Example usage with the `search` argument: ```python >>> from huggingface_hub import HfApi >>> api = HfApi() # List all models with "bert" in their name >>> api.list_models(search="bert") # List all models with "bert" in their name made by google >>> api.list_models(search="bert", author="google") ``` """ if expand and (full or cardData or fetch_config): raise ValueError("`expand` cannot be used if `full`, `cardData` or `fetch_config` are passed.") if emissions_thresholds is not None and cardData is None: raise ValueError("`emissions_thresholds` were passed without setting `cardData=True`.") path = f"{self.endpoint}/api/models" headers = self._build_hf_headers(token=token) params: Dict[str, Any] = {} # Build the filter list filter_list: List[str] = [] if filter: filter_list.extend([filter] if isinstance(filter, str) else filter) if library: filter_list.extend([library] if isinstance(library, str) else library) if task: filter_list.extend([task] if isinstance(task, str) else task) if trained_dataset: if isinstance(trained_dataset, str): trained_dataset = [trained_dataset] for dataset in trained_dataset: if not dataset.startswith("dataset:"): dataset = f"dataset:{dataset}" filter_list.append(dataset) if language: filter_list.extend([language] if isinstance(language, str) else language) if tags: filter_list.extend([tags] if isinstance(tags, str) else tags) if len(filter_list) > 0: params["filter"] = filter_list # Handle other query params if author: params["author"] = author if gated is not None: params["gated"] = gated if inference is not None: params["inference"] = inference if pipeline_tag: params["pipeline_tag"] = pipeline_tag search_list = [] if model_name: search_list.append(model_name) if search: search_list.append(search) if len(search_list) > 0: params["search"] = search_list if sort is not None: params["sort"] = ( "lastModified" if sort == "last_modified" else "trendingScore" if sort == "trending_score" else "createdAt" if sort == "created_at" else sort ) if direction is not None: params["direction"] = direction if limit is not None: params["limit"] = limit # Request additional data if full: params["full"] = True if fetch_config: params["config"] = True if cardData: params["cardData"] = True if expand: params["expand"] = expand # `items` is a generator items = paginate(path, params=params, headers=headers) if limit is not None: items = islice(items, limit) # Do not iterate over all pages for item in items: if "siblings" not in item: item["siblings"] = None model_info = ModelInfo(**item) if emissions_thresholds is None or _is_emission_within_threshold(model_info, *emissions_thresholds): yield model_info @validate_hf_hub_args def list_datasets( self, *, # Search-query parameter filter: Union[str, Iterable[str], None] = None, author: Optional[str] = None, benchmark: Optional[Union[str, List[str]]] = None, dataset_name: Optional[str] = None, gated: Optional[bool] = None, language_creators: Optional[Union[str, List[str]]] = None, language: Optional[Union[str, List[str]]] = None, multilinguality: Optional[Union[str, List[str]]] = None, size_categories: Optional[Union[str, List[str]]] = None, tags: Optional[Union[str, List[str]]] = None, task_categories: Optional[Union[str, List[str]]] = None, task_ids: Optional[Union[str, List[str]]] = None, search: Optional[str] = None, # Sorting and pagination parameters sort: Optional[Union[Literal["last_modified"], str]] = None, direction: Optional[Literal[-1]] = None, limit: Optional[int] = None, # Additional data to fetch expand: Optional[List[ExpandDatasetProperty_T]] = None, full: Optional[bool] = None, token: Union[bool, str, None] = None, ) -> Iterable[DatasetInfo]: """ List datasets hosted on the Huggingface Hub, given some filters. Args: filter (`str` or `Iterable[str]`, *optional*): A string or list of string to filter datasets on the hub. author (`str`, *optional*): A string which identify the author of the returned datasets. benchmark (`str` or `List`, *optional*): A string or list of strings that can be used to identify datasets on the Hub by their official benchmark. dataset_name (`str`, *optional*): A string or list of strings that can be used to identify datasets on the Hub by its name, such as `SQAC` or `wikineural` gated (`bool`, *optional*): A boolean to filter datasets on the Hub that are gated or not. By default, all datasets are returned. If `gated=True` is passed, only gated datasets are returned. If `gated=False` is passed, only non-gated datasets are returned. language_creators (`str` or `List`, *optional*): A string or list of strings that can be used to identify datasets on the Hub with how the data was curated, such as `crowdsourced` or `machine_generated`. language (`str` or `List`, *optional*): A string or list of strings representing a two-character language to filter datasets by on the Hub. multilinguality (`str` or `List`, *optional*): A string or list of strings representing a filter for datasets that contain multiple languages. size_categories (`str` or `List`, *optional*): A string or list of strings that can be used to identify datasets on the Hub by the size of the dataset such as `100K>> from huggingface_hub import HfApi >>> api = HfApi() # List all datasets >>> api.list_datasets() # List only the text classification datasets >>> api.list_datasets(filter="task_categories:text-classification") # List only the datasets in russian for language modeling >>> api.list_datasets( ... filter=("language:ru", "task_ids:language-modeling") ... ) # List FiftyOne datasets (identified by the tag "fiftyone" in dataset card) >>> api.list_datasets(tags="fiftyone") ``` Example usage with the `search` argument: ```python >>> from huggingface_hub import HfApi >>> api = HfApi() # List all datasets with "text" in their name >>> api.list_datasets(search="text") # List all datasets with "text" in their name made by google >>> api.list_datasets(search="text", author="google") ``` """ if expand and full: raise ValueError("`expand` cannot be used if `full` is passed.") path = f"{self.endpoint}/api/datasets" headers = self._build_hf_headers(token=token) params: Dict[str, Any] = {} # Build `filter` list filter_list = [] if filter is not None: if isinstance(filter, str): filter_list.append(filter) else: filter_list.extend(filter) for key, value in ( ("benchmark", benchmark), ("language_creators", language_creators), ("language", language), ("multilinguality", multilinguality), ("size_categories", size_categories), ("task_categories", task_categories), ("task_ids", task_ids), ): if value: if isinstance(value, str): value = [value] for value_item in value: if not value_item.startswith(f"{key}:"): data = f"{key}:{value_item}" filter_list.append(data) if tags is not None: filter_list.extend([tags] if isinstance(tags, str) else tags) if len(filter_list) > 0: params["filter"] = filter_list # Handle other query params if author: params["author"] = author if gated is not None: params["gated"] = gated search_list = [] if dataset_name: search_list.append(dataset_name) if search: search_list.append(search) if len(search_list) > 0: params["search"] = search_list if sort is not None: params["sort"] = ( "lastModified" if sort == "last_modified" else "trendingScore" if sort == "trending_score" else "createdAt" if sort == "created_at" else sort ) if direction is not None: params["direction"] = direction if limit is not None: params["limit"] = limit # Request additional data if expand: params["expand"] = expand if full: params["full"] = True items = paginate(path, params=params, headers=headers) if limit is not None: items = islice(items, limit) # Do not iterate over all pages for item in items: if "siblings" not in item: item["siblings"] = None yield DatasetInfo(**item) @validate_hf_hub_args def list_spaces( self, *, # Search-query parameter filter: Union[str, Iterable[str], None] = None, author: Optional[str] = None, search: Optional[str] = None, datasets: Union[str, Iterable[str], None] = None, models: Union[str, Iterable[str], None] = None, linked: bool = False, # Sorting and pagination parameters sort: Union[Literal["last_modified"], str, None] = None, direction: Optional[Literal[-1]] = None, limit: Optional[int] = None, # Additional data to fetch expand: Optional[List[ExpandSpaceProperty_T]] = None, full: Optional[bool] = None, token: Union[bool, str, None] = None, ) -> Iterable[SpaceInfo]: """ List spaces hosted on the Huggingface Hub, given some filters. Args: filter (`str` or `Iterable`, *optional*): A string tag or list of tags that can be used to identify Spaces on the Hub. author (`str`, *optional*): A string which identify the author of the returned Spaces. search (`str`, *optional*): A string that will be contained in the returned Spaces. datasets (`str` or `Iterable`, *optional*): Whether to return Spaces that make use of a dataset. The name of a specific dataset can be passed as a string. models (`str` or `Iterable`, *optional*): Whether to return Spaces that make use of a model. The name of a specific model can be passed as a string. linked (`bool`, *optional*): Whether to return Spaces that make use of either a model or a dataset. sort (`Literal["last_modified"]` or `str`, *optional*): The key with which to sort the resulting models. Possible values are "last_modified", "trending_score", "created_at" and "likes". direction (`Literal[-1]` or `int`, *optional*): Direction in which to sort. The value `-1` sorts by descending order while all other values sort by ascending order. limit (`int`, *optional*): The limit on the number of Spaces fetched. Leaving this option to `None` fetches all Spaces. expand (`List[ExpandSpaceProperty_T]`, *optional*): List properties to return in the response. When used, only the properties in the list will be returned. This parameter cannot be used if `full` is passed. Possible values are `"author"`, `"cardData"`, `"datasets"`, `"disabled"`, `"lastModified"`, `"createdAt"`, `"likes"`, `"models"`, `"private"`, `"runtime"`, `"sdk"`, `"siblings"`, `"sha"`, `"subdomain"`, `"tags"`, `"trendingScore"`, `"usedStorage"`, `"resourceGroup"` and `"xetEnabled"`. full (`bool`, *optional*): Whether to fetch all Spaces data, including the `last_modified`, `siblings` and `card_data` fields. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: `Iterable[SpaceInfo]`: an iterable of [`huggingface_hub.hf_api.SpaceInfo`] objects. """ if expand and full: raise ValueError("`expand` cannot be used if `full` is passed.") path = f"{self.endpoint}/api/spaces" headers = self._build_hf_headers(token=token) params: Dict[str, Any] = {} if filter is not None: params["filter"] = filter if author is not None: params["author"] = author if search is not None: params["search"] = search if sort is not None: params["sort"] = ( "lastModified" if sort == "last_modified" else "trendingScore" if sort == "trending_score" else "createdAt" if sort == "created_at" else sort ) if direction is not None: params["direction"] = direction if limit is not None: params["limit"] = limit if linked: params["linked"] = True if datasets is not None: params["datasets"] = datasets if models is not None: params["models"] = models # Request additional data if expand: params["expand"] = expand if full: params["full"] = True items = paginate(path, params=params, headers=headers) if limit is not None: items = islice(items, limit) # Do not iterate over all pages for item in items: if "siblings" not in item: item["siblings"] = None yield SpaceInfo(**item) @validate_hf_hub_args def unlike( self, repo_id: str, *, token: Union[bool, str, None] = None, repo_type: Optional[str] = None, ) -> None: """ Unlike a given repo on the Hub (e.g. remove from favorite list). To prevent spam usage, it is not possible to `like` a repository from a script. See also [`list_liked_repos`]. Args: repo_id (`str`): The repository to unlike. Example: `"user/my-cool-model"`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if unliking a dataset or space, `None` or `"model"` if unliking a model. Default is `None`. Raises: [`~utils.RepositoryNotFoundError`]: If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo does not exist. Example: ```python >>> from huggingface_hub import list_liked_repos, unlike >>> "gpt2" in list_liked_repos().models # we assume you have already liked gpt2 True >>> unlike("gpt2") >>> "gpt2" in list_liked_repos().models False ``` """ if repo_type is None: repo_type = constants.REPO_TYPE_MODEL response = get_session().delete( url=f"{self.endpoint}/api/{repo_type}s/{repo_id}/like", headers=self._build_hf_headers(token=token) ) hf_raise_for_status(response) @validate_hf_hub_args def list_liked_repos( self, user: Optional[str] = None, *, token: Union[bool, str, None] = None, ) -> UserLikes: """ List all public repos liked by a user on huggingface.co. This list is public so token is optional. If `user` is not passed, it defaults to the logged in user. See also [`unlike`]. Args: user (`str`, *optional*): Name of the user for which you want to fetch the likes. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`UserLikes`]: object containing the user name and 3 lists of repo ids (1 for models, 1 for datasets and 1 for Spaces). Raises: [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If `user` is not passed and no token found (either from argument or from machine). Example: ```python >>> from huggingface_hub import list_liked_repos >>> likes = list_liked_repos("julien-c") >>> likes.user "julien-c" >>> likes.models ["osanseviero/streamlit_1.15", "Xhaheen/ChatGPT_HF", ...] ``` """ # User is either provided explicitly or retrieved from current token. if user is None: me = self.whoami(token=token) if me["type"] == "user": user = me["name"] else: raise ValueError( "Cannot list liked repos. You must provide a 'user' as input or be logged in as a user." ) path = f"{self.endpoint}/api/users/{user}/likes" headers = self._build_hf_headers(token=token) likes = list(paginate(path, params={}, headers=headers)) # Looping over a list of items similar to: # { # 'createdAt': '2021-09-09T21:53:27.000Z', # 'repo': { # 'name': 'PaddlePaddle/PaddleOCR', # 'type': 'space' # } # } # Let's loop 3 times over the received list. Less efficient but more straightforward to read. return UserLikes( user=user, total=len(likes), models=[like["repo"]["name"] for like in likes if like["repo"]["type"] == "model"], datasets=[like["repo"]["name"] for like in likes if like["repo"]["type"] == "dataset"], spaces=[like["repo"]["name"] for like in likes if like["repo"]["type"] == "space"], ) @validate_hf_hub_args def list_repo_likers( self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None, ) -> Iterable[User]: """ List all users who liked a given repo on the hugging Face Hub. See also [`list_liked_repos`]. Args: repo_id (`str`): The repository to retrieve . Example: `"user/my-cool-model"`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to a model. Default is `None`. Returns: `Iterable[User]`: an iterable of [`huggingface_hub.hf_api.User`] objects. """ # Construct the API endpoint if repo_type is None: repo_type = constants.REPO_TYPE_MODEL path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/likers" for liker in paginate(path, params={}, headers=self._build_hf_headers(token=token)): yield User(username=liker["user"], fullname=liker["fullname"], avatar_url=liker["avatarUrl"]) @validate_hf_hub_args def model_info( self, repo_id: str, *, revision: Optional[str] = None, timeout: Optional[float] = None, securityStatus: Optional[bool] = None, files_metadata: bool = False, expand: Optional[List[ExpandModelProperty_T]] = None, token: Union[bool, str, None] = None, ) -> ModelInfo: """ Get info on one specific model on huggingface.co Model can be private if you pass an acceptable token or are logged in. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. revision (`str`, *optional*): The revision of the model repository from which to get the information. timeout (`float`, *optional*): Whether to set a timeout for the request to the Hub. securityStatus (`bool`, *optional*): Whether to retrieve the security status from the model repository as well. The security status will be returned in the `security_repo_status` field. files_metadata (`bool`, *optional*): Whether or not to retrieve metadata for files in the repository (size, LFS metadata, etc). Defaults to `False`. expand (`List[ExpandModelProperty_T]`, *optional*): List properties to return in the response. When used, only the properties in the list will be returned. This parameter cannot be used if `securityStatus` or `files_metadata` are passed. Possible values are `"author"`, `"baseModels"`, `"cardData"`, `"childrenModelCount"`, `"config"`, `"createdAt"`, `"disabled"`, `"downloads"`, `"downloadsAllTime"`, `"gated"`, `"gguf"`, `"inference"`, `"inferenceProviderMapping"`, `"lastModified"`, `"library_name"`, `"likes"`, `"mask_token"`, `"model-index"`, `"pipeline_tag"`, `"private"`, `"safetensors"`, `"sha"`, `"siblings"`, `"spaces"`, `"tags"`, `"transformersInfo"`, `"trendingScore"`, `"widgetData"`, `"usedStorage"`, `"resourceGroup"` and `"xetEnabled"`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`huggingface_hub.hf_api.ModelInfo`]: The model repository information. Raises the following errors: - [`~utils.RepositoryNotFoundError`] If the repository to download from cannot be found. This may be because it doesn't exist, or because it is set to `private` and you do not have access. - [`~utils.RevisionNotFoundError`] If the revision to download from cannot be found. """ if expand and (securityStatus or files_metadata): raise ValueError("`expand` cannot be used if `securityStatus` or `files_metadata` are set.") headers = self._build_hf_headers(token=token) path = ( f"{self.endpoint}/api/models/{repo_id}" if revision is None else (f"{self.endpoint}/api/models/{repo_id}/revision/{quote(revision, safe='')}") ) params: Dict = {} if securityStatus: params["securityStatus"] = True if files_metadata: params["blobs"] = True if expand: params["expand"] = expand r = get_session().get(path, headers=headers, timeout=timeout, params=params) hf_raise_for_status(r) data = r.json() return ModelInfo(**data) @validate_hf_hub_args def dataset_info( self, repo_id: str, *, revision: Optional[str] = None, timeout: Optional[float] = None, files_metadata: bool = False, expand: Optional[List[ExpandDatasetProperty_T]] = None, token: Union[bool, str, None] = None, ) -> DatasetInfo: """ Get info on one specific dataset on huggingface.co. Dataset can be private if you pass an acceptable token. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. revision (`str`, *optional*): The revision of the dataset repository from which to get the information. timeout (`float`, *optional*): Whether to set a timeout for the request to the Hub. files_metadata (`bool`, *optional*): Whether or not to retrieve metadata for files in the repository (size, LFS metadata, etc). Defaults to `False`. expand (`List[ExpandDatasetProperty_T]`, *optional*): List properties to return in the response. When used, only the properties in the list will be returned. This parameter cannot be used if `files_metadata` is passed. Possible values are `"author"`, `"cardData"`, `"citation"`, `"createdAt"`, `"disabled"`, `"description"`, `"downloads"`, `"downloadsAllTime"`, `"gated"`, `"lastModified"`, `"likes"`, `"paperswithcode_id"`, `"private"`, `"siblings"`, `"sha"`, `"tags"`, `"trendingScore"`,`"usedStorage"`, `"resourceGroup"` and `"xetEnabled"`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`hf_api.DatasetInfo`]: The dataset repository information. Raises the following errors: - [`~utils.RepositoryNotFoundError`] If the repository to download from cannot be found. This may be because it doesn't exist, or because it is set to `private` and you do not have access. - [`~utils.RevisionNotFoundError`] If the revision to download from cannot be found. """ if expand and files_metadata: raise ValueError("`expand` cannot be used if `files_metadata` is set.") headers = self._build_hf_headers(token=token) path = ( f"{self.endpoint}/api/datasets/{repo_id}" if revision is None else (f"{self.endpoint}/api/datasets/{repo_id}/revision/{quote(revision, safe='')}") ) params: Dict = {} if files_metadata: params["blobs"] = True if expand: params["expand"] = expand r = get_session().get(path, headers=headers, timeout=timeout, params=params) hf_raise_for_status(r) data = r.json() return DatasetInfo(**data) @validate_hf_hub_args def space_info( self, repo_id: str, *, revision: Optional[str] = None, timeout: Optional[float] = None, files_metadata: bool = False, expand: Optional[List[ExpandSpaceProperty_T]] = None, token: Union[bool, str, None] = None, ) -> SpaceInfo: """ Get info on one specific Space on huggingface.co. Space can be private if you pass an acceptable token. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. revision (`str`, *optional*): The revision of the space repository from which to get the information. timeout (`float`, *optional*): Whether to set a timeout for the request to the Hub. files_metadata (`bool`, *optional*): Whether or not to retrieve metadata for files in the repository (size, LFS metadata, etc). Defaults to `False`. expand (`List[ExpandSpaceProperty_T]`, *optional*): List properties to return in the response. When used, only the properties in the list will be returned. This parameter cannot be used if `full` is passed. Possible values are `"author"`, `"cardData"`, `"createdAt"`, `"datasets"`, `"disabled"`, `"lastModified"`, `"likes"`, `"models"`, `"private"`, `"runtime"`, `"sdk"`, `"siblings"`, `"sha"`, `"subdomain"`, `"tags"`, `"trendingScore"`, `"usedStorage"`, `"resourceGroup"` and `"xetEnabled"`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`~hf_api.SpaceInfo`]: The space repository information. Raises the following errors: - [`~utils.RepositoryNotFoundError`] If the repository to download from cannot be found. This may be because it doesn't exist, or because it is set to `private` and you do not have access. - [`~utils.RevisionNotFoundError`] If the revision to download from cannot be found. """ if expand and files_metadata: raise ValueError("`expand` cannot be used if `files_metadata` is set.") headers = self._build_hf_headers(token=token) path = ( f"{self.endpoint}/api/spaces/{repo_id}" if revision is None else (f"{self.endpoint}/api/spaces/{repo_id}/revision/{quote(revision, safe='')}") ) params: Dict = {} if files_metadata: params["blobs"] = True if expand: params["expand"] = expand r = get_session().get(path, headers=headers, timeout=timeout, params=params) hf_raise_for_status(r) data = r.json() return SpaceInfo(**data) @validate_hf_hub_args def repo_info( self, repo_id: str, *, revision: Optional[str] = None, repo_type: Optional[str] = None, timeout: Optional[float] = None, files_metadata: bool = False, expand: Optional[Union[ExpandModelProperty_T, ExpandDatasetProperty_T, ExpandSpaceProperty_T]] = None, token: Union[bool, str, None] = None, ) -> Union[ModelInfo, DatasetInfo, SpaceInfo]: """ Get the info object for a given repo of a given type. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. revision (`str`, *optional*): The revision of the repository from which to get the information. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if getting repository info from a dataset or a space, `None` or `"model"` if getting repository info from a model. Default is `None`. timeout (`float`, *optional*): Whether to set a timeout for the request to the Hub. expand (`ExpandModelProperty_T` or `ExpandDatasetProperty_T` or `ExpandSpaceProperty_T`, *optional*): List properties to return in the response. When used, only the properties in the list will be returned. This parameter cannot be used if `files_metadata` is passed. For an exhaustive list of available properties, check out [`model_info`], [`dataset_info`] or [`space_info`]. files_metadata (`bool`, *optional*): Whether or not to retrieve metadata for files in the repository (size, LFS metadata, etc). Defaults to `False`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: `Union[SpaceInfo, DatasetInfo, ModelInfo]`: The repository information, as a [`huggingface_hub.hf_api.DatasetInfo`], [`huggingface_hub.hf_api.ModelInfo`] or [`huggingface_hub.hf_api.SpaceInfo`] object. Raises the following errors: - [`~utils.RepositoryNotFoundError`] If the repository to download from cannot be found. This may be because it doesn't exist, or because it is set to `private` and you do not have access. - [`~utils.RevisionNotFoundError`] If the revision to download from cannot be found. """ if repo_type is None or repo_type == "model": method = self.model_info elif repo_type == "dataset": method = self.dataset_info # type: ignore elif repo_type == "space": method = self.space_info # type: ignore else: raise ValueError("Unsupported repo type.") return method( repo_id, revision=revision, token=token, timeout=timeout, expand=expand, # type: ignore[arg-type] files_metadata=files_metadata, ) @validate_hf_hub_args def repo_exists( self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[str, bool, None] = None, ) -> bool: """ Checks if a repository exists on the Hugging Face Hub. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if getting repository info from a dataset or a space, `None` or `"model"` if getting repository info from a model. Default is `None`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: True if the repository exists, False otherwise. Examples: ```py >>> from huggingface_hub import repo_exists >>> repo_exists("google/gemma-7b") True >>> repo_exists("google/not-a-repo") False ``` """ try: self.repo_info(repo_id=repo_id, repo_type=repo_type, token=token) return True except GatedRepoError: return True # we don't have access but it exists except RepositoryNotFoundError: return False @validate_hf_hub_args def revision_exists( self, repo_id: str, revision: str, *, repo_type: Optional[str] = None, token: Union[str, bool, None] = None, ) -> bool: """ Checks if a specific revision exists on a repo on the Hugging Face Hub. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. revision (`str`): The revision of the repository to check. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if getting repository info from a dataset or a space, `None` or `"model"` if getting repository info from a model. Default is `None`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: True if the repository and the revision exists, False otherwise. Examples: ```py >>> from huggingface_hub import revision_exists >>> revision_exists("google/gemma-7b", "float16") True >>> revision_exists("google/gemma-7b", "not-a-revision") False ``` """ try: self.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type, token=token) return True except RevisionNotFoundError: return False except RepositoryNotFoundError: return False @validate_hf_hub_args def file_exists( self, repo_id: str, filename: str, *, repo_type: Optional[str] = None, revision: Optional[str] = None, token: Union[str, bool, None] = None, ) -> bool: """ Checks if a file exists in a repository on the Hugging Face Hub. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. filename (`str`): The name of the file to check, for example: `"config.json"` repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if getting repository info from a dataset or a space, `None` or `"model"` if getting repository info from a model. Default is `None`. revision (`str`, *optional*): The revision of the repository from which to get the information. Defaults to `"main"` branch. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: True if the file exists, False otherwise. Examples: ```py >>> from huggingface_hub import file_exists >>> file_exists("bigcode/starcoder", "config.json") True >>> file_exists("bigcode/starcoder", "not-a-file") False >>> file_exists("bigcode/not-a-repo", "config.json") False ``` """ url = hf_hub_url( repo_id=repo_id, repo_type=repo_type, revision=revision, filename=filename, endpoint=self.endpoint ) try: if token is None: token = self.token get_hf_file_metadata(url, token=token) return True except GatedRepoError: # raise specifically on gated repo raise except (RepositoryNotFoundError, EntryNotFoundError, RevisionNotFoundError): return False @validate_hf_hub_args def list_repo_files( self, repo_id: str, *, revision: Optional[str] = None, repo_type: Optional[str] = None, token: Union[str, bool, None] = None, ) -> List[str]: """ Get the list of files in a given repo. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. revision (`str`, *optional*): The revision of the repository from which to get the information. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to a model. Default is `None`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: `List[str]`: the list of files in a given repository. """ return [ f.rfilename for f in self.list_repo_tree( repo_id=repo_id, recursive=True, revision=revision, repo_type=repo_type, token=token ) if isinstance(f, RepoFile) ] @validate_hf_hub_args def list_repo_tree( self, repo_id: str, path_in_repo: Optional[str] = None, *, recursive: bool = False, expand: bool = False, revision: Optional[str] = None, repo_type: Optional[str] = None, token: Union[str, bool, None] = None, ) -> Iterable[Union[RepoFile, RepoFolder]]: """ List a repo tree's files and folders and get information about them. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. path_in_repo (`str`, *optional*): Relative path of the tree (folder) in the repo, for example: `"checkpoints/1fec34a/results"`. Will default to the root tree (folder) of the repository. recursive (`bool`, *optional*, defaults to `False`): Whether to list tree's files and folders recursively. expand (`bool`, *optional*, defaults to `False`): Whether to fetch more information about the tree's files and folders (e.g. last commit and files' security scan results). This operation is more expensive for the server so only 50 results are returned per page (instead of 1000). As pagination is implemented in `huggingface_hub`, this is transparent for you except for the time it takes to get the results. revision (`str`, *optional*): The revision of the repository from which to get the tree. Defaults to `"main"` branch. repo_type (`str`, *optional*): The type of the repository from which to get the tree (`"model"`, `"dataset"` or `"space"`. Defaults to `"model"`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: `Iterable[Union[RepoFile, RepoFolder]]`: The information about the tree's files and folders, as an iterable of [`RepoFile`] and [`RepoFolder`] objects. The order of the files and folders is not guaranteed. Raises: [`~utils.RepositoryNotFoundError`]: If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo does not exist. [`~utils.RevisionNotFoundError`]: If revision is not found (error 404) on the repo. [`~utils.EntryNotFoundError`]: If the tree (folder) does not exist (error 404) on the repo. Examples: Get information about a repo's tree. ```py >>> from huggingface_hub import list_repo_tree >>> repo_tree = list_repo_tree("lysandre/arxiv-nlp") >>> repo_tree >>> list(repo_tree) [ RepoFile(path='.gitattributes', size=391, blob_id='ae8c63daedbd4206d7d40126955d4e6ab1c80f8f', lfs=None, last_commit=None, security=None), RepoFile(path='README.md', size=391, blob_id='43bd404b159de6fba7c2f4d3264347668d43af25', lfs=None, last_commit=None, security=None), RepoFile(path='config.json', size=554, blob_id='2f9618c3a19b9a61add74f70bfb121335aeef666', lfs=None, last_commit=None, security=None), RepoFile( path='flax_model.msgpack', size=497764107, blob_id='8095a62ccb4d806da7666fcda07467e2d150218e', lfs={'size': 497764107, 'sha256': 'd88b0d6a6ff9c3f8151f9d3228f57092aaea997f09af009eefd7373a77b5abb9', 'pointer_size': 134}, last_commit=None, security=None ), RepoFile(path='merges.txt', size=456318, blob_id='226b0752cac7789c48f0cb3ec53eda48b7be36cc', lfs=None, last_commit=None, security=None), RepoFile( path='pytorch_model.bin', size=548123560, blob_id='64eaa9c526867e404b68f2c5d66fd78e27026523', lfs={'size': 548123560, 'sha256': '9be78edb5b928eba33aa88f431551348f7466ba9f5ef3daf1d552398722a5436', 'pointer_size': 134}, last_commit=None, security=None ), RepoFile(path='vocab.json', size=898669, blob_id='b00361fece0387ca34b4b8b8539ed830d644dbeb', lfs=None, last_commit=None, security=None)] ] ``` Get even more information about a repo's tree (last commit and files' security scan results) ```py >>> from huggingface_hub import list_repo_tree >>> repo_tree = list_repo_tree("prompthero/openjourney-v4", expand=True) >>> list(repo_tree) [ RepoFolder( path='feature_extractor', tree_id='aa536c4ea18073388b5b0bc791057a7296a00398', last_commit={ 'oid': '47b62b20b20e06b9de610e840282b7e6c3d51190', 'title': 'Upload diffusers weights (#48)', 'date': datetime.datetime(2023, 3, 21, 9, 5, 27, tzinfo=datetime.timezone.utc) } ), RepoFolder( path='safety_checker', tree_id='65aef9d787e5557373fdf714d6c34d4fcdd70440', last_commit={ 'oid': '47b62b20b20e06b9de610e840282b7e6c3d51190', 'title': 'Upload diffusers weights (#48)', 'date': datetime.datetime(2023, 3, 21, 9, 5, 27, tzinfo=datetime.timezone.utc) } ), RepoFile( path='model_index.json', size=582, blob_id='d3d7c1e8c3e78eeb1640b8e2041ee256e24c9ee1', lfs=None, last_commit={ 'oid': 'b195ed2d503f3eb29637050a886d77bd81d35f0e', 'title': 'Fix deprecation warning by changing `CLIPFeatureExtractor` to `CLIPImageProcessor`. (#54)', 'date': datetime.datetime(2023, 5, 15, 21, 41, 59, tzinfo=datetime.timezone.utc) }, security={ 'safe': True, 'av_scan': {'virusFound': False, 'virusNames': None}, 'pickle_import_scan': None } ) ... ] ``` """ repo_type = repo_type or constants.REPO_TYPE_MODEL revision = quote(revision, safe="") if revision is not None else constants.DEFAULT_REVISION headers = self._build_hf_headers(token=token) encoded_path_in_repo = "/" + quote(path_in_repo, safe="") if path_in_repo else "" tree_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/tree/{revision}{encoded_path_in_repo}" for path_info in paginate(path=tree_url, headers=headers, params={"recursive": recursive, "expand": expand}): yield (RepoFile(**path_info) if path_info["type"] == "file" else RepoFolder(**path_info)) @validate_hf_hub_args def list_repo_refs( self, repo_id: str, *, repo_type: Optional[str] = None, include_pull_requests: bool = False, token: Union[str, bool, None] = None, ) -> GitRefs: """ Get the list of refs of a given repo (both tags and branches). Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if listing refs from a dataset or a Space, `None` or `"model"` if listing from a model. Default is `None`. include_pull_requests (`bool`, *optional*): Whether to include refs from pull requests in the list. Defaults to `False`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Example: ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.list_repo_refs("gpt2") GitRefs(branches=[GitRefInfo(name='main', ref='refs/heads/main', target_commit='e7da7f221d5bf496a48136c0cd264e630fe9fcc8')], converts=[], tags=[]) >>> api.list_repo_refs("bigcode/the-stack", repo_type='dataset') GitRefs( branches=[ GitRefInfo(name='main', ref='refs/heads/main', target_commit='18edc1591d9ce72aa82f56c4431b3c969b210ae3'), GitRefInfo(name='v1.1.a1', ref='refs/heads/v1.1.a1', target_commit='f9826b862d1567f3822d3d25649b0d6d22ace714') ], converts=[], tags=[ GitRefInfo(name='v1.0', ref='refs/tags/v1.0', target_commit='c37a8cd1e382064d8aced5e05543c5f7753834da') ] ) ``` Returns: [`GitRefs`]: object containing all information about branches and tags for a repo on the Hub. """ repo_type = repo_type or constants.REPO_TYPE_MODEL response = get_session().get( f"{self.endpoint}/api/{repo_type}s/{repo_id}/refs", headers=self._build_hf_headers(token=token), params={"include_prs": 1} if include_pull_requests else {}, ) hf_raise_for_status(response) data = response.json() def _format_as_git_ref_info(item: Dict) -> GitRefInfo: return GitRefInfo(name=item["name"], ref=item["ref"], target_commit=item["targetCommit"]) return GitRefs( branches=[_format_as_git_ref_info(item) for item in data["branches"]], converts=[_format_as_git_ref_info(item) for item in data["converts"]], tags=[_format_as_git_ref_info(item) for item in data["tags"]], pull_requests=[_format_as_git_ref_info(item) for item in data["pullRequests"]] if include_pull_requests else None, ) @validate_hf_hub_args def list_repo_commits( self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None, revision: Optional[str] = None, formatted: bool = False, ) -> List[GitCommitInfo]: """ Get the list of commits of a given revision for a repo on the Hub. Commits are sorted by date (last commit first). Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if listing commits from a dataset or a Space, `None` or `"model"` if listing from a model. Default is `None`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. revision (`str`, *optional*): The git revision to commit from. Defaults to the head of the `"main"` branch. formatted (`bool`): Whether to return the HTML-formatted title and description of the commits. Defaults to False. Example: ```py >>> from huggingface_hub import HfApi >>> api = HfApi() # Commits are sorted by date (last commit first) >>> initial_commit = api.list_repo_commits("gpt2")[-1] # Initial commit is always a system commit containing the `.gitattributes` file. >>> initial_commit GitCommitInfo( commit_id='9b865efde13a30c13e0a33e536cf3e4a5a9d71d8', authors=['system'], created_at=datetime.datetime(2019, 2, 18, 10, 36, 15, tzinfo=datetime.timezone.utc), title='initial commit', message='', formatted_title=None, formatted_message=None ) # Create an empty branch by deriving from initial commit >>> api.create_branch("gpt2", "new_empty_branch", revision=initial_commit.commit_id) ``` Returns: List[[`GitCommitInfo`]]: list of objects containing information about the commits for a repo on the Hub. Raises: [`~utils.RepositoryNotFoundError`]: If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo does not exist. [`~utils.RevisionNotFoundError`]: If revision is not found (error 404) on the repo. """ repo_type = repo_type or constants.REPO_TYPE_MODEL revision = quote(revision, safe="") if revision is not None else constants.DEFAULT_REVISION # Paginate over results and return the list of commits. return [ GitCommitInfo( commit_id=item["id"], authors=[author["user"] for author in item["authors"]], created_at=parse_datetime(item["date"]), title=item["title"], message=item["message"], formatted_title=item.get("formatted", {}).get("title"), formatted_message=item.get("formatted", {}).get("message"), ) for item in paginate( f"{self.endpoint}/api/{repo_type}s/{repo_id}/commits/{revision}", headers=self._build_hf_headers(token=token), params={"expand[]": "formatted"} if formatted else {}, ) ] @validate_hf_hub_args def get_paths_info( self, repo_id: str, paths: Union[List[str], str], *, expand: bool = False, revision: Optional[str] = None, repo_type: Optional[str] = None, token: Union[str, bool, None] = None, ) -> List[Union[RepoFile, RepoFolder]]: """ Get information about a repo's paths. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. paths (`Union[List[str], str]`, *optional*): The paths to get information about. If a path do not exist, it is ignored without raising an exception. expand (`bool`, *optional*, defaults to `False`): Whether to fetch more information about the paths (e.g. last commit and files' security scan results). This operation is more expensive for the server so only 50 results are returned per page (instead of 1000). As pagination is implemented in `huggingface_hub`, this is transparent for you except for the time it takes to get the results. revision (`str`, *optional*): The revision of the repository from which to get the information. Defaults to `"main"` branch. repo_type (`str`, *optional*): The type of the repository from which to get the information (`"model"`, `"dataset"` or `"space"`. Defaults to `"model"`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: `List[Union[RepoFile, RepoFolder]]`: The information about the paths, as a list of [`RepoFile`] and [`RepoFolder`] objects. Raises: [`~utils.RepositoryNotFoundError`]: If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo does not exist. [`~utils.RevisionNotFoundError`]: If revision is not found (error 404) on the repo. Example: ```py >>> from huggingface_hub import get_paths_info >>> paths_info = get_paths_info("allenai/c4", ["README.md", "en"], repo_type="dataset") >>> paths_info [ RepoFile(path='README.md', size=2379, blob_id='f84cb4c97182890fc1dbdeaf1a6a468fd27b4fff', lfs=None, last_commit=None, security=None), RepoFolder(path='en', tree_id='dc943c4c40f53d02b31ced1defa7e5f438d5862e', last_commit=None) ] ``` """ repo_type = repo_type or constants.REPO_TYPE_MODEL revision = quote(revision, safe="") if revision is not None else constants.DEFAULT_REVISION headers = self._build_hf_headers(token=token) response = get_session().post( f"{self.endpoint}/api/{repo_type}s/{repo_id}/paths-info/{revision}", data={ "paths": paths if isinstance(paths, list) else [paths], "expand": expand, }, headers=headers, ) hf_raise_for_status(response) paths_info = response.json() return [ RepoFile(**path_info) if path_info["type"] == "file" else RepoFolder(**path_info) for path_info in paths_info ] @validate_hf_hub_args def super_squash_history( self, repo_id: str, *, branch: Optional[str] = None, commit_message: Optional[str] = None, repo_type: Optional[str] = None, token: Union[str, bool, None] = None, ) -> None: """Squash commit history on a branch for a repo on the Hub. Squashing the repo history is useful when you know you'll make hundreds of commits and you don't want to clutter the history. Squashing commits can only be performed from the head of a branch. Once squashed, the commit history cannot be retrieved. This is a non-revertible operation. Once the history of a branch has been squashed, it is not possible to merge it back into another branch since their history will have diverged. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. branch (`str`, *optional*): The branch to squash. Defaults to the head of the `"main"` branch. commit_message (`str`, *optional*): The commit message to use for the squashed commit. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if listing commits from a dataset or a Space, `None` or `"model"` if listing from a model. Default is `None`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Raises: [`~utils.RepositoryNotFoundError`]: If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo does not exist. [`~utils.RevisionNotFoundError`]: If the branch to squash cannot be found. [`~utils.BadRequestError`]: If invalid reference for a branch. You cannot squash history on tags. Example: ```py >>> from huggingface_hub import HfApi >>> api = HfApi() # Create repo >>> repo_id = api.create_repo("test-squash").repo_id # Make a lot of commits. >>> api.upload_file(repo_id=repo_id, path_in_repo="file.txt", path_or_fileobj=b"content") >>> api.upload_file(repo_id=repo_id, path_in_repo="lfs.bin", path_or_fileobj=b"content") >>> api.upload_file(repo_id=repo_id, path_in_repo="file.txt", path_or_fileobj=b"another_content") # Squash history >>> api.super_squash_history(repo_id=repo_id) ``` """ if repo_type is None: repo_type = constants.REPO_TYPE_MODEL if repo_type not in constants.REPO_TYPES: raise ValueError("Invalid repo type") if branch is None: branch = constants.DEFAULT_REVISION # Prepare request url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/super-squash/{quote(branch, safe='')}" headers = self._build_hf_headers(token=token) commit_message = commit_message or f"Super-squash branch '{branch}' using huggingface_hub" # Super-squash response = get_session().post(url=url, headers=headers, json={"message": commit_message}) hf_raise_for_status(response) @validate_hf_hub_args def list_lfs_files( self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None, ) -> Iterable[LFSFileInfo]: """ List all LFS files in a repo on the Hub. This is primarily useful to count how much storage a repo is using and to eventually clean up large files with [`permanently_delete_lfs_files`]. Note that this would be a permanent action that will affect all commits referencing this deleted files and that cannot be undone. Args: repo_id (`str`): The repository for which you are listing LFS files. repo_type (`str`, *optional*): Type of repository. Set to `"dataset"` or `"space"` if listing from a dataset or space, `None` or `"model"` if listing from a model. Default is `None`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: `Iterable[LFSFileInfo]`: An iterator of [`LFSFileInfo`] objects. Example: ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> lfs_files = api.list_lfs_files("username/my-cool-repo") # Filter files files to delete based on a combination of `filename`, `pushed_at`, `ref` or `size`. # e.g. select only LFS files in the "checkpoints" folder >>> lfs_files_to_delete = (lfs_file for lfs_file in lfs_files if lfs_file.filename.startswith("checkpoints/")) # Permanently delete LFS files >>> api.permanently_delete_lfs_files("username/my-cool-repo", lfs_files_to_delete) ``` """ # Prepare request if repo_type is None: repo_type = constants.REPO_TYPE_MODEL url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/lfs-files" headers = self._build_hf_headers(token=token) # Paginate over LFS items for item in paginate(url, params={}, headers=headers): yield LFSFileInfo(**item) @validate_hf_hub_args def permanently_delete_lfs_files( self, repo_id: str, lfs_files: Iterable[LFSFileInfo], *, rewrite_history: bool = True, repo_type: Optional[str] = None, token: Union[bool, str, None] = None, ) -> None: """ Permanently delete LFS files from a repo on the Hub. This is a permanent action that will affect all commits referencing the deleted files and might corrupt your repository. This is a non-revertible operation. Use it only if you know what you are doing. Args: repo_id (`str`): The repository for which you are listing LFS files. lfs_files (`Iterable[LFSFileInfo]`): An iterable of [`LFSFileInfo`] items to permanently delete from the repo. Use [`list_lfs_files`] to list all LFS files from a repo. rewrite_history (`bool`, *optional*, default to `True`): Whether to rewrite repository history to remove file pointers referencing the deleted LFS files (recommended). repo_type (`str`, *optional*): Type of repository. Set to `"dataset"` or `"space"` if listing from a dataset or space, `None` or `"model"` if listing from a model. Default is `None`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Example: ```py >>> from huggingface_hub import HfApi >>> api = HfApi() >>> lfs_files = api.list_lfs_files("username/my-cool-repo") # Filter files files to delete based on a combination of `filename`, `pushed_at`, `ref` or `size`. # e.g. select only LFS files in the "checkpoints" folder >>> lfs_files_to_delete = (lfs_file for lfs_file in lfs_files if lfs_file.filename.startswith("checkpoints/")) # Permanently delete LFS files >>> api.permanently_delete_lfs_files("username/my-cool-repo", lfs_files_to_delete) ``` """ # Prepare request if repo_type is None: repo_type = constants.REPO_TYPE_MODEL url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/lfs-files/batch" headers = self._build_hf_headers(token=token) # Delete LFS items by batches of 1000 for batch in chunk_iterable(lfs_files, 1000): shas = [item.file_oid for item in batch] if len(shas) == 0: return payload = { "deletions": { "sha": shas, "rewriteHistory": rewrite_history, } } response = get_session().post(url, headers=headers, json=payload) hf_raise_for_status(response) @validate_hf_hub_args def create_repo( self, repo_id: str, *, token: Union[str, bool, None] = None, private: Optional[bool] = None, repo_type: Optional[str] = None, exist_ok: bool = False, resource_group_id: Optional[str] = None, space_sdk: Optional[str] = None, space_hardware: Optional[SpaceHardware] = None, space_storage: Optional[SpaceStorage] = None, space_sleep_time: Optional[int] = None, space_secrets: Optional[List[Dict[str, str]]] = None, space_variables: Optional[List[Dict[str, str]]] = None, ) -> RepoUrl: """Create an empty repo on the HuggingFace Hub. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. private (`bool`, *optional*): Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to a model. Default is `None`. exist_ok (`bool`, *optional*, defaults to `False`): If `True`, do not raise an error if repo already exists. resource_group_id (`str`, *optional*): Resource group in which to create the repo. Resource groups is only available for organizations and allow to define which members of the organization can access the resource. The ID of a resource group can be found in the URL of the resource's page on the Hub (e.g. `"66670e5163145ca562cb1988"`). To learn more about resource groups, see https://huggingface.co/docs/hub/en/security-resource-groups. space_sdk (`str`, *optional*): Choice of SDK to use if repo_type is "space". Can be "streamlit", "gradio", "docker", or "static". space_hardware (`SpaceHardware` or `str`, *optional*): Choice of Hardware if repo_type is "space". See [`SpaceHardware`] for a complete list. space_storage (`SpaceStorage` or `str`, *optional*): Choice of persistent storage tier. Example: `"small"`. See [`SpaceStorage`] for a complete list. space_sleep_time (`int`, *optional*): Number of seconds of inactivity to wait before a Space is put to sleep. Set to `-1` if you don't want your Space to sleep (default behavior for upgraded hardware). For free hardware, you can't configure the sleep time (value is fixed to 48 hours of inactivity). See https://huggingface.co/docs/hub/spaces-gpus#sleep-time for more details. space_secrets (`List[Dict[str, str]]`, *optional*): A list of secret keys to set in your Space. Each item is in the form `{"key": ..., "value": ..., "description": ...}` where description is optional. For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets. space_variables (`List[Dict[str, str]]`, *optional*): A list of public environment variables to set in your Space. Each item is in the form `{"key": ..., "value": ..., "description": ...}` where description is optional. For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables. Returns: [`RepoUrl`]: URL to the newly created repo. Value is a subclass of `str` containing attributes like `endpoint`, `repo_type` and `repo_id`. """ organization, name = repo_id.split("/") if "/" in repo_id else (None, repo_id) path = f"{self.endpoint}/api/repos/create" if repo_type not in constants.REPO_TYPES: raise ValueError("Invalid repo type") json: Dict[str, Any] = {"name": name, "organization": organization} if private is not None: json["private"] = private if repo_type is not None: json["type"] = repo_type if repo_type == "space": if space_sdk is None: raise ValueError( "No space_sdk provided. `create_repo` expects space_sdk to be one" f" of {constants.SPACES_SDK_TYPES} when repo_type is 'space'`" ) if space_sdk not in constants.SPACES_SDK_TYPES: raise ValueError(f"Invalid space_sdk. Please choose one of {constants.SPACES_SDK_TYPES}.") json["sdk"] = space_sdk if space_sdk is not None and repo_type != "space": warnings.warn("Ignoring provided space_sdk because repo_type is not 'space'.") function_args = [ "space_hardware", "space_storage", "space_sleep_time", "space_secrets", "space_variables", ] json_keys = ["hardware", "storageTier", "sleepTimeSeconds", "secrets", "variables"] values = [space_hardware, space_storage, space_sleep_time, space_secrets, space_variables] if repo_type == "space": json.update({k: v for k, v in zip(json_keys, values) if v is not None}) else: provided_space_args = [key for key, value in zip(function_args, values) if value is not None] if provided_space_args: warnings.warn(f"Ignoring provided {', '.join(provided_space_args)} because repo_type is not 'space'.") if getattr(self, "_lfsmultipartthresh", None): # Testing purposes only. # See https://github.com/huggingface/huggingface_hub/pull/733/files#r820604472 json["lfsmultipartthresh"] = self._lfsmultipartthresh # type: ignore if resource_group_id is not None: json["resourceGroupId"] = resource_group_id headers = self._build_hf_headers(token=token) while True: r = get_session().post(path, headers=headers, json=json) if r.status_code == 409 and "Cannot create repo: another conflicting operation is in progress" in r.text: # Since https://github.com/huggingface/moon-landing/pull/7272 (private repo), it is not possible to # concurrently create repos on the Hub for a same user. This is rarely an issue, except when running # tests. To avoid any inconvenience, we retry to create the repo for this specific error. # NOTE: This could have being fixed directly in the tests but adding it here should fixed CIs for all # dependent libraries. # NOTE: If a fix is implemented server-side, we should be able to remove this retry mechanism. logger.debug("Create repo failed due to a concurrency issue. Retrying...") continue break try: hf_raise_for_status(r) except HTTPError as err: if exist_ok and err.response.status_code == 409: # Repo already exists and `exist_ok=True` pass elif exist_ok and err.response.status_code == 403: # No write permission on the namespace but repo might already exist try: self.repo_info(repo_id=repo_id, repo_type=repo_type, token=token) if repo_type is None or repo_type == constants.REPO_TYPE_MODEL: return RepoUrl(f"{self.endpoint}/{repo_id}") return RepoUrl(f"{self.endpoint}/{repo_type}/{repo_id}") except HfHubHTTPError: raise err else: raise d = r.json() return RepoUrl(d["url"], endpoint=self.endpoint) @validate_hf_hub_args def delete_repo( self, repo_id: str, *, token: Union[str, bool, None] = None, repo_type: Optional[str] = None, missing_ok: bool = False, ) -> None: """ Delete a repo from the HuggingFace Hub. CAUTION: this is irreversible. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to a model. missing_ok (`bool`, *optional*, defaults to `False`): If `True`, do not raise an error if repo does not exist. Raises: [`~utils.RepositoryNotFoundError`] If the repository to delete from cannot be found and `missing_ok` is set to False (default). """ organization, name = repo_id.split("/") if "/" in repo_id else (None, repo_id) path = f"{self.endpoint}/api/repos/delete" if repo_type not in constants.REPO_TYPES: raise ValueError("Invalid repo type") json = {"name": name, "organization": organization} if repo_type is not None: json["type"] = repo_type headers = self._build_hf_headers(token=token) r = get_session().delete(path, headers=headers, json=json) try: hf_raise_for_status(r) except RepositoryNotFoundError: if not missing_ok: raise @_deprecate_method(version="0.32", message="Please use `update_repo_settings` instead.") @validate_hf_hub_args def update_repo_visibility( self, repo_id: str, private: bool = False, *, token: Union[str, bool, None] = None, repo_type: Optional[str] = None, ) -> Dict[str, bool]: """Update the visibility setting of a repository. Deprecated. Use `update_repo_settings` instead. Args: repo_id (`str`, *optional*): A namespace (user or an organization) and a repo name separated by a `/`. private (`bool`, *optional*, defaults to `False`): Whether the repository should be private. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to a model. Default is `None`. Returns: The HTTP response in json. Raises the following errors: - [`~utils.RepositoryNotFoundError`] If the repository to download from cannot be found. This may be because it doesn't exist, or because it is set to `private` and you do not have access. """ if repo_type not in constants.REPO_TYPES: raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") if repo_type is None: repo_type = constants.REPO_TYPE_MODEL # default repo type r = get_session().put( url=f"{self.endpoint}/api/{repo_type}s/{repo_id}/settings", headers=self._build_hf_headers(token=token), json={"private": private}, ) hf_raise_for_status(r) return r.json() @validate_hf_hub_args def update_repo_settings( self, repo_id: str, *, gated: Optional[Literal["auto", "manual", False]] = None, private: Optional[bool] = None, token: Union[str, bool, None] = None, repo_type: Optional[str] = None, xet_enabled: Optional[bool] = None, ) -> None: """ Update the settings of a repository, including gated access and visibility. To give more control over how repos are used, the Hub allows repo authors to enable access requests for their repos, and also to set the visibility of the repo to private. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a /. gated (`Literal["auto", "manual", False]`, *optional*): The gated status for the repository. If set to `None` (default), the `gated` setting of the repository won't be updated. * "auto": The repository is gated, and access requests are automatically approved or denied based on predefined criteria. * "manual": The repository is gated, and access requests require manual approval. * False : The repository is not gated, and anyone can access it. private (`bool`, *optional*): Whether the repository should be private. token (`Union[str, bool, None]`, *optional*): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass False. repo_type (`str`, *optional*): The type of the repository to update settings from (`"model"`, `"dataset"` or `"space"`). Defaults to `"model"`. xet_enabled (`bool`, *optional*): Whether the repository should be enabled for Xet Storage. Raises: [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If gated is not one of "auto", "manual", or False. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If repo_type is not one of the values in constants.REPO_TYPES. [`~utils.HfHubHTTPError`]: If the request to the Hugging Face Hub API fails. [`~utils.RepositoryNotFoundError`] If the repository to download from cannot be found. This may be because it doesn't exist, or because it is set to `private` and you do not have access. """ if repo_type not in constants.REPO_TYPES: raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") if repo_type is None: repo_type = constants.REPO_TYPE_MODEL # default repo type # Prepare the JSON payload for the PUT request payload: Dict = {} if gated is not None: if gated not in ["auto", "manual", False]: raise ValueError(f"Invalid gated status, must be one of 'auto', 'manual', or False. Got '{gated}'.") payload["gated"] = gated if private is not None: payload["private"] = private if xet_enabled is not None: payload["xetEnabled"] = xet_enabled if len(payload) == 0: raise ValueError("At least one setting must be updated.") # Build headers headers = self._build_hf_headers(token=token) r = get_session().put( url=f"{self.endpoint}/api/{repo_type}s/{repo_id}/settings", headers=headers, json=payload, ) hf_raise_for_status(r) def move_repo( self, from_id: str, to_id: str, *, repo_type: Optional[str] = None, token: Union[str, bool, None] = None, ): """ Moving a repository from namespace1/repo_name1 to namespace2/repo_name2 Note there are certain limitations. For more information about moving repositories, please see https://hf.co/docs/hub/repositories-settings#renaming-or-transferring-a-repo. Args: from_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. Original repository identifier. to_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. Final repository identifier. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to a model. Default is `None`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Raises the following errors: - [`~utils.RepositoryNotFoundError`] If the repository to download from cannot be found. This may be because it doesn't exist, or because it is set to `private` and you do not have access. """ if len(from_id.split("/")) != 2: raise ValueError(f"Invalid repo_id: {from_id}. It should have a namespace (:namespace:/:repo_name:)") if len(to_id.split("/")) != 2: raise ValueError(f"Invalid repo_id: {to_id}. It should have a namespace (:namespace:/:repo_name:)") if repo_type is None: repo_type = constants.REPO_TYPE_MODEL # Hub won't accept `None`. json = {"fromRepo": from_id, "toRepo": to_id, "type": repo_type} path = f"{self.endpoint}/api/repos/move" headers = self._build_hf_headers(token=token) r = get_session().post(path, headers=headers, json=json) try: hf_raise_for_status(r) except HfHubHTTPError as e: e.append_to_message( "\nFor additional documentation please see" " https://hf.co/docs/hub/repositories-settings#renaming-or-transferring-a-repo." ) raise @overload def create_commit( # type: ignore self, repo_id: str, operations: Iterable[CommitOperation], *, commit_message: str, commit_description: Optional[str] = None, token: Union[str, bool, None] = None, repo_type: Optional[str] = None, revision: Optional[str] = None, create_pr: Optional[bool] = None, num_threads: int = 5, parent_commit: Optional[str] = None, run_as_future: Literal[False] = ..., ) -> CommitInfo: ... @overload def create_commit( self, repo_id: str, operations: Iterable[CommitOperation], *, commit_message: str, commit_description: Optional[str] = None, token: Union[str, bool, None] = None, repo_type: Optional[str] = None, revision: Optional[str] = None, create_pr: Optional[bool] = None, num_threads: int = 5, parent_commit: Optional[str] = None, run_as_future: Literal[True] = ..., ) -> Future[CommitInfo]: ... @validate_hf_hub_args @future_compatible def create_commit( self, repo_id: str, operations: Iterable[CommitOperation], *, commit_message: str, commit_description: Optional[str] = None, token: Union[str, bool, None] = None, repo_type: Optional[str] = None, revision: Optional[str] = None, create_pr: Optional[bool] = None, num_threads: int = 5, parent_commit: Optional[str] = None, run_as_future: bool = False, ) -> Union[CommitInfo, Future[CommitInfo]]: """ Creates a commit in the given repo, deleting & uploading files as needed. The input list of `CommitOperation` will be mutated during the commit process. Do not reuse the same objects for multiple commits. `create_commit` assumes that the repo already exists on the Hub. If you get a Client error 404, please make sure you are authenticated and that `repo_id` and `repo_type` are set correctly. If repo does not exist, create it first using [`~hf_api.create_repo`]. `create_commit` is limited to 25k LFS files and a 1GB payload for regular files. Args: repo_id (`str`): The repository in which the commit will be created, for example: `"username/custom_transformers"` operations (`Iterable` of [`~hf_api.CommitOperation`]): An iterable of operations to include in the commit, either: - [`~hf_api.CommitOperationAdd`] to upload a file - [`~hf_api.CommitOperationDelete`] to delete a file - [`~hf_api.CommitOperationCopy`] to copy a file Operation objects will be mutated to include information relative to the upload. Do not reuse the same objects for multiple commits. commit_message (`str`): The summary (first line) of the commit that will be created. commit_description (`str`, *optional*): The description of the commit that will be created token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to a model. Default is `None`. revision (`str`, *optional*): The git revision to commit from. Defaults to the head of the `"main"` branch. create_pr (`boolean`, *optional*): Whether or not to create a Pull Request with that commit. Defaults to `False`. If `revision` is not set, PR is opened against the `"main"` branch. If `revision` is set and is a branch, PR is opened against this branch. If `revision` is set and is not a branch name (example: a commit oid), an `RevisionNotFoundError` is returned by the server. num_threads (`int`, *optional*): Number of concurrent threads for uploading files. Defaults to 5. Setting it to 2 means at most 2 files will be uploaded concurrently. parent_commit (`str`, *optional*): The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be especially useful if the repo is updated / committed to concurrently. run_as_future (`bool`, *optional*): Whether or not to run this method in the background. Background jobs are run sequentially without blocking the main thread. Passing `run_as_future=True` will return a [Future](https://docs.python.org/3/library/concurrent.futures.html#future-objects) object. Defaults to `False`. Returns: [`CommitInfo`] or `Future`: Instance of [`CommitInfo`] containing information about the newly created commit (commit hash, commit url, pr url, commit message,...). If `run_as_future=True` is passed, returns a Future object which will contain the result when executed. Raises: [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If commit message is empty. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If parent commit is not a valid commit OID. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If a README.md file with an invalid metadata section is committed. In this case, the commit will fail early, before trying to upload any file. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If `create_pr` is `True` and revision is neither `None` nor `"main"`. [`~utils.RepositoryNotFoundError`]: If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo does not exist. """ if parent_commit is not None and not constants.REGEX_COMMIT_OID.fullmatch(parent_commit): raise ValueError( f"`parent_commit` is not a valid commit OID. It must match the following regex: {constants.REGEX_COMMIT_OID}" ) if commit_message is None or len(commit_message) == 0: raise ValueError("`commit_message` can't be empty, please pass a value.") commit_description = commit_description if commit_description is not None else "" repo_type = repo_type if repo_type is not None else constants.REPO_TYPE_MODEL if repo_type not in constants.REPO_TYPES: raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") unquoted_revision = revision or constants.DEFAULT_REVISION revision = quote(unquoted_revision, safe="") create_pr = create_pr if create_pr is not None else False headers = self._build_hf_headers(token=token) operations = list(operations) additions = [op for op in operations if isinstance(op, CommitOperationAdd)] copies = [op for op in operations if isinstance(op, CommitOperationCopy)] nb_additions = len(additions) nb_copies = len(copies) nb_deletions = len(operations) - nb_additions - nb_copies for addition in additions: if addition._is_committed: raise ValueError( f"CommitOperationAdd {addition} has already being committed and cannot be reused. Please create a" " new CommitOperationAdd object if you want to create a new commit." ) if repo_type != "dataset": for addition in additions: if addition.path_in_repo.endswith((".arrow", ".parquet")): warnings.warn( f"It seems that you are about to commit a data file ({addition.path_in_repo}) to a {repo_type}" " repository. You are sure this is intended? If you are trying to upload a dataset, please" " set `repo_type='dataset'` or `--repo-type=dataset` in a CLI." ) logger.debug( f"About to commit to the hub: {len(additions)} addition(s), {len(copies)} copie(s) and" f" {nb_deletions} deletion(s)." ) # If updating a README.md file, make sure the metadata format is valid # It's better to fail early than to fail after all the files have been uploaded. for addition in additions: if addition.path_in_repo == "README.md": with addition.as_file() as file: content = file.read().decode() self._validate_yaml(content, repo_type=repo_type, token=token) # Skip other additions after `README.md` has been processed break # If updating twice the same file or update then delete a file in a single commit _warn_on_overwriting_operations(operations) self.preupload_lfs_files( repo_id=repo_id, additions=additions, token=token, repo_type=repo_type, revision=unquoted_revision, # first-class methods take unquoted revision create_pr=create_pr, num_threads=num_threads, free_memory=False, # do not remove `CommitOperationAdd.path_or_fileobj` on LFS files for "normal" users ) files_to_copy = _fetch_files_to_copy( copies=copies, repo_type=repo_type, repo_id=repo_id, headers=headers, revision=unquoted_revision, endpoint=self.endpoint, ) # Remove no-op operations (files that have not changed) operations_without_no_op = [] for operation in operations: if ( isinstance(operation, CommitOperationAdd) and operation._remote_oid is not None and operation._remote_oid == operation._local_oid ): # File already exists on the Hub and has not changed: we can skip it. logger.debug(f"Skipping upload for '{operation.path_in_repo}' as the file has not changed.") continue if ( isinstance(operation, CommitOperationCopy) and operation._dest_oid is not None and operation._dest_oid == operation._src_oid ): # Source and destination files are identical - skip logger.debug( f"Skipping copy for '{operation.src_path_in_repo}' -> '{operation.path_in_repo}' as the content of the source file is the same as the destination file." ) continue operations_without_no_op.append(operation) if len(operations) != len(operations_without_no_op): logger.info( f"Removing {len(operations) - len(operations_without_no_op)} file(s) from commit that have not changed." ) # Return early if empty commit if len(operations_without_no_op) == 0: logger.warning("No files have been modified since last commit. Skipping to prevent empty commit.") # Get latest commit info try: info = self.repo_info(repo_id=repo_id, repo_type=repo_type, revision=unquoted_revision, token=token) except RepositoryNotFoundError as e: e.append_to_message(_CREATE_COMMIT_NO_REPO_ERROR_MESSAGE) raise # Return commit info based on latest commit url_prefix = self.endpoint if repo_type is not None and repo_type != constants.REPO_TYPE_MODEL: url_prefix = f"{url_prefix}/{repo_type}s" return CommitInfo( commit_url=f"{url_prefix}/{repo_id}/commit/{info.sha}", commit_message=commit_message, commit_description=commit_description, oid=info.sha, # type: ignore[arg-type] ) commit_payload = _prepare_commit_payload( operations=operations, files_to_copy=files_to_copy, commit_message=commit_message, commit_description=commit_description, parent_commit=parent_commit, ) commit_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/commit/{revision}" def _payload_as_ndjson() -> Iterable[bytes]: for item in commit_payload: yield json.dumps(item).encode() yield b"\n" headers = { # See https://github.com/huggingface/huggingface_hub/issues/1085#issuecomment-1265208073 "Content-Type": "application/x-ndjson", **headers, } data = b"".join(_payload_as_ndjson()) params = {"create_pr": "1"} if create_pr else None try: commit_resp = get_session().post(url=commit_url, headers=headers, data=data, params=params) hf_raise_for_status(commit_resp, endpoint_name="commit") except RepositoryNotFoundError as e: e.append_to_message(_CREATE_COMMIT_NO_REPO_ERROR_MESSAGE) raise except EntryNotFoundError as e: if nb_deletions > 0 and "A file with this name doesn't exist" in str(e): e.append_to_message( "\nMake sure to differentiate file and folder paths in delete" " operations with a trailing '/' or using `is_folder=True/False`." ) raise # Mark additions as committed (cannot be reused in another commit) for addition in additions: addition._is_committed = True commit_data = commit_resp.json() return CommitInfo( commit_url=commit_data["commitUrl"], commit_message=commit_message, commit_description=commit_description, oid=commit_data["commitOid"], pr_url=commit_data["pullRequestUrl"] if create_pr else None, ) def preupload_lfs_files( self, repo_id: str, additions: Iterable[CommitOperationAdd], *, token: Union[str, bool, None] = None, repo_type: Optional[str] = None, revision: Optional[str] = None, create_pr: Optional[bool] = None, num_threads: int = 5, free_memory: bool = True, gitignore_content: Optional[str] = None, ): """Pre-upload LFS files to S3 in preparation on a future commit. This method is useful if you are generating the files to upload on-the-fly and you don't want to store them in memory before uploading them all at once. This is a power-user method. You shouldn't need to call it directly to make a normal commit. Use [`create_commit`] directly instead. Commit operations will be mutated during the process. In particular, the attached `path_or_fileobj` will be removed after the upload to save memory (and replaced by an empty `bytes` object). Do not reuse the same objects except to pass them to [`create_commit`]. If you don't want to remove the attached content from the commit operation object, pass `free_memory=False`. Args: repo_id (`str`): The repository in which you will commit the files, for example: `"username/custom_transformers"`. operations (`Iterable` of [`CommitOperationAdd`]): The list of files to upload. Warning: the objects in this list will be mutated to include information relative to the upload. Do not reuse the same objects for multiple commits. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. repo_type (`str`, *optional*): The type of repository to upload to (e.g. `"model"` -default-, `"dataset"` or `"space"`). revision (`str`, *optional*): The git revision to commit from. Defaults to the head of the `"main"` branch. create_pr (`boolean`, *optional*): Whether or not you plan to create a Pull Request with that commit. Defaults to `False`. num_threads (`int`, *optional*): Number of concurrent threads for uploading files. Defaults to 5. Setting it to 2 means at most 2 files will be uploaded concurrently. gitignore_content (`str`, *optional*): The content of the `.gitignore` file to know which files should be ignored. The order of priority is to first check if `gitignore_content` is passed, then check if the `.gitignore` file is present in the list of files to commit and finally default to the `.gitignore` file already hosted on the Hub (if any). Example: ```py >>> from huggingface_hub import CommitOperationAdd, preupload_lfs_files, create_commit, create_repo >>> repo_id = create_repo("test_preupload").repo_id # Generate and preupload LFS files one by one >>> operations = [] # List of all `CommitOperationAdd` objects that will be generated >>> for i in range(5): ... content = ... # generate binary content ... addition = CommitOperationAdd(path_in_repo=f"shard_{i}_of_5.bin", path_or_fileobj=content) ... preupload_lfs_files(repo_id, additions=[addition]) # upload + free memory ... operations.append(addition) # Create commit >>> create_commit(repo_id, operations=operations, commit_message="Commit all shards") ``` """ repo_type = repo_type if repo_type is not None else constants.REPO_TYPE_MODEL if repo_type not in constants.REPO_TYPES: raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") revision = quote(revision, safe="") if revision is not None else constants.DEFAULT_REVISION create_pr = create_pr if create_pr is not None else False headers = self._build_hf_headers(token=token) # Check if a `gitignore` file is being committed to the Hub. additions = list(additions) if gitignore_content is None: for addition in additions: if addition.path_in_repo == ".gitignore": with addition.as_file() as f: gitignore_content = f.read().decode() break # Filter out already uploaded files new_additions = [addition for addition in additions if not addition._is_uploaded] # Check which new files are LFS try: _fetch_upload_modes( additions=new_additions, repo_type=repo_type, repo_id=repo_id, headers=headers, revision=revision, endpoint=self.endpoint, create_pr=create_pr or False, gitignore_content=gitignore_content, ) except RepositoryNotFoundError as e: e.append_to_message(_CREATE_COMMIT_NO_REPO_ERROR_MESSAGE) raise # Filter out regular files new_lfs_additions = [addition for addition in new_additions if addition._upload_mode == "lfs"] # Filter out files listed in .gitignore new_lfs_additions_to_upload = [] for addition in new_lfs_additions: if addition._should_ignore: logger.debug(f"Skipping upload for LFS file '{addition.path_in_repo}' (ignored by gitignore file).") else: new_lfs_additions_to_upload.append(addition) if len(new_lfs_additions) != len(new_lfs_additions_to_upload): logger.info( f"Skipped upload for {len(new_lfs_additions) - len(new_lfs_additions_to_upload)} LFS file(s) " "(ignored by gitignore file)." ) # Prepare upload parameters upload_kwargs = { "additions": new_lfs_additions_to_upload, "repo_type": repo_type, "repo_id": repo_id, "headers": headers, "endpoint": self.endpoint, # If `create_pr`, we don't want to check user permission on the revision as users with read permission # should still be able to create PRs even if they don't have write permission on the target branch of the # PR (i.e. `revision`). "revision": revision if not create_pr else None, } # Upload files using Xet protocol if all of the following are true: # - xet is enabled for the repo, # - the files are provided as str or paths objects, # - the library is installed. # Otherwise, default back to LFS. xet_enabled = self.repo_info( repo_id=repo_id, repo_type=repo_type, revision=unquote(revision) if revision is not None else revision, expand="xetEnabled", token=token, ).xet_enabled has_buffered_io_data = any( isinstance(addition.path_or_fileobj, io.BufferedIOBase) for addition in new_lfs_additions_to_upload ) if xet_enabled and not has_buffered_io_data and is_xet_available(): logger.info("Uploading files using Xet Storage..") _upload_xet_files(**upload_kwargs, create_pr=create_pr) # type: ignore [arg-type] else: if xet_enabled and is_xet_available(): if has_buffered_io_data: logger.warning( "Uploading files as a binary IO buffer is not supported by Xet Storage. " "Falling back to HTTP upload." ) _upload_lfs_files(**upload_kwargs, num_threads=num_threads) # type: ignore [arg-type] for addition in new_lfs_additions_to_upload: addition._is_uploaded = True if free_memory: addition.path_or_fileobj = b"" @overload def upload_file( # type: ignore self, *, path_or_fileobj: Union[str, Path, bytes, BinaryIO], path_in_repo: str, repo_id: str, token: Union[str, bool, None] = None, repo_type: Optional[str] = None, revision: Optional[str] = None, commit_message: Optional[str] = None, commit_description: Optional[str] = None, create_pr: Optional[bool] = None, parent_commit: Optional[str] = None, run_as_future: Literal[False] = ..., ) -> CommitInfo: ... @overload def upload_file( self, *, path_or_fileobj: Union[str, Path, bytes, BinaryIO], path_in_repo: str, repo_id: str, token: Union[str, bool, None] = None, repo_type: Optional[str] = None, revision: Optional[str] = None, commit_message: Optional[str] = None, commit_description: Optional[str] = None, create_pr: Optional[bool] = None, parent_commit: Optional[str] = None, run_as_future: Literal[True] = ..., ) -> Future[CommitInfo]: ... @validate_hf_hub_args @future_compatible def upload_file( self, *, path_or_fileobj: Union[str, Path, bytes, BinaryIO], path_in_repo: str, repo_id: str, token: Union[str, bool, None] = None, repo_type: Optional[str] = None, revision: Optional[str] = None, commit_message: Optional[str] = None, commit_description: Optional[str] = None, create_pr: Optional[bool] = None, parent_commit: Optional[str] = None, run_as_future: bool = False, ) -> Union[CommitInfo, Future[CommitInfo]]: """ Upload a local file (up to 50 GB) to the given repo. The upload is done through a HTTP post request, and doesn't require git or git-lfs to be installed. Args: path_or_fileobj (`str`, `Path`, `bytes`, or `IO`): Path to a file on the local machine or binary data stream / fileobj / buffer. path_in_repo (`str`): Relative filepath in the repo, for example: `"checkpoints/1fec34a/weights.bin"` repo_id (`str`): The repository to which the file will be uploaded, for example: `"username/custom_transformers"` token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to a model. Default is `None`. revision (`str`, *optional*): The git revision to commit from. Defaults to the head of the `"main"` branch. commit_message (`str`, *optional*): The summary / title / first line of the generated commit commit_description (`str` *optional*) The description of the generated commit create_pr (`boolean`, *optional*): Whether or not to create a Pull Request with that commit. Defaults to `False`. If `revision` is not set, PR is opened against the `"main"` branch. If `revision` is set and is a branch, PR is opened against this branch. If `revision` is set and is not a branch name (example: a commit oid), an `RevisionNotFoundError` is returned by the server. parent_commit (`str`, *optional*): The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be especially useful if the repo is updated / committed to concurrently. run_as_future (`bool`, *optional*): Whether or not to run this method in the background. Background jobs are run sequentially without blocking the main thread. Passing `run_as_future=True` will return a [Future](https://docs.python.org/3/library/concurrent.futures.html#future-objects) object. Defaults to `False`. Returns: [`CommitInfo`] or `Future`: Instance of [`CommitInfo`] containing information about the newly created commit (commit hash, commit url, pr url, commit message,...). If `run_as_future=True` is passed, returns a Future object which will contain the result when executed. Raises the following errors: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) if the HuggingFace API returned an error - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid - [`~utils.RepositoryNotFoundError`] If the repository to download from cannot be found. This may be because it doesn't exist, or because it is set to `private` and you do not have access. - [`~utils.RevisionNotFoundError`] If the revision to download from cannot be found. `upload_file` assumes that the repo already exists on the Hub. If you get a Client error 404, please make sure you are authenticated and that `repo_id` and `repo_type` are set correctly. If repo does not exist, create it first using [`~hf_api.create_repo`]. Example: ```python >>> from huggingface_hub import upload_file >>> with open("./local/filepath", "rb") as fobj: ... upload_file( ... path_or_fileobj=fileobj, ... path_in_repo="remote/file/path.h5", ... repo_id="username/my-dataset", ... repo_type="dataset", ... token="my_token", ... ) "https://huggingface.co/datasets/username/my-dataset/blob/main/remote/file/path.h5" >>> upload_file( ... path_or_fileobj=".\\\\local\\\\file\\\\path", ... path_in_repo="remote/file/path.h5", ... repo_id="username/my-model", ... token="my_token", ... ) "https://huggingface.co/username/my-model/blob/main/remote/file/path.h5" >>> upload_file( ... path_or_fileobj=".\\\\local\\\\file\\\\path", ... path_in_repo="remote/file/path.h5", ... repo_id="username/my-model", ... token="my_token", ... create_pr=True, ... ) "https://huggingface.co/username/my-model/blob/refs%2Fpr%2F1/remote/file/path.h5" ``` """ if repo_type not in constants.REPO_TYPES: raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") commit_message = ( commit_message if commit_message is not None else f"Upload {path_in_repo} with huggingface_hub" ) operation = CommitOperationAdd( path_or_fileobj=path_or_fileobj, path_in_repo=path_in_repo, ) commit_info = self.create_commit( repo_id=repo_id, repo_type=repo_type, operations=[operation], commit_message=commit_message, commit_description=commit_description, token=token, revision=revision, create_pr=create_pr, parent_commit=parent_commit, ) if commit_info.pr_url is not None: revision = quote(_parse_revision_from_pr_url(commit_info.pr_url), safe="") if repo_type in constants.REPO_TYPES_URL_PREFIXES: repo_id = constants.REPO_TYPES_URL_PREFIXES[repo_type] + repo_id revision = revision if revision is not None else constants.DEFAULT_REVISION return CommitInfo( commit_url=commit_info.commit_url, commit_message=commit_info.commit_message, commit_description=commit_info.commit_description, oid=commit_info.oid, pr_url=commit_info.pr_url, # Similar to `hf_hub_url` but it's "blob" instead of "resolve" # TODO: remove this in v1.0 _url=f"{self.endpoint}/{repo_id}/blob/{revision}/{path_in_repo}", ) @overload def upload_folder( # type: ignore self, *, repo_id: str, folder_path: Union[str, Path], path_in_repo: Optional[str] = None, commit_message: Optional[str] = None, commit_description: Optional[str] = None, token: Union[str, bool, None] = None, repo_type: Optional[str] = None, revision: Optional[str] = None, create_pr: Optional[bool] = None, parent_commit: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, delete_patterns: Optional[Union[List[str], str]] = None, run_as_future: Literal[False] = ..., ) -> CommitInfo: ... @overload def upload_folder( # type: ignore self, *, repo_id: str, folder_path: Union[str, Path], path_in_repo: Optional[str] = None, commit_message: Optional[str] = None, commit_description: Optional[str] = None, token: Union[str, bool, None] = None, repo_type: Optional[str] = None, revision: Optional[str] = None, create_pr: Optional[bool] = None, parent_commit: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, delete_patterns: Optional[Union[List[str], str]] = None, run_as_future: Literal[True] = ..., ) -> Future[CommitInfo]: ... @validate_hf_hub_args @future_compatible def upload_folder( self, *, repo_id: str, folder_path: Union[str, Path], path_in_repo: Optional[str] = None, commit_message: Optional[str] = None, commit_description: Optional[str] = None, token: Union[str, bool, None] = None, repo_type: Optional[str] = None, revision: Optional[str] = None, create_pr: Optional[bool] = None, parent_commit: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, delete_patterns: Optional[Union[List[str], str]] = None, run_as_future: bool = False, ) -> Union[CommitInfo, Future[CommitInfo]]: """ Upload a local folder to the given repo. The upload is done through a HTTP requests, and doesn't require git or git-lfs to be installed. The structure of the folder will be preserved. Files with the same name already present in the repository will be overwritten. Others will be left untouched. Use the `allow_patterns` and `ignore_patterns` arguments to specify which files to upload. These parameters accept either a single pattern or a list of patterns. Patterns are Standard Wildcards (globbing patterns) as documented [here](https://tldp.org/LDP/GNU-Linux-Tools-Summary/html/x11655.htm). If both `allow_patterns` and `ignore_patterns` are provided, both constraints apply. By default, all files from the folder are uploaded. Use the `delete_patterns` argument to specify remote files you want to delete. Input type is the same as for `allow_patterns` (see above). If `path_in_repo` is also provided, the patterns are matched against paths relative to this folder. For example, `upload_folder(..., path_in_repo="experiment", delete_patterns="logs/*")` will delete any remote file under `./experiment/logs/`. Note that the `.gitattributes` file will not be deleted even if it matches the patterns. Any `.git/` folder present in any subdirectory will be ignored. However, please be aware that the `.gitignore` file is not taken into account. Uses `HfApi.create_commit` under the hood. Args: repo_id (`str`): The repository to which the file will be uploaded, for example: `"username/custom_transformers"` folder_path (`str` or `Path`): Path to the folder to upload on the local file system path_in_repo (`str`, *optional*): Relative path of the directory in the repo, for example: `"checkpoints/1fec34a/results"`. Will default to the root folder of the repository. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to a model. Default is `None`. revision (`str`, *optional*): The git revision to commit from. Defaults to the head of the `"main"` branch. commit_message (`str`, *optional*): The summary / title / first line of the generated commit. Defaults to: `f"Upload {path_in_repo} with huggingface_hub"` commit_description (`str` *optional*): The description of the generated commit create_pr (`boolean`, *optional*): Whether or not to create a Pull Request with that commit. Defaults to `False`. If `revision` is not set, PR is opened against the `"main"` branch. If `revision` is set and is a branch, PR is opened against this branch. If `revision` is set and is not a branch name (example: a commit oid), an `RevisionNotFoundError` is returned by the server. parent_commit (`str`, *optional*): The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be especially useful if the repo is updated / committed to concurrently. allow_patterns (`List[str]` or `str`, *optional*): If provided, only files matching at least one pattern are uploaded. ignore_patterns (`List[str]` or `str`, *optional*): If provided, files matching any of the patterns are not uploaded. delete_patterns (`List[str]` or `str`, *optional*): If provided, remote files matching any of the patterns will be deleted from the repo while committing new files. This is useful if you don't know which files have already been uploaded. Note: to avoid discrepancies the `.gitattributes` file is not deleted even if it matches the pattern. run_as_future (`bool`, *optional*): Whether or not to run this method in the background. Background jobs are run sequentially without blocking the main thread. Passing `run_as_future=True` will return a [Future](https://docs.python.org/3/library/concurrent.futures.html#future-objects) object. Defaults to `False`. Returns: [`CommitInfo`] or `Future`: Instance of [`CommitInfo`] containing information about the newly created commit (commit hash, commit url, pr url, commit message,...). If `run_as_future=True` is passed, returns a Future object which will contain the result when executed. Raises the following errors: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) if the HuggingFace API returned an error - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid `upload_folder` assumes that the repo already exists on the Hub. If you get a Client error 404, please make sure you are authenticated and that `repo_id` and `repo_type` are set correctly. If repo does not exist, create it first using [`~hf_api.create_repo`]. When dealing with a large folder (thousands of files or hundreds of GB), we recommend using [`~hf_api.upload_large_folder`] instead. Example: ```python # Upload checkpoints folder except the log files >>> upload_folder( ... folder_path="local/checkpoints", ... path_in_repo="remote/experiment/checkpoints", ... repo_id="username/my-dataset", ... repo_type="datasets", ... token="my_token", ... ignore_patterns="**/logs/*.txt", ... ) # "https://huggingface.co/datasets/username/my-dataset/tree/main/remote/experiment/checkpoints" # Upload checkpoints folder including logs while deleting existing logs from the repo # Useful if you don't know exactly which log files have already being pushed >>> upload_folder( ... folder_path="local/checkpoints", ... path_in_repo="remote/experiment/checkpoints", ... repo_id="username/my-dataset", ... repo_type="datasets", ... token="my_token", ... delete_patterns="**/logs/*.txt", ... ) "https://huggingface.co/datasets/username/my-dataset/tree/main/remote/experiment/checkpoints" # Upload checkpoints folder while creating a PR >>> upload_folder( ... folder_path="local/checkpoints", ... path_in_repo="remote/experiment/checkpoints", ... repo_id="username/my-dataset", ... repo_type="datasets", ... token="my_token", ... create_pr=True, ... ) "https://huggingface.co/datasets/username/my-dataset/tree/refs%2Fpr%2F1/remote/experiment/checkpoints" ``` """ if repo_type not in constants.REPO_TYPES: raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") # By default, upload folder to the root directory in repo. if path_in_repo is None: path_in_repo = "" # Do not upload .git folder if ignore_patterns is None: ignore_patterns = [] elif isinstance(ignore_patterns, str): ignore_patterns = [ignore_patterns] ignore_patterns += DEFAULT_IGNORE_PATTERNS delete_operations = self._prepare_folder_deletions( repo_id=repo_id, repo_type=repo_type, revision=constants.DEFAULT_REVISION if create_pr else revision, token=token, path_in_repo=path_in_repo, delete_patterns=delete_patterns, ) add_operations = self._prepare_upload_folder_additions( folder_path, path_in_repo, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, token=token, repo_type=repo_type, ) # Optimize operations: if some files will be overwritten, we don't need to delete them first if len(add_operations) > 0: added_paths = set(op.path_in_repo for op in add_operations) delete_operations = [ delete_op for delete_op in delete_operations if delete_op.path_in_repo not in added_paths ] commit_operations = delete_operations + add_operations commit_message = commit_message or "Upload folder using huggingface_hub" commit_info = self.create_commit( repo_type=repo_type, repo_id=repo_id, operations=commit_operations, commit_message=commit_message, commit_description=commit_description, token=token, revision=revision, create_pr=create_pr, parent_commit=parent_commit, ) # Create url to uploaded folder (for legacy return value) if create_pr and commit_info.pr_url is not None: revision = quote(_parse_revision_from_pr_url(commit_info.pr_url), safe="") if repo_type in constants.REPO_TYPES_URL_PREFIXES: repo_id = constants.REPO_TYPES_URL_PREFIXES[repo_type] + repo_id revision = revision if revision is not None else constants.DEFAULT_REVISION return CommitInfo( commit_url=commit_info.commit_url, commit_message=commit_info.commit_message, commit_description=commit_info.commit_description, oid=commit_info.oid, pr_url=commit_info.pr_url, # Similar to `hf_hub_url` but it's "tree" instead of "resolve" # TODO: remove this in v1.0 _url=f"{self.endpoint}/{repo_id}/tree/{revision}/{path_in_repo}", ) @validate_hf_hub_args def delete_file( self, path_in_repo: str, repo_id: str, *, token: Union[str, bool, None] = None, repo_type: Optional[str] = None, revision: Optional[str] = None, commit_message: Optional[str] = None, commit_description: Optional[str] = None, create_pr: Optional[bool] = None, parent_commit: Optional[str] = None, ) -> CommitInfo: """ Deletes a file in the given repo. Args: path_in_repo (`str`): Relative filepath in the repo, for example: `"checkpoints/1fec34a/weights.bin"` repo_id (`str`): The repository from which the file will be deleted, for example: `"username/custom_transformers"` token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if the file is in a dataset or space, `None` or `"model"` if in a model. Default is `None`. revision (`str`, *optional*): The git revision to commit from. Defaults to the head of the `"main"` branch. commit_message (`str`, *optional*): The summary / title / first line of the generated commit. Defaults to `f"Delete {path_in_repo} with huggingface_hub"`. commit_description (`str` *optional*) The description of the generated commit create_pr (`boolean`, *optional*): Whether or not to create a Pull Request with that commit. Defaults to `False`. If `revision` is not set, PR is opened against the `"main"` branch. If `revision` is set and is a branch, PR is opened against this branch. If `revision` is set and is not a branch name (example: a commit oid), an `RevisionNotFoundError` is returned by the server. parent_commit (`str`, *optional*): The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be especially useful if the repo is updated / committed to concurrently. Raises the following errors: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) if the HuggingFace API returned an error - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid - [`~utils.RepositoryNotFoundError`] If the repository to download from cannot be found. This may be because it doesn't exist, or because it is set to `private` and you do not have access. - [`~utils.RevisionNotFoundError`] If the revision to download from cannot be found. - [`~utils.EntryNotFoundError`] If the file to download cannot be found. """ commit_message = ( commit_message if commit_message is not None else f"Delete {path_in_repo} with huggingface_hub" ) operations = [CommitOperationDelete(path_in_repo=path_in_repo)] return self.create_commit( repo_id=repo_id, repo_type=repo_type, token=token, operations=operations, revision=revision, commit_message=commit_message, commit_description=commit_description, create_pr=create_pr, parent_commit=parent_commit, ) @validate_hf_hub_args def delete_files( self, repo_id: str, delete_patterns: List[str], *, token: Union[bool, str, None] = None, repo_type: Optional[str] = None, revision: Optional[str] = None, commit_message: Optional[str] = None, commit_description: Optional[str] = None, create_pr: Optional[bool] = None, parent_commit: Optional[str] = None, ) -> CommitInfo: """ Delete files from a repository on the Hub. If a folder path is provided, the entire folder is deleted as well as all files it contained. Args: repo_id (`str`): The repository from which the folder will be deleted, for example: `"username/custom_transformers"` delete_patterns (`List[str]`): List of files or folders to delete. Each string can either be a file path, a folder path or a Unix shell-style wildcard. E.g. `["file.txt", "folder/", "data/*.parquet"]` token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. to the stored token. repo_type (`str`, *optional*): Type of the repo to delete files from. Can be `"model"`, `"dataset"` or `"space"`. Defaults to `"model"`. revision (`str`, *optional*): The git revision to commit from. Defaults to the head of the `"main"` branch. commit_message (`str`, *optional*): The summary (first line) of the generated commit. Defaults to `f"Delete files using huggingface_hub"`. commit_description (`str` *optional*) The description of the generated commit. create_pr (`boolean`, *optional*): Whether or not to create a Pull Request with that commit. Defaults to `False`. If `revision` is not set, PR is opened against the `"main"` branch. If `revision` is set and is a branch, PR is opened against this branch. If `revision` is set and is not a branch name (example: a commit oid), an `RevisionNotFoundError` is returned by the server. parent_commit (`str`, *optional*): The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be especially useful if the repo is updated / committed to concurrently. """ operations = self._prepare_folder_deletions( repo_id=repo_id, repo_type=repo_type, delete_patterns=delete_patterns, path_in_repo="", revision=revision ) if commit_message is None: commit_message = f"Delete files {' '.join(delete_patterns)} with huggingface_hub" return self.create_commit( repo_id=repo_id, repo_type=repo_type, token=token, operations=operations, revision=revision, commit_message=commit_message, commit_description=commit_description, create_pr=create_pr, parent_commit=parent_commit, ) @validate_hf_hub_args def delete_folder( self, path_in_repo: str, repo_id: str, *, token: Union[bool, str, None] = None, repo_type: Optional[str] = None, revision: Optional[str] = None, commit_message: Optional[str] = None, commit_description: Optional[str] = None, create_pr: Optional[bool] = None, parent_commit: Optional[str] = None, ) -> CommitInfo: """ Deletes a folder in the given repo. Simple wrapper around [`create_commit`] method. Args: path_in_repo (`str`): Relative folder path in the repo, for example: `"checkpoints/1fec34a"`. repo_id (`str`): The repository from which the folder will be deleted, for example: `"username/custom_transformers"` token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. to the stored token. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if the folder is in a dataset or space, `None` or `"model"` if in a model. Default is `None`. revision (`str`, *optional*): The git revision to commit from. Defaults to the head of the `"main"` branch. commit_message (`str`, *optional*): The summary / title / first line of the generated commit. Defaults to `f"Delete folder {path_in_repo} with huggingface_hub"`. commit_description (`str` *optional*) The description of the generated commit. create_pr (`boolean`, *optional*): Whether or not to create a Pull Request with that commit. Defaults to `False`. If `revision` is not set, PR is opened against the `"main"` branch. If `revision` is set and is a branch, PR is opened against this branch. If `revision` is set and is not a branch name (example: a commit oid), an `RevisionNotFoundError` is returned by the server. parent_commit (`str`, *optional*): The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be especially useful if the repo is updated / committed to concurrently. """ return self.create_commit( repo_id=repo_id, repo_type=repo_type, token=token, operations=[CommitOperationDelete(path_in_repo=path_in_repo, is_folder=True)], revision=revision, commit_message=( commit_message if commit_message is not None else f"Delete folder {path_in_repo} with huggingface_hub" ), commit_description=commit_description, create_pr=create_pr, parent_commit=parent_commit, ) def upload_large_folder( self, repo_id: str, folder_path: Union[str, Path], *, repo_type: str, # Repo type is required! revision: Optional[str] = None, private: Optional[bool] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, num_workers: Optional[int] = None, print_report: bool = True, print_report_every: int = 60, ) -> None: """Upload a large folder to the Hub in the most resilient way possible. Several workers are started to upload files in an optimized way. Before being committed to a repo, files must be hashed and be pre-uploaded if they are LFS files. Workers will perform these tasks for each file in the folder. At each step, some metadata information about the upload process is saved in the folder under `.cache/.huggingface/` to be able to resume the process if interrupted. The whole process might result in several commits. Args: repo_id (`str`): The repository to which the file will be uploaded. E.g. `"HuggingFaceTB/smollm-corpus"`. folder_path (`str` or `Path`): Path to the folder to upload on the local file system. repo_type (`str`): Type of the repository. Must be one of `"model"`, `"dataset"` or `"space"`. Unlike in all other `HfApi` methods, `repo_type` is explicitly required here. This is to avoid any mistake when uploading a large folder to the Hub, and therefore prevent from having to re-upload everything. revision (`str`, `optional`): The branch to commit to. If not provided, the `main` branch will be used. private (`bool`, `optional`): Whether the repository should be private. If `None` (default), the repo will be public unless the organization's default is private. allow_patterns (`List[str]` or `str`, *optional*): If provided, only files matching at least one pattern are uploaded. ignore_patterns (`List[str]` or `str`, *optional*): If provided, files matching any of the patterns are not uploaded. num_workers (`int`, *optional*): Number of workers to start. Defaults to `os.cpu_count() - 2` (minimum 2). A higher number of workers may speed up the process if your machine allows it. However, on machines with a slower connection, it is recommended to keep the number of workers low to ensure better resumability. Indeed, partially uploaded files will have to be completely re-uploaded if the process is interrupted. print_report (`bool`, *optional*): Whether to print a report of the upload progress. Defaults to True. Report is printed to `sys.stdout` every X seconds (60 by defaults) and overwrites the previous report. print_report_every (`int`, *optional*): Frequency at which the report is printed. Defaults to 60 seconds. A few things to keep in mind: - Repository limits still apply: https://huggingface.co/docs/hub/repositories-recommendations - Do not start several processes in parallel. - You can interrupt and resume the process at any time. - Do not upload the same folder to several repositories. If you need to do so, you must delete the local `.cache/.huggingface/` folder first. While being much more robust to upload large folders, `upload_large_folder` is more limited than [`upload_folder`] feature-wise. In practice: - you cannot set a custom `path_in_repo`. If you want to upload to a subfolder, you need to set the proper structure locally. - you cannot set a custom `commit_message` and `commit_description` since multiple commits are created. - you cannot delete from the repo while uploading. Please make a separate commit first. - you cannot create a PR directly. Please create a PR first (from the UI or using [`create_pull_request`]) and then commit to it by passing `revision`. **Technical details:** `upload_large_folder` process is as follow: 1. (Check parameters and setup.) 2. Create repo if missing. 3. List local files to upload. 4. Start workers. Workers can perform the following tasks: - Hash a file. - Get upload mode (regular or LFS) for a list of files. - Pre-upload an LFS file. - Commit a bunch of files. Once a worker finishes a task, it will move on to the next task based on the priority list (see below) until all files are uploaded and committed. 5. While workers are up, regularly print a report to sys.stdout. Order of priority: 1. Commit if more than 5 minutes since last commit attempt (and at least 1 file). 2. Commit if at least 150 files are ready to commit. 3. Get upload mode if at least 10 files have been hashed. 4. Pre-upload LFS file if at least 1 file and no worker is pre-uploading. 5. Hash file if at least 1 file and no worker is hashing. 6. Get upload mode if at least 1 file and no worker is getting upload mode. 7. Pre-upload LFS file if at least 1 file (exception: if hf_transfer is enabled, only 1 worker can preupload LFS at a time). 8. Hash file if at least 1 file to hash. 9. Get upload mode if at least 1 file to get upload mode. 10. Commit if at least 1 file to commit and at least 1 min since last commit attempt. 11. Commit if at least 1 file to commit and all other queues are empty. Special rules: - If `hf_transfer` is enabled, only 1 LFS uploader at a time. Otherwise the CPU would be bloated by `hf_transfer`. - Only one worker can commit at a time. - If no tasks are available, the worker waits for 10 seconds before checking again. """ return upload_large_folder_internal( self, repo_id=repo_id, folder_path=folder_path, repo_type=repo_type, revision=revision, private=private, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, num_workers=num_workers, print_report=print_report, print_report_every=print_report_every, ) @validate_hf_hub_args def get_hf_file_metadata( self, *, url: str, token: Union[bool, str, None] = None, proxies: Optional[Dict] = None, timeout: Optional[float] = constants.DEFAULT_REQUEST_TIMEOUT, ) -> HfFileMetadata: """Fetch metadata of a file versioned on the Hub for a given url. Args: url (`str`): File url, for example returned by [`hf_hub_url`]. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. proxies (`dict`, *optional*): Dictionary mapping protocol to the URL of the proxy passed to `requests.request`. timeout (`float`, *optional*, defaults to 10): How many seconds to wait for the server to send metadata before giving up. Returns: A [`HfFileMetadata`] object containing metadata such as location, etag, size and commit_hash. """ if token is None: # Cannot do `token = token or self.token` as token can be `False`. token = self.token return get_hf_file_metadata( url=url, token=token, proxies=proxies, timeout=timeout, library_name=self.library_name, library_version=self.library_version, user_agent=self.user_agent, ) @validate_hf_hub_args def hf_hub_download( self, repo_id: str, filename: str, *, subfolder: Optional[str] = None, repo_type: Optional[str] = None, revision: Optional[str] = None, cache_dir: Union[str, Path, None] = None, local_dir: Union[str, Path, None] = None, force_download: bool = False, proxies: Optional[Dict] = None, etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, token: Union[bool, str, None] = None, local_files_only: bool = False, # Deprecated args resume_download: Optional[bool] = None, force_filename: Optional[str] = None, local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", ) -> str: """Download a given file if it's not already present in the local cache. The new cache file layout looks like this: - The cache directory contains one subfolder per repo_id (namespaced by repo type) - inside each repo folder: - refs is a list of the latest known revision => commit_hash pairs - blobs contains the actual file blobs (identified by their git-sha or sha256, depending on whether they're LFS files or not) - snapshots contains one subfolder per commit, each "commit" contains the subset of the files that have been resolved at that particular commit. Each filename is a symlink to the blob at that particular commit. ``` [ 96] . └── [ 160] models--julien-c--EsperBERTo-small ├── [ 160] blobs │ ├── [321M] 403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd │ ├── [ 398] 7cb18dc9bafbfcf74629a4b760af1b160957a83e │ └── [1.4K] d7edf6bd2a681fb0175f7735299831ee1b22b812 ├── [ 96] refs │ └── [ 40] main └── [ 128] snapshots ├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f │ ├── [ 52] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812 │ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd └── [ 128] bbc77c8132af1cc5cf678da3f1ddf2de43606d48 ├── [ 52] README.md -> ../../blobs/7cb18dc9bafbfcf74629a4b760af1b160957a83e └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd ``` If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir` to store some metadata related to the downloaded files. While this mechanism is not as robust as the main cache-system, it's optimized for regularly pulling the latest version of a repository. Args: repo_id (`str`): A user or an organization name and a repo name separated by a `/`. filename (`str`): The name of the file in the repo. subfolder (`str`, *optional*): An optional value corresponding to a folder inside the repository. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if downloading from a dataset or space, `None` or `"model"` if downloading from a model. Default is `None`. revision (`str`, *optional*): An optional Git revision id which can be a branch name, a tag, or a commit hash. cache_dir (`str`, `Path`, *optional*): Path to the folder where cached files are stored. local_dir (`str` or `Path`, *optional*): If provided, the downloaded file will be placed under this directory. force_download (`bool`, *optional*, defaults to `False`): Whether the file should be downloaded even if it already exists in the local cache. proxies (`dict`, *optional*): Dictionary mapping protocol to the URL of the proxy passed to `requests.request`. etag_timeout (`float`, *optional*, defaults to `10`): When fetching ETag, how many seconds to wait for the server to send data before giving up which is passed to `requests.request`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. local_files_only (`bool`, *optional*, defaults to `False`): If `True`, avoid downloading the file and return the path to the local cached file if it exists. Returns: `str`: Local path of file or if networking is off, last version of file cached on disk. Raises: [`~utils.RepositoryNotFoundError`] If the repository to download from cannot be found. This may be because it doesn't exist, or because it is set to `private` and you do not have access. [`~utils.RevisionNotFoundError`] If the revision to download from cannot be found. [`~utils.EntryNotFoundError`] If the file to download cannot be found. [`~utils.LocalEntryNotFoundError`] If network is disabled or unavailable and file is not found in cache. [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) If `token=True` but the token cannot be found. [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) If ETag cannot be determined. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If some parameter value is invalid. """ from .file_download import hf_hub_download if token is None: # Cannot do `token = token or self.token` as token can be `False`. token = self.token return hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type, revision=revision, endpoint=self.endpoint, library_name=self.library_name, library_version=self.library_version, cache_dir=cache_dir, local_dir=local_dir, local_dir_use_symlinks=local_dir_use_symlinks, user_agent=self.user_agent, force_download=force_download, force_filename=force_filename, proxies=proxies, etag_timeout=etag_timeout, resume_download=resume_download, token=token, headers=self.headers, local_files_only=local_files_only, ) @validate_hf_hub_args def snapshot_download( self, repo_id: str, *, repo_type: Optional[str] = None, revision: Optional[str] = None, cache_dir: Union[str, Path, None] = None, local_dir: Union[str, Path, None] = None, proxies: Optional[Dict] = None, etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, force_download: bool = False, token: Union[bool, str, None] = None, local_files_only: bool = False, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, max_workers: int = 8, tqdm_class: Optional[base_tqdm] = None, # Deprecated args local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", resume_download: Optional[bool] = None, ) -> str: """Download repo files. Download a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from a repo, because you don't know which ones you will need a priori. All files are nested inside a folder in order to keep their actual filename relative to that folder. You can also filter which files to download using `allow_patterns` and `ignore_patterns`. If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir` to store some metadata related to the downloaded files.While this mechanism is not as robust as the main cache-system, it's optimized for regularly pulling the latest version of a repository. An alternative would be to clone the repo but this requires git and git-lfs to be installed and properly configured. It is also not possible to filter which files to download when cloning a repository using git. Args: repo_id (`str`): A user or an organization name and a repo name separated by a `/`. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if downloading from a dataset or space, `None` or `"model"` if downloading from a model. Default is `None`. revision (`str`, *optional*): An optional Git revision id which can be a branch name, a tag, or a commit hash. cache_dir (`str`, `Path`, *optional*): Path to the folder where cached files are stored. local_dir (`str` or `Path`, *optional*): If provided, the downloaded files will be placed under this directory. proxies (`dict`, *optional*): Dictionary mapping protocol to the URL of the proxy passed to `requests.request`. etag_timeout (`float`, *optional*, defaults to `10`): When fetching ETag, how many seconds to wait for the server to send data before giving up which is passed to `requests.request`. force_download (`bool`, *optional*, defaults to `False`): Whether the file should be downloaded even if it already exists in the local cache. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. local_files_only (`bool`, *optional*, defaults to `False`): If `True`, avoid downloading the file and return the path to the local cached file if it exists. allow_patterns (`List[str]` or `str`, *optional*): If provided, only files matching at least one pattern are downloaded. ignore_patterns (`List[str]` or `str`, *optional*): If provided, files matching any of the patterns are not downloaded. max_workers (`int`, *optional*): Number of concurrent threads to download files (1 thread = 1 file download). Defaults to 8. tqdm_class (`tqdm`, *optional*): If provided, overwrites the default behavior for the progress bar. Passed argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior. Note that the `tqdm_class` is not passed to each individual download. Defaults to the custom HF progress bar that can be disabled by setting `HF_HUB_DISABLE_PROGRESS_BARS` environment variable. Returns: `str`: folder path of the repo snapshot. Raises: [`~utils.RepositoryNotFoundError`] If the repository to download from cannot be found. This may be because it doesn't exist, or because it is set to `private` and you do not have access. [`~utils.RevisionNotFoundError`] If the revision to download from cannot be found. [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) If `token=True` and the token cannot be found. [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if ETag cannot be determined. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid. """ from ._snapshot_download import snapshot_download if token is None: # Cannot do `token = token or self.token` as token can be `False`. token = self.token return snapshot_download( repo_id=repo_id, repo_type=repo_type, revision=revision, endpoint=self.endpoint, cache_dir=cache_dir, local_dir=local_dir, local_dir_use_symlinks=local_dir_use_symlinks, library_name=self.library_name, library_version=self.library_version, user_agent=self.user_agent, proxies=proxies, etag_timeout=etag_timeout, resume_download=resume_download, force_download=force_download, token=token, local_files_only=local_files_only, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, max_workers=max_workers, tqdm_class=tqdm_class, ) def get_safetensors_metadata( self, repo_id: str, *, repo_type: Optional[str] = None, revision: Optional[str] = None, token: Union[bool, str, None] = None, ) -> SafetensorsRepoMetadata: """ Parse metadata for a safetensors repo on the Hub. We first check if the repo has a single safetensors file or a sharded safetensors repo. If it's a single safetensors file, we parse the metadata from this file. If it's a sharded safetensors repo, we parse the metadata from the index file and then parse the metadata from each shard. To parse metadata from a single safetensors file, use [`parse_safetensors_file_metadata`]. For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format. Args: repo_id (`str`): A user or an organization name and a repo name separated by a `/`. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if the file is in a dataset or space, `None` or `"model"` if in a model. Default is `None`. revision (`str`, *optional*): The git revision to fetch the file from. Can be a branch name, a tag, or a commit hash. Defaults to the head of the `"main"` branch. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`SafetensorsRepoMetadata`]: information related to safetensors repo. Raises: [`NotASafetensorsRepoError`] If the repo is not a safetensors repo i.e. doesn't have either a `model.safetensors` or a `model.safetensors.index.json` file. [`SafetensorsParsingError`] If a safetensors file header couldn't be parsed correctly. Example: ```py # Parse repo with single weights file >>> metadata = get_safetensors_metadata("bigscience/bloomz-560m") >>> metadata SafetensorsRepoMetadata( metadata=None, sharded=False, weight_map={'h.0.input_layernorm.bias': 'model.safetensors', ...}, files_metadata={'model.safetensors': SafetensorsFileMetadata(...)} ) >>> metadata.files_metadata["model.safetensors"].metadata {'format': 'pt'} # Parse repo with sharded model >>> metadata = get_safetensors_metadata("bigscience/bloom") Parse safetensors files: 100%|██████████████████████████████████████████| 72/72 [00:12<00:00, 5.78it/s] >>> metadata SafetensorsRepoMetadata(metadata={'total_size': 352494542848}, sharded=True, weight_map={...}, files_metadata={...}) >>> len(metadata.files_metadata) 72 # All safetensors files have been fetched # Parse repo with sharded model >>> get_safetensors_metadata("runwayml/stable-diffusion-v1-5") NotASafetensorsRepoError: 'runwayml/stable-diffusion-v1-5' is not a safetensors repo. Couldn't find 'model.safetensors.index.json' or 'model.safetensors' files. ``` """ if self.file_exists( # Single safetensors file => non-sharded model repo_id=repo_id, filename=constants.SAFETENSORS_SINGLE_FILE, repo_type=repo_type, revision=revision, token=token, ): file_metadata = self.parse_safetensors_file_metadata( repo_id=repo_id, filename=constants.SAFETENSORS_SINGLE_FILE, repo_type=repo_type, revision=revision, token=token, ) return SafetensorsRepoMetadata( metadata=None, sharded=False, weight_map={ tensor_name: constants.SAFETENSORS_SINGLE_FILE for tensor_name in file_metadata.tensors.keys() }, files_metadata={constants.SAFETENSORS_SINGLE_FILE: file_metadata}, ) elif self.file_exists( # Multiple safetensors files => sharded with index repo_id=repo_id, filename=constants.SAFETENSORS_INDEX_FILE, repo_type=repo_type, revision=revision, token=token, ): # Fetch index index_file = self.hf_hub_download( repo_id=repo_id, filename=constants.SAFETENSORS_INDEX_FILE, repo_type=repo_type, revision=revision, token=token, ) with open(index_file) as f: index = json.load(f) weight_map = index.get("weight_map", {}) # Fetch metadata per shard files_metadata = {} def _parse(filename: str) -> None: files_metadata[filename] = self.parse_safetensors_file_metadata( repo_id=repo_id, filename=filename, repo_type=repo_type, revision=revision, token=token ) thread_map( _parse, set(weight_map.values()), desc="Parse safetensors files", tqdm_class=hf_tqdm, ) return SafetensorsRepoMetadata( metadata=index.get("metadata", None), sharded=True, weight_map=weight_map, files_metadata=files_metadata, ) else: # Not a safetensors repo raise NotASafetensorsRepoError( f"'{repo_id}' is not a safetensors repo. Couldn't find '{constants.SAFETENSORS_INDEX_FILE}' or '{constants.SAFETENSORS_SINGLE_FILE}' files." ) def parse_safetensors_file_metadata( self, repo_id: str, filename: str, *, repo_type: Optional[str] = None, revision: Optional[str] = None, token: Union[bool, str, None] = None, ) -> SafetensorsFileMetadata: """ Parse metadata from a safetensors file on the Hub. To parse metadata from all safetensors files in a repo at once, use [`get_safetensors_metadata`]. For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format. Args: repo_id (`str`): A user or an organization name and a repo name separated by a `/`. filename (`str`): The name of the file in the repo. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if the file is in a dataset or space, `None` or `"model"` if in a model. Default is `None`. revision (`str`, *optional*): The git revision to fetch the file from. Can be a branch name, a tag, or a commit hash. Defaults to the head of the `"main"` branch. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`SafetensorsFileMetadata`]: information related to a safetensors file. Raises: [`NotASafetensorsRepoError`]: If the repo is not a safetensors repo i.e. doesn't have either a `model.safetensors` or a `model.safetensors.index.json` file. [`SafetensorsParsingError`]: If a safetensors file header couldn't be parsed correctly. """ url = hf_hub_url( repo_id=repo_id, filename=filename, repo_type=repo_type, revision=revision, endpoint=self.endpoint ) _headers = self._build_hf_headers(token=token) # 1. Fetch first 100kb # Empirically, 97% of safetensors files have a metadata size < 100kb (over the top 1000 models on the Hub). # We assume fetching 100kb is faster than making 2 GET requests. Therefore we always fetch the first 100kb to # avoid the 2nd GET in most cases. # See https://github.com/huggingface/huggingface_hub/pull/1855#discussion_r1404286419. response = get_session().get(url, headers={**_headers, "range": "bytes=0-100000"}) hf_raise_for_status(response) # 2. Parse metadata size metadata_size = struct.unpack(" constants.SAFETENSORS_MAX_HEADER_LENGTH: raise SafetensorsParsingError( f"Failed to parse safetensors header for '{filename}' (repo '{repo_id}', revision " f"'{revision or constants.DEFAULT_REVISION}'): safetensors header is too big. Maximum supported size is " f"{constants.SAFETENSORS_MAX_HEADER_LENGTH} bytes (got {metadata_size})." ) # 3.a. Get metadata from payload if metadata_size <= 100000: metadata_as_bytes = response.content[8 : 8 + metadata_size] else: # 3.b. Request full metadata response = get_session().get(url, headers={**_headers, "range": f"bytes=8-{metadata_size + 7}"}) hf_raise_for_status(response) metadata_as_bytes = response.content # 4. Parse json header try: metadata_as_dict = json.loads(metadata_as_bytes.decode(errors="ignore")) except json.JSONDecodeError as e: raise SafetensorsParsingError( f"Failed to parse safetensors header for '{filename}' (repo '{repo_id}', revision " f"'{revision or constants.DEFAULT_REVISION}'): header is not json-encoded string. Please make sure this is a " "correctly formatted safetensors file." ) from e try: return SafetensorsFileMetadata( metadata=metadata_as_dict.get("__metadata__", {}), tensors={ key: TensorInfo( dtype=tensor["dtype"], shape=tensor["shape"], data_offsets=tuple(tensor["data_offsets"]), # type: ignore ) for key, tensor in metadata_as_dict.items() if key != "__metadata__" }, ) except (KeyError, IndexError) as e: raise SafetensorsParsingError( f"Failed to parse safetensors header for '{filename}' (repo '{repo_id}', revision " f"'{revision or constants.DEFAULT_REVISION}'): header format not recognized. Please make sure this is a correctly" " formatted safetensors file." ) from e @validate_hf_hub_args def create_branch( self, repo_id: str, *, branch: str, revision: Optional[str] = None, token: Union[bool, str, None] = None, repo_type: Optional[str] = None, exist_ok: bool = False, ) -> None: """ Create a new branch for a repo on the Hub, starting from the specified revision (defaults to `main`). To find a revision suiting your needs, you can use [`list_repo_refs`] or [`list_repo_commits`]. Args: repo_id (`str`): The repository in which the branch will be created. Example: `"user/my-cool-model"`. branch (`str`): The name of the branch to create. revision (`str`, *optional*): The git revision to create the branch from. It can be a branch name or the OID/SHA of a commit, as a hexadecimal string. Defaults to the head of the `"main"` branch. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if creating a branch on a dataset or space, `None` or `"model"` if tagging a model. Default is `None`. exist_ok (`bool`, *optional*, defaults to `False`): If `True`, do not raise an error if branch already exists. Raises: [`~utils.RepositoryNotFoundError`]: If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo does not exist. [`~utils.BadRequestError`]: If invalid reference for a branch. Ex: `refs/pr/5` or 'refs/foo/bar'. [`~utils.HfHubHTTPError`]: If the branch already exists on the repo (error 409) and `exist_ok` is set to `False`. """ if repo_type is None: repo_type = constants.REPO_TYPE_MODEL branch = quote(branch, safe="") # Prepare request branch_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/branch/{branch}" headers = self._build_hf_headers(token=token) payload = {} if revision is not None: payload["startingPoint"] = revision # Create branch response = get_session().post(url=branch_url, headers=headers, json=payload) try: hf_raise_for_status(response) except HfHubHTTPError as e: if exist_ok and e.response.status_code == 409: return elif exist_ok and e.response.status_code == 403: # No write permission on the namespace but branch might already exist try: refs = self.list_repo_refs(repo_id=repo_id, repo_type=repo_type, token=token) for branch_ref in refs.branches: if branch_ref.name == branch: return # Branch already exists => do not raise except HfHubHTTPError: pass # We raise the original error if the branch does not exist raise @validate_hf_hub_args def delete_branch( self, repo_id: str, *, branch: str, token: Union[bool, str, None] = None, repo_type: Optional[str] = None, ) -> None: """ Delete a branch from a repo on the Hub. Args: repo_id (`str`): The repository in which a branch will be deleted. Example: `"user/my-cool-model"`. branch (`str`): The name of the branch to delete. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if creating a branch on a dataset or space, `None` or `"model"` if tagging a model. Default is `None`. Raises: [`~utils.RepositoryNotFoundError`]: If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo does not exist. [`~utils.HfHubHTTPError`]: If trying to delete a protected branch. Ex: `main` cannot be deleted. [`~utils.HfHubHTTPError`]: If trying to delete a branch that does not exist. """ if repo_type is None: repo_type = constants.REPO_TYPE_MODEL branch = quote(branch, safe="") # Prepare request branch_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/branch/{branch}" headers = self._build_hf_headers(token=token) # Delete branch response = get_session().delete(url=branch_url, headers=headers) hf_raise_for_status(response) @validate_hf_hub_args def create_tag( self, repo_id: str, *, tag: str, tag_message: Optional[str] = None, revision: Optional[str] = None, token: Union[bool, str, None] = None, repo_type: Optional[str] = None, exist_ok: bool = False, ) -> None: """ Tag a given commit of a repo on the Hub. Args: repo_id (`str`): The repository in which a commit will be tagged. Example: `"user/my-cool-model"`. tag (`str`): The name of the tag to create. tag_message (`str`, *optional*): The description of the tag to create. revision (`str`, *optional*): The git revision to tag. It can be a branch name or the OID/SHA of a commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. Defaults to the head of the `"main"` branch. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if tagging a dataset or space, `None` or `"model"` if tagging a model. Default is `None`. exist_ok (`bool`, *optional*, defaults to `False`): If `True`, do not raise an error if tag already exists. Raises: [`~utils.RepositoryNotFoundError`]: If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo does not exist. [`~utils.RevisionNotFoundError`]: If revision is not found (error 404) on the repo. [`~utils.HfHubHTTPError`]: If the branch already exists on the repo (error 409) and `exist_ok` is set to `False`. """ if repo_type is None: repo_type = constants.REPO_TYPE_MODEL revision = quote(revision, safe="") if revision is not None else constants.DEFAULT_REVISION # Prepare request tag_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/tag/{revision}" headers = self._build_hf_headers(token=token) payload = {"tag": tag} if tag_message is not None: payload["message"] = tag_message # Tag response = get_session().post(url=tag_url, headers=headers, json=payload) try: hf_raise_for_status(response) except HfHubHTTPError as e: if not (e.response.status_code == 409 and exist_ok): raise @validate_hf_hub_args def delete_tag( self, repo_id: str, *, tag: str, token: Union[bool, str, None] = None, repo_type: Optional[str] = None, ) -> None: """ Delete a tag from a repo on the Hub. Args: repo_id (`str`): The repository in which a tag will be deleted. Example: `"user/my-cool-model"`. tag (`str`): The name of the tag to delete. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if tagging a dataset or space, `None` or `"model"` if tagging a model. Default is `None`. Raises: [`~utils.RepositoryNotFoundError`]: If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo does not exist. [`~utils.RevisionNotFoundError`]: If tag is not found. """ if repo_type is None: repo_type = constants.REPO_TYPE_MODEL tag = quote(tag, safe="") # Prepare request tag_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/tag/{tag}" headers = self._build_hf_headers(token=token) # Un-tag response = get_session().delete(url=tag_url, headers=headers) hf_raise_for_status(response) @validate_hf_hub_args def get_full_repo_name( self, model_id: str, *, organization: Optional[str] = None, token: Union[bool, str, None] = None, ): """ Returns the repository name for a given model ID and optional organization. Args: model_id (`str`): The name of the model. organization (`str`, *optional*): If passed, the repository name will be in the organization namespace instead of the user namespace. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: `str`: The repository name in the user's namespace ({username}/{model_id}) if no organization is passed, and under the organization namespace ({organization}/{model_id}) otherwise. """ if organization is None: if "/" in model_id: username = model_id.split("/")[0] else: username = self.whoami(token=token)["name"] # type: ignore return f"{username}/{model_id}" else: return f"{organization}/{model_id}" @validate_hf_hub_args def get_repo_discussions( self, repo_id: str, *, author: Optional[str] = None, discussion_type: Optional[constants.DiscussionTypeFilter] = None, discussion_status: Optional[constants.DiscussionStatusFilter] = None, repo_type: Optional[str] = None, token: Union[bool, str, None] = None, ) -> Iterator[Discussion]: """ Fetches Discussions and Pull Requests for the given repo. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. author (`str`, *optional*): Pass a value to filter by discussion author. `None` means no filter. Default is `None`. discussion_type (`str`, *optional*): Set to `"pull_request"` to fetch only pull requests, `"discussion"` to fetch only discussions. Set to `"all"` or `None` to fetch both. Default is `None`. discussion_status (`str`, *optional*): Set to `"open"` (respectively `"closed"`) to fetch only open (respectively closed) discussions. Set to `"all"` or `None` to fetch both. Default is `None`. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if fetching from a dataset or space, `None` or `"model"` if fetching from a model. Default is `None`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: `Iterator[Discussion]`: An iterator of [`Discussion`] objects. Example: Collecting all discussions of a repo in a list: ```python >>> from huggingface_hub import get_repo_discussions >>> discussions_list = list(get_repo_discussions(repo_id="bert-base-uncased")) ``` Iterating over discussions of a repo: ```python >>> from huggingface_hub import get_repo_discussions >>> for discussion in get_repo_discussions(repo_id="bert-base-uncased"): ... print(discussion.num, discussion.title) ``` """ if repo_type not in constants.REPO_TYPES: raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") if repo_type is None: repo_type = constants.REPO_TYPE_MODEL if discussion_type is not None and discussion_type not in constants.DISCUSSION_TYPES: raise ValueError(f"Invalid discussion_type, must be one of {constants.DISCUSSION_TYPES}") if discussion_status is not None and discussion_status not in constants.DISCUSSION_STATUS: raise ValueError(f"Invalid discussion_status, must be one of {constants.DISCUSSION_STATUS}") headers = self._build_hf_headers(token=token) path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions" params: Dict[str, Union[str, int]] = {} if discussion_type is not None: params["type"] = discussion_type if discussion_status is not None: params["status"] = discussion_status if author is not None: params["author"] = author def _fetch_discussion_page(page_index: int): params["p"] = page_index resp = get_session().get(path, headers=headers, params=params) hf_raise_for_status(resp) paginated_discussions = resp.json() total = paginated_discussions["count"] start = paginated_discussions["start"] discussions = paginated_discussions["discussions"] has_next = (start + len(discussions)) < total return discussions, has_next has_next, page_index = True, 0 while has_next: discussions, has_next = _fetch_discussion_page(page_index=page_index) for discussion in discussions: yield Discussion( title=discussion["title"], num=discussion["num"], author=discussion.get("author", {}).get("name", "deleted"), created_at=parse_datetime(discussion["createdAt"]), status=discussion["status"], repo_id=discussion["repo"]["name"], repo_type=discussion["repo"]["type"], is_pull_request=discussion["isPullRequest"], endpoint=self.endpoint, ) page_index = page_index + 1 @validate_hf_hub_args def get_discussion_details( self, repo_id: str, discussion_num: int, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None, ) -> DiscussionWithDetails: """Fetches a Discussion's / Pull Request 's details from the Hub. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. discussion_num (`int`): The number of the Discussion or Pull Request . Must be a strictly positive integer. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to a model. Default is `None`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`DiscussionWithDetails`] Raises the following errors: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) if the HuggingFace API returned an error - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid - [`~utils.RepositoryNotFoundError`] If the repository to download from cannot be found. This may be because it doesn't exist, or because it is set to `private` and you do not have access. """ if not isinstance(discussion_num, int) or discussion_num <= 0: raise ValueError("Invalid discussion_num, must be a positive integer") if repo_type not in constants.REPO_TYPES: raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") if repo_type is None: repo_type = constants.REPO_TYPE_MODEL path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions/{discussion_num}" headers = self._build_hf_headers(token=token) resp = get_session().get(path, params={"diff": "1"}, headers=headers) hf_raise_for_status(resp) discussion_details = resp.json() is_pull_request = discussion_details["isPullRequest"] target_branch = discussion_details["changes"]["base"] if is_pull_request else None conflicting_files = discussion_details["filesWithConflicts"] if is_pull_request else None merge_commit_oid = discussion_details["changes"].get("mergeCommitId", None) if is_pull_request else None return DiscussionWithDetails( title=discussion_details["title"], num=discussion_details["num"], author=discussion_details.get("author", {}).get("name", "deleted"), created_at=parse_datetime(discussion_details["createdAt"]), status=discussion_details["status"], repo_id=discussion_details["repo"]["name"], repo_type=discussion_details["repo"]["type"], is_pull_request=discussion_details["isPullRequest"], events=[deserialize_event(evt) for evt in discussion_details["events"]], conflicting_files=conflicting_files, target_branch=target_branch, merge_commit_oid=merge_commit_oid, diff=discussion_details.get("diff"), endpoint=self.endpoint, ) @validate_hf_hub_args def create_discussion( self, repo_id: str, title: str, *, token: Union[bool, str, None] = None, description: Optional[str] = None, repo_type: Optional[str] = None, pull_request: bool = False, ) -> DiscussionWithDetails: """Creates a Discussion or Pull Request. Pull Requests created programmatically will be in `"draft"` status. Creating a Pull Request with changes can also be done at once with [`HfApi.create_commit`]. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. title (`str`): The title of the discussion. It can be up to 200 characters long, and must be at least 3 characters long. Leading and trailing whitespaces will be stripped. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. description (`str`, *optional*): An optional description for the Pull Request. Defaults to `"Discussion opened with the huggingface_hub Python library"` pull_request (`bool`, *optional*): Whether to create a Pull Request or discussion. If `True`, creates a Pull Request. If `False`, creates a discussion. Defaults to `False`. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to a model. Default is `None`. Returns: [`DiscussionWithDetails`] Raises the following errors: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) if the HuggingFace API returned an error - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid - [`~utils.RepositoryNotFoundError`] If the repository to download from cannot be found. This may be because it doesn't exist, or because it is set to `private` and you do not have access. """ if repo_type not in constants.REPO_TYPES: raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") if repo_type is None: repo_type = constants.REPO_TYPE_MODEL if description is not None: description = description.strip() description = ( description if description else ( f"{'Pull Request' if pull_request else 'Discussion'} opened with the" " [huggingface_hub Python" " library](https://huggingface.co/docs/huggingface_hub)" ) ) headers = self._build_hf_headers(token=token) resp = get_session().post( f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions", json={ "title": title.strip(), "description": description, "pullRequest": pull_request, }, headers=headers, ) hf_raise_for_status(resp) num = resp.json()["num"] return self.get_discussion_details( repo_id=repo_id, repo_type=repo_type, discussion_num=num, token=token, ) @validate_hf_hub_args def create_pull_request( self, repo_id: str, title: str, *, token: Union[bool, str, None] = None, description: Optional[str] = None, repo_type: Optional[str] = None, ) -> DiscussionWithDetails: """Creates a Pull Request . Pull Requests created programmatically will be in `"draft"` status. Creating a Pull Request with changes can also be done at once with [`HfApi.create_commit`]; This is a wrapper around [`HfApi.create_discussion`]. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. title (`str`): The title of the discussion. It can be up to 200 characters long, and must be at least 3 characters long. Leading and trailing whitespaces will be stripped. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. description (`str`, *optional*): An optional description for the Pull Request. Defaults to `"Discussion opened with the huggingface_hub Python library"` repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to a model. Default is `None`. Returns: [`DiscussionWithDetails`] Raises the following errors: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) if the HuggingFace API returned an error - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid - [`~utils.RepositoryNotFoundError`] If the repository to download from cannot be found. This may be because it doesn't exist, or because it is set to `private` and you do not have access. """ return self.create_discussion( repo_id=repo_id, title=title, token=token, description=description, repo_type=repo_type, pull_request=True, ) def _post_discussion_changes( self, *, repo_id: str, discussion_num: int, resource: str, body: Optional[dict] = None, token: Union[bool, str, None] = None, repo_type: Optional[str] = None, ) -> requests.Response: """Internal utility to POST changes to a Discussion or Pull Request""" if not isinstance(discussion_num, int) or discussion_num <= 0: raise ValueError("Invalid discussion_num, must be a positive integer") if repo_type not in constants.REPO_TYPES: raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") if repo_type is None: repo_type = constants.REPO_TYPE_MODEL repo_id = f"{repo_type}s/{repo_id}" path = f"{self.endpoint}/api/{repo_id}/discussions/{discussion_num}/{resource}" headers = self._build_hf_headers(token=token) resp = requests.post(path, headers=headers, json=body) hf_raise_for_status(resp) return resp @validate_hf_hub_args def comment_discussion( self, repo_id: str, discussion_num: int, comment: str, *, token: Union[bool, str, None] = None, repo_type: Optional[str] = None, ) -> DiscussionComment: """Creates a new comment on the given Discussion. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. discussion_num (`int`): The number of the Discussion or Pull Request . Must be a strictly positive integer. comment (`str`): The content of the comment to create. Comments support markdown formatting. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to a model. Default is `None`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`DiscussionComment`]: the newly created comment Examples: ```python >>> comment = \"\"\" ... Hello @otheruser! ... ... # This is a title ... ... **This is bold**, *this is italic* and ~this is strikethrough~ ... And [this](http://url) is a link ... \"\"\" >>> HfApi().comment_discussion( ... repo_id="username/repo_name", ... discussion_num=34 ... comment=comment ... ) # DiscussionComment(id='deadbeef0000000', type='comment', ...) ``` Raises the following errors: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) if the HuggingFace API returned an error - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid - [`~utils.RepositoryNotFoundError`] If the repository to download from cannot be found. This may be because it doesn't exist, or because it is set to `private` and you do not have access. """ resp = self._post_discussion_changes( repo_id=repo_id, repo_type=repo_type, discussion_num=discussion_num, token=token, resource="comment", body={"comment": comment}, ) return deserialize_event(resp.json()["newMessage"]) # type: ignore @validate_hf_hub_args def rename_discussion( self, repo_id: str, discussion_num: int, new_title: str, *, token: Union[bool, str, None] = None, repo_type: Optional[str] = None, ) -> DiscussionTitleChange: """Renames a Discussion. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. discussion_num (`int`): The number of the Discussion or Pull Request . Must be a strictly positive integer. new_title (`str`): The new title for the discussion repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to a model. Default is `None`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`DiscussionTitleChange`]: the title change event Examples: ```python >>> new_title = "New title, fixing a typo" >>> HfApi().rename_discussion( ... repo_id="username/repo_name", ... discussion_num=34 ... new_title=new_title ... ) # DiscussionTitleChange(id='deadbeef0000000', type='title-change', ...) ``` Raises the following errors: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) if the HuggingFace API returned an error - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid - [`~utils.RepositoryNotFoundError`] If the repository to download from cannot be found. This may be because it doesn't exist, or because it is set to `private` and you do not have access. """ resp = self._post_discussion_changes( repo_id=repo_id, repo_type=repo_type, discussion_num=discussion_num, token=token, resource="title", body={"title": new_title}, ) return deserialize_event(resp.json()["newTitle"]) # type: ignore @validate_hf_hub_args def change_discussion_status( self, repo_id: str, discussion_num: int, new_status: Literal["open", "closed"], *, token: Union[bool, str, None] = None, comment: Optional[str] = None, repo_type: Optional[str] = None, ) -> DiscussionStatusChange: """Closes or re-opens a Discussion or Pull Request. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. discussion_num (`int`): The number of the Discussion or Pull Request . Must be a strictly positive integer. new_status (`str`): The new status for the discussion, either `"open"` or `"closed"`. comment (`str`, *optional*): An optional comment to post with the status change. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to a model. Default is `None`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`DiscussionStatusChange`]: the status change event Examples: ```python >>> new_title = "New title, fixing a typo" >>> HfApi().rename_discussion( ... repo_id="username/repo_name", ... discussion_num=34 ... new_title=new_title ... ) # DiscussionStatusChange(id='deadbeef0000000', type='status-change', ...) ``` Raises the following errors: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) if the HuggingFace API returned an error - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid - [`~utils.RepositoryNotFoundError`] If the repository to download from cannot be found. This may be because it doesn't exist, or because it is set to `private` and you do not have access. """ if new_status not in ["open", "closed"]: raise ValueError("Invalid status, valid statuses are: 'open' and 'closed'") body: Dict[str, str] = {"status": new_status} if comment and comment.strip(): body["comment"] = comment.strip() resp = self._post_discussion_changes( repo_id=repo_id, repo_type=repo_type, discussion_num=discussion_num, token=token, resource="status", body=body, ) return deserialize_event(resp.json()["newStatus"]) # type: ignore @validate_hf_hub_args def merge_pull_request( self, repo_id: str, discussion_num: int, *, token: Union[bool, str, None] = None, comment: Optional[str] = None, repo_type: Optional[str] = None, ): """Merges a Pull Request. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. discussion_num (`int`): The number of the Discussion or Pull Request . Must be a strictly positive integer. comment (`str`, *optional*): An optional comment to post with the status change. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to a model. Default is `None`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`DiscussionStatusChange`]: the status change event Raises the following errors: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) if the HuggingFace API returned an error - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid - [`~utils.RepositoryNotFoundError`] If the repository to download from cannot be found. This may be because it doesn't exist, or because it is set to `private` and you do not have access. """ self._post_discussion_changes( repo_id=repo_id, repo_type=repo_type, discussion_num=discussion_num, token=token, resource="merge", body={"comment": comment.strip()} if comment and comment.strip() else None, ) @validate_hf_hub_args def edit_discussion_comment( self, repo_id: str, discussion_num: int, comment_id: str, new_content: str, *, token: Union[bool, str, None] = None, repo_type: Optional[str] = None, ) -> DiscussionComment: """Edits a comment on a Discussion / Pull Request. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. discussion_num (`int`): The number of the Discussion or Pull Request . Must be a strictly positive integer. comment_id (`str`): The ID of the comment to edit. new_content (`str`): The new content of the comment. Comments support markdown formatting. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to a model. Default is `None`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`DiscussionComment`]: the edited comment Raises the following errors: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) if the HuggingFace API returned an error - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid - [`~utils.RepositoryNotFoundError`] If the repository to download from cannot be found. This may be because it doesn't exist, or because it is set to `private` and you do not have access. """ resp = self._post_discussion_changes( repo_id=repo_id, repo_type=repo_type, discussion_num=discussion_num, token=token, resource=f"comment/{comment_id.lower()}/edit", body={"content": new_content}, ) return deserialize_event(resp.json()["updatedComment"]) # type: ignore @validate_hf_hub_args def hide_discussion_comment( self, repo_id: str, discussion_num: int, comment_id: str, *, token: Union[bool, str, None] = None, repo_type: Optional[str] = None, ) -> DiscussionComment: """Hides a comment on a Discussion / Pull Request. Hidden comments' content cannot be retrieved anymore. Hiding a comment is irreversible. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. discussion_num (`int`): The number of the Discussion or Pull Request . Must be a strictly positive integer. comment_id (`str`): The ID of the comment to edit. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to a model. Default is `None`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`DiscussionComment`]: the hidden comment Raises the following errors: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) if the HuggingFace API returned an error - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid - [`~utils.RepositoryNotFoundError`] If the repository to download from cannot be found. This may be because it doesn't exist, or because it is set to `private` and you do not have access. """ warnings.warn( "Hidden comments' content cannot be retrieved anymore. Hiding a comment is irreversible.", UserWarning, ) resp = self._post_discussion_changes( repo_id=repo_id, repo_type=repo_type, discussion_num=discussion_num, token=token, resource=f"comment/{comment_id.lower()}/hide", ) return deserialize_event(resp.json()["updatedComment"]) # type: ignore @validate_hf_hub_args def add_space_secret( self, repo_id: str, key: str, value: str, *, description: Optional[str] = None, token: Union[bool, str, None] = None, ) -> None: """Adds or updates a secret in a Space. Secrets allow to set secret keys or tokens to a Space without hardcoding them. For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets. Args: repo_id (`str`): ID of the repo to update. Example: `"bigcode/in-the-stack"`. key (`str`): Secret key. Example: `"GITHUB_API_KEY"` value (`str`): Secret value. Example: `"your_github_api_key"`. description (`str`, *optional*): Secret description. Example: `"Github API key to access the Github API"`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. """ payload = {"key": key, "value": value} if description is not None: payload["description"] = description r = get_session().post( f"{self.endpoint}/api/spaces/{repo_id}/secrets", headers=self._build_hf_headers(token=token), json=payload, ) hf_raise_for_status(r) @validate_hf_hub_args def delete_space_secret(self, repo_id: str, key: str, *, token: Union[bool, str, None] = None) -> None: """Deletes a secret from a Space. Secrets allow to set secret keys or tokens to a Space without hardcoding them. For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets. Args: repo_id (`str`): ID of the repo to update. Example: `"bigcode/in-the-stack"`. key (`str`): Secret key. Example: `"GITHUB_API_KEY"`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. """ r = get_session().delete( f"{self.endpoint}/api/spaces/{repo_id}/secrets", headers=self._build_hf_headers(token=token), json={"key": key}, ) hf_raise_for_status(r) @validate_hf_hub_args def get_space_variables(self, repo_id: str, *, token: Union[bool, str, None] = None) -> Dict[str, SpaceVariable]: """Gets all variables from a Space. Variables allow to set environment variables to a Space without hardcoding them. For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables Args: repo_id (`str`): ID of the repo to query. Example: `"bigcode/in-the-stack"`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. """ r = get_session().get( f"{self.endpoint}/api/spaces/{repo_id}/variables", headers=self._build_hf_headers(token=token), ) hf_raise_for_status(r) return {k: SpaceVariable(k, v) for k, v in r.json().items()} @validate_hf_hub_args def add_space_variable( self, repo_id: str, key: str, value: str, *, description: Optional[str] = None, token: Union[bool, str, None] = None, ) -> Dict[str, SpaceVariable]: """Adds or updates a variable in a Space. Variables allow to set environment variables to a Space without hardcoding them. For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables Args: repo_id (`str`): ID of the repo to update. Example: `"bigcode/in-the-stack"`. key (`str`): Variable key. Example: `"MODEL_REPO_ID"` value (`str`): Variable value. Example: `"the_model_repo_id"`. description (`str`): Description of the variable. Example: `"Model Repo ID of the implemented model"`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. """ payload = {"key": key, "value": value} if description is not None: payload["description"] = description r = get_session().post( f"{self.endpoint}/api/spaces/{repo_id}/variables", headers=self._build_hf_headers(token=token), json=payload, ) hf_raise_for_status(r) return {k: SpaceVariable(k, v) for k, v in r.json().items()} @validate_hf_hub_args def delete_space_variable( self, repo_id: str, key: str, *, token: Union[bool, str, None] = None ) -> Dict[str, SpaceVariable]: """Deletes a variable from a Space. Variables allow to set environment variables to a Space without hardcoding them. For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables Args: repo_id (`str`): ID of the repo to update. Example: `"bigcode/in-the-stack"`. key (`str`): Variable key. Example: `"MODEL_REPO_ID"` token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. """ r = get_session().delete( f"{self.endpoint}/api/spaces/{repo_id}/variables", headers=self._build_hf_headers(token=token), json={"key": key}, ) hf_raise_for_status(r) return {k: SpaceVariable(k, v) for k, v in r.json().items()} @validate_hf_hub_args def get_space_runtime(self, repo_id: str, *, token: Union[bool, str, None] = None) -> SpaceRuntime: """Gets runtime information about a Space. Args: repo_id (`str`): ID of the repo to update. Example: `"bigcode/in-the-stack"`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware. """ r = get_session().get( f"{self.endpoint}/api/spaces/{repo_id}/runtime", headers=self._build_hf_headers(token=token) ) hf_raise_for_status(r) return SpaceRuntime(r.json()) @validate_hf_hub_args def request_space_hardware( self, repo_id: str, hardware: SpaceHardware, *, token: Union[bool, str, None] = None, sleep_time: Optional[int] = None, ) -> SpaceRuntime: """Request new hardware for a Space. Args: repo_id (`str`): ID of the repo to update. Example: `"bigcode/in-the-stack"`. hardware (`str` or [`SpaceHardware`]): Hardware on which to run the Space. Example: `"t4-medium"`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. sleep_time (`int`, *optional*): Number of seconds of inactivity to wait before a Space is put to sleep. Set to `-1` if you don't want your Space to sleep (default behavior for upgraded hardware). For free hardware, you can't configure the sleep time (value is fixed to 48 hours of inactivity). See https://huggingface.co/docs/hub/spaces-gpus#sleep-time for more details. Returns: [`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware. It is also possible to request hardware directly when creating the Space repo! See [`create_repo`] for details. """ if sleep_time is not None and hardware == SpaceHardware.CPU_BASIC: warnings.warn( "If your Space runs on the default 'cpu-basic' hardware, it will go to sleep if inactive for more" " than 48 hours. This value is not configurable. If you don't want your Space to deactivate or if" " you want to set a custom sleep time, you need to upgrade to a paid Hardware.", UserWarning, ) payload: Dict[str, Any] = {"flavor": hardware} if sleep_time is not None: payload["sleepTimeSeconds"] = sleep_time r = get_session().post( f"{self.endpoint}/api/spaces/{repo_id}/hardware", headers=self._build_hf_headers(token=token), json=payload, ) hf_raise_for_status(r) return SpaceRuntime(r.json()) @validate_hf_hub_args def set_space_sleep_time( self, repo_id: str, sleep_time: int, *, token: Union[bool, str, None] = None ) -> SpaceRuntime: """Set a custom sleep time for a Space running on upgraded hardware.. Your Space will go to sleep after X seconds of inactivity. You are not billed when your Space is in "sleep" mode. If a new visitor lands on your Space, it will "wake it up". Only upgraded hardware can have a configurable sleep time. To know more about the sleep stage, please refer to https://huggingface.co/docs/hub/spaces-gpus#sleep-time. Args: repo_id (`str`): ID of the repo to update. Example: `"bigcode/in-the-stack"`. sleep_time (`int`, *optional*): Number of seconds of inactivity to wait before a Space is put to sleep. Set to `-1` if you don't want your Space to pause (default behavior for upgraded hardware). For free hardware, you can't configure the sleep time (value is fixed to 48 hours of inactivity). See https://huggingface.co/docs/hub/spaces-gpus#sleep-time for more details. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware. It is also possible to set a custom sleep time when requesting hardware with [`request_space_hardware`]. """ r = get_session().post( f"{self.endpoint}/api/spaces/{repo_id}/sleeptime", headers=self._build_hf_headers(token=token), json={"seconds": sleep_time}, ) hf_raise_for_status(r) runtime = SpaceRuntime(r.json()) hardware = runtime.requested_hardware or runtime.hardware if hardware == SpaceHardware.CPU_BASIC: warnings.warn( "If your Space runs on the default 'cpu-basic' hardware, it will go to sleep if inactive for more" " than 48 hours. This value is not configurable. If you don't want your Space to deactivate or if" " you want to set a custom sleep time, you need to upgrade to a paid Hardware.", UserWarning, ) return runtime @validate_hf_hub_args def pause_space(self, repo_id: str, *, token: Union[bool, str, None] = None) -> SpaceRuntime: """Pause your Space. A paused Space stops executing until manually restarted by its owner. This is different from the sleeping state in which free Spaces go after 48h of inactivity. Paused time is not billed to your account, no matter the hardware you've selected. To restart your Space, use [`restart_space`] and go to your Space settings page. For more details, please visit [the docs](https://huggingface.co/docs/hub/spaces-gpus#pause). Args: repo_id (`str`): ID of the Space to pause. Example: `"Salesforce/BLIP2"`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`SpaceRuntime`]: Runtime information about your Space including `stage=PAUSED` and requested hardware. Raises: [`~utils.RepositoryNotFoundError`]: If your Space is not found (error 404). Most probably wrong repo_id or your space is private but you are not authenticated. [`~utils.HfHubHTTPError`]: 403 Forbidden: only the owner of a Space can pause it. If you want to manage a Space that you don't own, either ask the owner by opening a Discussion or duplicate the Space. [`~utils.BadRequestError`]: If your Space is a static Space. Static Spaces are always running and never billed. If you want to hide a static Space, you can set it to private. """ r = get_session().post( f"{self.endpoint}/api/spaces/{repo_id}/pause", headers=self._build_hf_headers(token=token) ) hf_raise_for_status(r) return SpaceRuntime(r.json()) @validate_hf_hub_args def restart_space( self, repo_id: str, *, token: Union[bool, str, None] = None, factory_reboot: bool = False ) -> SpaceRuntime: """Restart your Space. This is the only way to programmatically restart a Space if you've put it on Pause (see [`pause_space`]). You must be the owner of the Space to restart it. If you are using an upgraded hardware, your account will be billed as soon as the Space is restarted. You can trigger a restart no matter the current state of a Space. For more details, please visit [the docs](https://huggingface.co/docs/hub/spaces-gpus#pause). Args: repo_id (`str`): ID of the Space to restart. Example: `"Salesforce/BLIP2"`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. factory_reboot (`bool`, *optional*): If `True`, the Space will be rebuilt from scratch without caching any requirements. Returns: [`SpaceRuntime`]: Runtime information about your Space. Raises: [`~utils.RepositoryNotFoundError`]: If your Space is not found (error 404). Most probably wrong repo_id or your space is private but you are not authenticated. [`~utils.HfHubHTTPError`]: 403 Forbidden: only the owner of a Space can restart it. If you want to restart a Space that you don't own, either ask the owner by opening a Discussion or duplicate the Space. [`~utils.BadRequestError`]: If your Space is a static Space. Static Spaces are always running and never billed. If you want to hide a static Space, you can set it to private. """ params = {} if factory_reboot: params["factory"] = "true" r = get_session().post( f"{self.endpoint}/api/spaces/{repo_id}/restart", headers=self._build_hf_headers(token=token), params=params ) hf_raise_for_status(r) return SpaceRuntime(r.json()) @validate_hf_hub_args def duplicate_space( self, from_id: str, to_id: Optional[str] = None, *, private: Optional[bool] = None, token: Union[bool, str, None] = None, exist_ok: bool = False, hardware: Optional[SpaceHardware] = None, storage: Optional[SpaceStorage] = None, sleep_time: Optional[int] = None, secrets: Optional[List[Dict[str, str]]] = None, variables: Optional[List[Dict[str, str]]] = None, ) -> RepoUrl: """Duplicate a Space. Programmatically duplicate a Space. The new Space will be created in your account and will be in the same state as the original Space (running or paused). You can duplicate a Space no matter the current state of a Space. Args: from_id (`str`): ID of the Space to duplicate. Example: `"pharma/CLIP-Interrogator"`. to_id (`str`, *optional*): ID of the new Space. Example: `"dog/CLIP-Interrogator"`. If not provided, the new Space will have the same name as the original Space, but in your account. private (`bool`, *optional*): Whether the new Space should be private or not. Defaults to the same privacy as the original Space. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. exist_ok (`bool`, *optional*, defaults to `False`): If `True`, do not raise an error if repo already exists. hardware (`SpaceHardware` or `str`, *optional*): Choice of Hardware. Example: `"t4-medium"`. See [`SpaceHardware`] for a complete list. storage (`SpaceStorage` or `str`, *optional*): Choice of persistent storage tier. Example: `"small"`. See [`SpaceStorage`] for a complete list. sleep_time (`int`, *optional*): Number of seconds of inactivity to wait before a Space is put to sleep. Set to `-1` if you don't want your Space to sleep (default behavior for upgraded hardware). For free hardware, you can't configure the sleep time (value is fixed to 48 hours of inactivity). See https://huggingface.co/docs/hub/spaces-gpus#sleep-time for more details. secrets (`List[Dict[str, str]]`, *optional*): A list of secret keys to set in your Space. Each item is in the form `{"key": ..., "value": ..., "description": ...}` where description is optional. For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets. variables (`List[Dict[str, str]]`, *optional*): A list of public environment variables to set in your Space. Each item is in the form `{"key": ..., "value": ..., "description": ...}` where description is optional. For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables. Returns: [`RepoUrl`]: URL to the newly created repo. Value is a subclass of `str` containing attributes like `endpoint`, `repo_type` and `repo_id`. Raises: [`~utils.RepositoryNotFoundError`]: If one of `from_id` or `to_id` cannot be found. This may be because it doesn't exist, or because it is set to `private` and you do not have access. [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): If the HuggingFace API returned an error Example: ```python >>> from huggingface_hub import duplicate_space # Duplicate a Space to your account >>> duplicate_space("multimodalart/dreambooth-training") RepoUrl('https://huggingface.co/spaces/nateraw/dreambooth-training',...) # Can set custom destination id and visibility flag. >>> duplicate_space("multimodalart/dreambooth-training", to_id="my-dreambooth", private=True) RepoUrl('https://huggingface.co/spaces/nateraw/my-dreambooth',...) ``` """ # Parse to_id if provided parsed_to_id = RepoUrl(to_id) if to_id is not None else None # Infer target repo_id to_namespace = ( # set namespace manually or default to username parsed_to_id.namespace if parsed_to_id is not None and parsed_to_id.namespace is not None else self.whoami(token)["name"] ) to_repo_name = parsed_to_id.repo_name if to_id is not None else RepoUrl(from_id).repo_name # type: ignore # repository must be a valid repo_id (namespace/repo_name). payload: Dict[str, Any] = {"repository": f"{to_namespace}/{to_repo_name}"} keys = ["private", "hardware", "storageTier", "sleepTimeSeconds", "secrets", "variables"] values = [private, hardware, storage, sleep_time, secrets, variables] payload.update({k: v for k, v in zip(keys, values) if v is not None}) if sleep_time is not None and hardware == SpaceHardware.CPU_BASIC: warnings.warn( "If your Space runs on the default 'cpu-basic' hardware, it will go to sleep if inactive for more" " than 48 hours. This value is not configurable. If you don't want your Space to deactivate or if" " you want to set a custom sleep time, you need to upgrade to a paid Hardware.", UserWarning, ) r = get_session().post( f"{self.endpoint}/api/spaces/{from_id}/duplicate", headers=self._build_hf_headers(token=token), json=payload, ) try: hf_raise_for_status(r) except HTTPError as err: if exist_ok and err.response.status_code == 409: # Repo already exists and `exist_ok=True` pass else: raise return RepoUrl(r.json()["url"], endpoint=self.endpoint) @validate_hf_hub_args def request_space_storage( self, repo_id: str, storage: SpaceStorage, *, token: Union[bool, str, None] = None, ) -> SpaceRuntime: """Request persistent storage for a Space. Args: repo_id (`str`): ID of the Space to update. Example: `"open-llm-leaderboard/open_llm_leaderboard"`. storage (`str` or [`SpaceStorage`]): Storage tier. Either 'small', 'medium', or 'large'. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware. It is not possible to decrease persistent storage after its granted. To do so, you must delete it via [`delete_space_storage`]. """ payload: Dict[str, SpaceStorage] = {"tier": storage} r = get_session().post( f"{self.endpoint}/api/spaces/{repo_id}/storage", headers=self._build_hf_headers(token=token), json=payload, ) hf_raise_for_status(r) return SpaceRuntime(r.json()) @validate_hf_hub_args def delete_space_storage( self, repo_id: str, *, token: Union[bool, str, None] = None, ) -> SpaceRuntime: """Delete persistent storage for a Space. Args: repo_id (`str`): ID of the Space to update. Example: `"open-llm-leaderboard/open_llm_leaderboard"`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware. Raises: [`BadRequestError`] If space has no persistent storage. """ r = get_session().delete( f"{self.endpoint}/api/spaces/{repo_id}/storage", headers=self._build_hf_headers(token=token), ) hf_raise_for_status(r) return SpaceRuntime(r.json()) ####################### # Inference Endpoints # ####################### def list_inference_endpoints( self, namespace: Optional[str] = None, *, token: Union[bool, str, None] = None ) -> List[InferenceEndpoint]: """Lists all inference endpoints for the given namespace. Args: namespace (`str`, *optional*): The namespace to list endpoints for. Defaults to the current user. Set to `"*"` to list all endpoints from all namespaces (i.e. personal namespace and all orgs the user belongs to). token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: List[`InferenceEndpoint`]: A list of all inference endpoints for the given namespace. Example: ```python >>> from huggingface_hub import HfApi >>> api = HfApi() >>> api.list_inference_endpoints() [InferenceEndpoint(name='my-endpoint', ...), ...] ``` """ # Special case: list all endpoints for all namespaces the user has access to if namespace == "*": user = self.whoami(token=token) # List personal endpoints first endpoints: List[InferenceEndpoint] = list_inference_endpoints(namespace=self._get_namespace(token=token)) # Then list endpoints for all orgs the user belongs to and ignore 401 errors (no billing or no access) for org in user.get("orgs", []): try: endpoints += list_inference_endpoints(namespace=org["name"], token=token) except HfHubHTTPError as error: if error.response.status_code == 401: # Either no billing or user don't have access) logger.debug("Cannot list Inference Endpoints for org '%s': %s", org["name"], error) pass return endpoints # Normal case: list endpoints for a specific namespace namespace = namespace or self._get_namespace(token=token) response = get_session().get( f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}", headers=self._build_hf_headers(token=token), ) hf_raise_for_status(response) return [ InferenceEndpoint.from_raw(endpoint, namespace=namespace, token=token) for endpoint in response.json()["items"] ] def create_inference_endpoint( self, name: str, *, repository: str, framework: str, accelerator: str, instance_size: str, instance_type: str, region: str, vendor: str, account_id: Optional[str] = None, min_replica: int = 0, max_replica: int = 1, scale_to_zero_timeout: int = 15, revision: Optional[str] = None, task: Optional[str] = None, custom_image: Optional[Dict] = None, env: Optional[Dict[str, str]] = None, secrets: Optional[Dict[str, str]] = None, type: InferenceEndpointType = InferenceEndpointType.PROTECTED, domain: Optional[str] = None, path: Optional[str] = None, cache_http_responses: Optional[bool] = None, tags: Optional[List[str]] = None, namespace: Optional[str] = None, token: Union[bool, str, None] = None, ) -> InferenceEndpoint: """Create a new Inference Endpoint. Args: name (`str`): The unique name for the new Inference Endpoint. repository (`str`): The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`). framework (`str`): The machine learning framework used for the model (e.g. `"custom"`). accelerator (`str`): The hardware accelerator to be used for inference (e.g. `"cpu"`). instance_size (`str`): The size or type of the instance to be used for hosting the model (e.g. `"x4"`). instance_type (`str`): The cloud instance type where the Inference Endpoint will be deployed (e.g. `"intel-icl"`). region (`str`): The cloud region in which the Inference Endpoint will be created (e.g. `"us-east-1"`). vendor (`str`): The cloud provider or vendor where the Inference Endpoint will be hosted (e.g. `"aws"`). account_id (`str`, *optional*): The account ID used to link a VPC to a private Inference Endpoint (if applicable). min_replica (`int`, *optional*): The minimum number of replicas (instances) to keep running for the Inference Endpoint. Defaults to 0. max_replica (`int`, *optional*): The maximum number of replicas (instances) to scale to for the Inference Endpoint. Defaults to 1. scale_to_zero_timeout (`int`, *optional*): The duration in minutes before an inactive endpoint is scaled to zero. Defaults to 15. revision (`str`, *optional*): The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`). task (`str`, *optional*): The task on which to deploy the model (e.g. `"text-classification"`). custom_image (`Dict`, *optional*): A custom Docker image to use for the Inference Endpoint. This is useful if you want to deploy an Inference Endpoint running on the `text-generation-inference` (TGI) framework (see examples). env (`Dict[str, str]`, *optional*): Non-secret environment variables to inject in the container environment. secrets (`Dict[str, str]`, *optional*): Secret values to inject in the container environment. type ([`InferenceEndpointType]`, *optional*): The type of the Inference Endpoint, which can be `"protected"` (default), `"public"` or `"private"`. domain (`str`, *optional*): The custom domain for the Inference Endpoint deployment, if setup the inference endpoint will be available at this domain (e.g. `"my-new-domain.cool-website.woof"`). path (`str`, *optional*): The custom path to the deployed model, should start with a `/` (e.g. `"/models/google-bert/bert-base-uncased"`). cache_http_responses (`bool`, *optional*): Whether to cache HTTP responses from the Inference Endpoint. Defaults to `False`. tags (`List[str]`, *optional*): A list of tags to associate with the Inference Endpoint. namespace (`str`, *optional*): The namespace where the Inference Endpoint will be created. Defaults to the current user's namespace. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`InferenceEndpoint`]: information about the updated Inference Endpoint. Example: ```python >>> from huggingface_hub import HfApi >>> api = HfApi() >>> endpoint = api.create_inference_endpoint( ... "my-endpoint-name", ... repository="gpt2", ... framework="pytorch", ... task="text-generation", ... accelerator="cpu", ... vendor="aws", ... region="us-east-1", ... type="protected", ... instance_size="x2", ... instance_type="intel-icl", ... ) >>> endpoint InferenceEndpoint(name='my-endpoint-name', status="pending",...) # Run inference on the endpoint >>> endpoint.client.text_generation(...) "..." ``` ```python # Start an Inference Endpoint running Zephyr-7b-beta on TGI >>> from huggingface_hub import HfApi >>> api = HfApi() >>> endpoint = api.create_inference_endpoint( ... "aws-zephyr-7b-beta-0486", ... repository="HuggingFaceH4/zephyr-7b-beta", ... framework="pytorch", ... task="text-generation", ... accelerator="gpu", ... vendor="aws", ... region="us-east-1", ... type="protected", ... instance_size="x1", ... instance_type="nvidia-a10g", ... env={ ... "MAX_BATCH_PREFILL_TOKENS": "2048", ... "MAX_INPUT_LENGTH": "1024", ... "MAX_TOTAL_TOKENS": "1512", ... "MODEL_ID": "/repository" ... }, ... custom_image={ ... "health_route": "/health", ... "url": "ghcr.io/huggingface/text-generation-inference:1.1.0", ... }, ... secrets={"MY_SECRET_KEY": "secret_value"}, ... tags=["dev", "text-generation"], ... ) ``` """ namespace = namespace or self._get_namespace(token=token) image = {"custom": custom_image} if custom_image is not None else {"huggingface": {}} payload: Dict = { "accountId": account_id, "compute": { "accelerator": accelerator, "instanceSize": instance_size, "instanceType": instance_type, "scaling": { "maxReplica": max_replica, "minReplica": min_replica, "scaleToZeroTimeout": scale_to_zero_timeout, }, }, "model": { "framework": framework, "repository": repository, "revision": revision, "task": task, "image": image, }, "name": name, "provider": { "region": region, "vendor": vendor, }, "type": type, } if env: payload["model"]["env"] = env if secrets: payload["model"]["secrets"] = secrets if domain is not None or path is not None: payload["route"] = {} if domain is not None: payload["route"]["domain"] = domain if path is not None: payload["route"]["path"] = path if cache_http_responses is not None: payload["cacheHttpResponses"] = cache_http_responses if tags is not None: payload["tags"] = tags response = get_session().post( f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}", headers=self._build_hf_headers(token=token), json=payload, ) hf_raise_for_status(response) return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) @experimental @validate_hf_hub_args def create_inference_endpoint_from_catalog( self, repo_id: str, *, name: Optional[str] = None, token: Union[bool, str, None] = None, namespace: Optional[str] = None, ) -> InferenceEndpoint: """Create a new Inference Endpoint from a model in the Hugging Face Inference Catalog. The goal of the Inference Catalog is to provide a curated list of models that are optimized for inference and for which default configurations have been tested. See https://endpoints.huggingface.co/catalog for a list of available models in the catalog. Args: repo_id (`str`): The ID of the model in the catalog to deploy as an Inference Endpoint. name (`str`, *optional*): The unique name for the new Inference Endpoint. If not provided, a random name will be generated. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). namespace (`str`, *optional*): The namespace where the Inference Endpoint will be created. Defaults to the current user's namespace. Returns: [`InferenceEndpoint`]: information about the new Inference Endpoint. `create_inference_endpoint_from_catalog` is experimental. Its API is subject to change in the future. Please provide feedback if you have any suggestions or requests. """ token = token or self.token or get_token() payload: Dict = { "namespace": namespace or self._get_namespace(token=token), "repoId": repo_id, } if name is not None: payload["endpointName"] = name response = get_session().post( f"{constants.INFERENCE_CATALOG_ENDPOINT}/deploy", headers=self._build_hf_headers(token=token), json=payload, ) hf_raise_for_status(response) data = response.json()["endpoint"] return InferenceEndpoint.from_raw(data, namespace=data["name"], token=token) @experimental @validate_hf_hub_args def list_inference_catalog(self, *, token: Union[bool, str, None] = None) -> List[str]: """List models available in the Hugging Face Inference Catalog. The goal of the Inference Catalog is to provide a curated list of models that are optimized for inference and for which default configurations have been tested. See https://endpoints.huggingface.co/catalog for a list of available models in the catalog. Use [`create_inference_endpoint_from_catalog`] to deploy a model from the catalog. Args: token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). Returns: List[`str`]: A list of model IDs available in the catalog. `list_inference_catalog` is experimental. Its API is subject to change in the future. Please provide feedback if you have any suggestions or requests. """ response = get_session().get( f"{constants.INFERENCE_CATALOG_ENDPOINT}/repo-list", headers=self._build_hf_headers(token=token), ) hf_raise_for_status(response) return response.json()["models"] def get_inference_endpoint( self, name: str, *, namespace: Optional[str] = None, token: Union[bool, str, None] = None ) -> InferenceEndpoint: """Get information about an Inference Endpoint. Args: name (`str`): The name of the Inference Endpoint to retrieve information about. namespace (`str`, *optional*): The namespace in which the Inference Endpoint is located. Defaults to the current user. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`InferenceEndpoint`]: information about the requested Inference Endpoint. Example: ```python >>> from huggingface_hub import HfApi >>> api = HfApi() >>> endpoint = api.get_inference_endpoint("my-text-to-image") >>> endpoint InferenceEndpoint(name='my-text-to-image', ...) # Get status >>> endpoint.status 'running' >>> endpoint.url 'https://my-text-to-image.region.vendor.endpoints.huggingface.cloud' # Run inference >>> endpoint.client.text_to_image(...) ``` """ namespace = namespace or self._get_namespace(token=token) response = get_session().get( f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}", headers=self._build_hf_headers(token=token), ) hf_raise_for_status(response) return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) def update_inference_endpoint( self, name: str, *, # Compute update accelerator: Optional[str] = None, instance_size: Optional[str] = None, instance_type: Optional[str] = None, min_replica: Optional[int] = None, max_replica: Optional[int] = None, scale_to_zero_timeout: Optional[int] = None, # Model update repository: Optional[str] = None, framework: Optional[str] = None, revision: Optional[str] = None, task: Optional[str] = None, custom_image: Optional[Dict] = None, env: Optional[Dict[str, str]] = None, secrets: Optional[Dict[str, str]] = None, # Route update domain: Optional[str] = None, path: Optional[str] = None, # Other cache_http_responses: Optional[bool] = None, tags: Optional[List[str]] = None, namespace: Optional[str] = None, token: Union[bool, str, None] = None, ) -> InferenceEndpoint: """Update an Inference Endpoint. This method allows the update of either the compute configuration, the deployed model, the route, or any combination. All arguments are optional but at least one must be provided. For convenience, you can also update an Inference Endpoint using [`InferenceEndpoint.update`]. Args: name (`str`): The name of the Inference Endpoint to update. accelerator (`str`, *optional*): The hardware accelerator to be used for inference (e.g. `"cpu"`). instance_size (`str`, *optional*): The size or type of the instance to be used for hosting the model (e.g. `"x4"`). instance_type (`str`, *optional*): The cloud instance type where the Inference Endpoint will be deployed (e.g. `"intel-icl"`). min_replica (`int`, *optional*): The minimum number of replicas (instances) to keep running for the Inference Endpoint. max_replica (`int`, *optional*): The maximum number of replicas (instances) to scale to for the Inference Endpoint. scale_to_zero_timeout (`int`, *optional*): The duration in minutes before an inactive endpoint is scaled to zero. repository (`str`, *optional*): The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`). framework (`str`, *optional*): The machine learning framework used for the model (e.g. `"custom"`). revision (`str`, *optional*): The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`). task (`str`, *optional*): The task on which to deploy the model (e.g. `"text-classification"`). custom_image (`Dict`, *optional*): A custom Docker image to use for the Inference Endpoint. This is useful if you want to deploy an Inference Endpoint running on the `text-generation-inference` (TGI) framework (see examples). env (`Dict[str, str]`, *optional*): Non-secret environment variables to inject in the container environment secrets (`Dict[str, str]`, *optional*): Secret values to inject in the container environment. domain (`str`, *optional*): The custom domain for the Inference Endpoint deployment, if setup the inference endpoint will be available at this domain (e.g. `"my-new-domain.cool-website.woof"`). path (`str`, *optional*): The custom path to the deployed model, should start with a `/` (e.g. `"/models/google-bert/bert-base-uncased"`). cache_http_responses (`bool`, *optional*): Whether to cache HTTP responses from the Inference Endpoint. tags (`List[str]`, *optional*): A list of tags to associate with the Inference Endpoint. namespace (`str`, *optional*): The namespace where the Inference Endpoint will be updated. Defaults to the current user's namespace. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`InferenceEndpoint`]: information about the updated Inference Endpoint. """ namespace = namespace or self._get_namespace(token=token) # Populate only the fields that are not None payload: Dict = defaultdict(lambda: defaultdict(dict)) if accelerator is not None: payload["compute"]["accelerator"] = accelerator if instance_size is not None: payload["compute"]["instanceSize"] = instance_size if instance_type is not None: payload["compute"]["instanceType"] = instance_type if max_replica is not None: payload["compute"]["scaling"]["maxReplica"] = max_replica if min_replica is not None: payload["compute"]["scaling"]["minReplica"] = min_replica if scale_to_zero_timeout is not None: payload["compute"]["scaling"]["scaleToZeroTimeout"] = scale_to_zero_timeout if repository is not None: payload["model"]["repository"] = repository if framework is not None: payload["model"]["framework"] = framework if revision is not None: payload["model"]["revision"] = revision if task is not None: payload["model"]["task"] = task if custom_image is not None: payload["model"]["image"] = {"custom": custom_image} if env is not None: payload["model"]["env"] = env if secrets is not None: payload["model"]["secrets"] = secrets if domain is not None: payload["route"]["domain"] = domain if path is not None: payload["route"]["path"] = path if cache_http_responses is not None: payload["cacheHttpResponses"] = cache_http_responses if tags is not None: payload["tags"] = tags response = get_session().put( f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}", headers=self._build_hf_headers(token=token), json=payload, ) hf_raise_for_status(response) return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) def delete_inference_endpoint( self, name: str, *, namespace: Optional[str] = None, token: Union[bool, str, None] = None ) -> None: """Delete an Inference Endpoint. This operation is not reversible. If you don't want to be charged for an Inference Endpoint, it is preferable to pause it with [`pause_inference_endpoint`] or scale it to zero with [`scale_to_zero_inference_endpoint`]. For convenience, you can also delete an Inference Endpoint using [`InferenceEndpoint.delete`]. Args: name (`str`): The name of the Inference Endpoint to delete. namespace (`str`, *optional*): The namespace in which the Inference Endpoint is located. Defaults to the current user. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. """ namespace = namespace or self._get_namespace(token=token) response = get_session().delete( f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}", headers=self._build_hf_headers(token=token), ) hf_raise_for_status(response) def pause_inference_endpoint( self, name: str, *, namespace: Optional[str] = None, token: Union[bool, str, None] = None ) -> InferenceEndpoint: """Pause an Inference Endpoint. A paused Inference Endpoint will not be charged. It can be resumed at any time using [`resume_inference_endpoint`]. This is different than scaling the Inference Endpoint to zero with [`scale_to_zero_inference_endpoint`], which would be automatically restarted when a request is made to it. For convenience, you can also pause an Inference Endpoint using [`pause_inference_endpoint`]. Args: name (`str`): The name of the Inference Endpoint to pause. namespace (`str`, *optional*): The namespace in which the Inference Endpoint is located. Defaults to the current user. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`InferenceEndpoint`]: information about the paused Inference Endpoint. """ namespace = namespace or self._get_namespace(token=token) response = get_session().post( f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/pause", headers=self._build_hf_headers(token=token), ) hf_raise_for_status(response) return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) def resume_inference_endpoint( self, name: str, *, namespace: Optional[str] = None, running_ok: bool = True, token: Union[bool, str, None] = None, ) -> InferenceEndpoint: """Resume an Inference Endpoint. For convenience, you can also resume an Inference Endpoint using [`InferenceEndpoint.resume`]. Args: name (`str`): The name of the Inference Endpoint to resume. namespace (`str`, *optional*): The namespace in which the Inference Endpoint is located. Defaults to the current user. running_ok (`bool`, *optional*): If `True`, the method will not raise an error if the Inference Endpoint is already running. Defaults to `True`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`InferenceEndpoint`]: information about the resumed Inference Endpoint. """ namespace = namespace or self._get_namespace(token=token) response = get_session().post( f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/resume", headers=self._build_hf_headers(token=token), ) try: hf_raise_for_status(response) except HfHubHTTPError as error: # If already running (and it's ok), then fetch current status and return if running_ok and error.response.status_code == 400 and "already running" in error.response.text: return self.get_inference_endpoint(name, namespace=namespace, token=token) # Otherwise, raise the error raise return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) def scale_to_zero_inference_endpoint( self, name: str, *, namespace: Optional[str] = None, token: Union[bool, str, None] = None ) -> InferenceEndpoint: """Scale Inference Endpoint to zero. An Inference Endpoint scaled to zero will not be charged. It will be resume on the next request to it, with a cold start delay. This is different than pausing the Inference Endpoint with [`pause_inference_endpoint`], which would require a manual resume with [`resume_inference_endpoint`]. For convenience, you can also scale an Inference Endpoint to zero using [`InferenceEndpoint.scale_to_zero`]. Args: name (`str`): The name of the Inference Endpoint to scale to zero. namespace (`str`, *optional*): The namespace in which the Inference Endpoint is located. Defaults to the current user. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`InferenceEndpoint`]: information about the scaled-to-zero Inference Endpoint. """ namespace = namespace or self._get_namespace(token=token) response = get_session().post( f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/scale-to-zero", headers=self._build_hf_headers(token=token), ) hf_raise_for_status(response) return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) def _get_namespace(self, token: Union[bool, str, None] = None) -> str: """Get the default namespace for the current user.""" me = self.whoami(token=token) if me["type"] == "user": return me["name"] else: raise ValueError( "Cannot determine default namespace. You must provide a 'namespace' as input or be logged in as a" " user." ) ######################## # Collection Endpoints # ######################## @validate_hf_hub_args def list_collections( self, *, owner: Union[List[str], str, None] = None, item: Union[List[str], str, None] = None, sort: Optional[Literal["lastModified", "trending", "upvotes"]] = None, limit: Optional[int] = None, token: Union[bool, str, None] = None, ) -> Iterable[Collection]: """List collections on the Huggingface Hub, given some filters. When listing collections, the item list per collection is truncated to 4 items maximum. To retrieve all items from a collection, you must use [`get_collection`]. Args: owner (`List[str]` or `str`, *optional*): Filter by owner's username. item (`List[str]` or `str`, *optional*): Filter collections containing a particular items. Example: `"models/teknium/OpenHermes-2.5-Mistral-7B"`, `"datasets/squad"` or `"papers/2311.12983"`. sort (`Literal["lastModified", "trending", "upvotes"]`, *optional*): Sort collections by last modified, trending or upvotes. limit (`int`, *optional*): Maximum number of collections to be returned. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: `Iterable[Collection]`: an iterable of [`Collection`] objects. """ # Construct the API endpoint path = f"{self.endpoint}/api/collections" headers = self._build_hf_headers(token=token) params: Dict = {} if owner is not None: params.update({"owner": owner}) if item is not None: params.update({"item": item}) if sort is not None: params.update({"sort": sort}) if limit is not None: params.update({"limit": limit}) # Paginate over the results until limit is reached items = paginate(path, headers=headers, params=params) if limit is not None: items = islice(items, limit) # Do not iterate over all pages # Parse as Collection and return for position, collection_data in enumerate(items): yield Collection(position=position, **collection_data) def get_collection(self, collection_slug: str, *, token: Union[bool, str, None] = None) -> Collection: """Gets information about a Collection on the Hub. Args: collection_slug (`str`): Slug of the collection of the Hub. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`Collection`] Example: ```py >>> from huggingface_hub import get_collection >>> collection = get_collection("TheBloke/recent-models-64f9a55bb3115b4f513ec026") >>> collection.title 'Recent models' >>> len(collection.items) 37 >>> collection.items[0] CollectionItem( item_object_id='651446103cd773a050bf64c2', item_id='TheBloke/U-Amethyst-20B-AWQ', item_type='model', position=88, note=None ) ``` """ r = get_session().get( f"{self.endpoint}/api/collections/{collection_slug}", headers=self._build_hf_headers(token=token) ) hf_raise_for_status(r) return Collection(**{**r.json(), "endpoint": self.endpoint}) def create_collection( self, title: str, *, namespace: Optional[str] = None, description: Optional[str] = None, private: bool = False, exists_ok: bool = False, token: Union[bool, str, None] = None, ) -> Collection: """Create a new Collection on the Hub. Args: title (`str`): Title of the collection to create. Example: `"Recent models"`. namespace (`str`, *optional*): Namespace of the collection to create (username or org). Will default to the owner name. description (`str`, *optional*): Description of the collection to create. private (`bool`, *optional*): Whether the collection should be private or not. Defaults to `False` (i.e. public collection). exists_ok (`bool`, *optional*): If `True`, do not raise an error if collection already exists. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`Collection`] Example: ```py >>> from huggingface_hub import create_collection >>> collection = create_collection( ... title="ICCV 2023", ... description="Portfolio of models, papers and demos I presented at ICCV 2023", ... ) >>> collection.slug "username/iccv-2023-64f9a55bb3115b4f513ec026" ``` """ if namespace is None: namespace = self.whoami(token)["name"] payload = { "title": title, "namespace": namespace, "private": private, } if description is not None: payload["description"] = description r = get_session().post( f"{self.endpoint}/api/collections", headers=self._build_hf_headers(token=token), json=payload ) try: hf_raise_for_status(r) except HTTPError as err: if exists_ok and err.response.status_code == 409: # Collection already exists and `exists_ok=True` slug = r.json()["slug"] return self.get_collection(slug, token=token) else: raise return Collection(**{**r.json(), "endpoint": self.endpoint}) def update_collection_metadata( self, collection_slug: str, *, title: Optional[str] = None, description: Optional[str] = None, position: Optional[int] = None, private: Optional[bool] = None, theme: Optional[str] = None, token: Union[bool, str, None] = None, ) -> Collection: """Update metadata of a collection on the Hub. All arguments are optional. Only provided metadata will be updated. Args: collection_slug (`str`): Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. title (`str`): Title of the collection to update. description (`str`, *optional*): Description of the collection to update. position (`int`, *optional*): New position of the collection in the list of collections of the user. private (`bool`, *optional*): Whether the collection should be private or not. theme (`str`, *optional*): Theme of the collection on the Hub. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`Collection`] Example: ```py >>> from huggingface_hub import update_collection_metadata >>> collection = update_collection_metadata( ... collection_slug="username/iccv-2023-64f9a55bb3115b4f513ec026", ... title="ICCV Oct. 2023" ... description="Portfolio of models, datasets, papers and demos I presented at ICCV Oct. 2023", ... private=False, ... theme="pink", ... ) >>> collection.slug "username/iccv-oct-2023-64f9a55bb3115b4f513ec026" # ^collection slug got updated but not the trailing ID ``` """ payload = { "position": position, "private": private, "theme": theme, "title": title, "description": description, } r = get_session().patch( f"{self.endpoint}/api/collections/{collection_slug}", headers=self._build_hf_headers(token=token), # Only send not-none values to the API json={key: value for key, value in payload.items() if value is not None}, ) hf_raise_for_status(r) return Collection(**{**r.json()["data"], "endpoint": self.endpoint}) def delete_collection( self, collection_slug: str, *, missing_ok: bool = False, token: Union[bool, str, None] = None ) -> None: """Delete a collection on the Hub. Args: collection_slug (`str`): Slug of the collection to delete. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. missing_ok (`bool`, *optional*): If `True`, do not raise an error if collection doesn't exists. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Example: ```py >>> from huggingface_hub import delete_collection >>> collection = delete_collection("username/useless-collection-64f9a55bb3115b4f513ec026", missing_ok=True) ``` This is a non-revertible action. A deleted collection cannot be restored. """ r = get_session().delete( f"{self.endpoint}/api/collections/{collection_slug}", headers=self._build_hf_headers(token=token) ) try: hf_raise_for_status(r) except HTTPError as err: if missing_ok and err.response.status_code == 404: # Collection doesn't exists and `missing_ok=True` return else: raise def add_collection_item( self, collection_slug: str, item_id: str, item_type: CollectionItemType_T, *, note: Optional[str] = None, exists_ok: bool = False, token: Union[bool, str, None] = None, ) -> Collection: """Add an item to a collection on the Hub. Args: collection_slug (`str`): Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. item_id (`str`): ID of the item to add to the collection. It can be the ID of a repo on the Hub (e.g. `"facebook/bart-large-mnli"`) or a paper id (e.g. `"2307.09288"`). item_type (`str`): Type of the item to add. Can be one of `"model"`, `"dataset"`, `"space"` or `"paper"`. note (`str`, *optional*): A note to attach to the item in the collection. The maximum size for a note is 500 characters. exists_ok (`bool`, *optional*): If `True`, do not raise an error if item already exists. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`Collection`] Raises: [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` or `admin` role in the organization the repo belongs to or if you passed a `read` token. [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 404 if the item you try to add to the collection does not exist on the Hub. [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 409 if the item you try to add to the collection is already in the collection (and exists_ok=False) Example: ```py >>> from huggingface_hub import add_collection_item >>> collection = add_collection_item( ... collection_slug="davanstrien/climate-64f99dc2a5067f6b65531bab", ... item_id="pierre-loic/climate-news-articles", ... item_type="dataset" ... ) >>> collection.items[-1].item_id "pierre-loic/climate-news-articles" # ^item got added to the collection on last position # Add item with a note >>> add_collection_item( ... collection_slug="davanstrien/climate-64f99dc2a5067f6b65531bab", ... item_id="datasets/climate_fever", ... item_type="dataset" ... note="This dataset adopts the FEVER methodology that consists of 1,535 real-world claims regarding climate-change collected on the internet." ... ) (...) ``` """ payload: Dict[str, Any] = {"item": {"id": item_id, "type": item_type}} if note is not None: payload["note"] = note r = get_session().post( f"{self.endpoint}/api/collections/{collection_slug}/items", headers=self._build_hf_headers(token=token), json=payload, ) try: hf_raise_for_status(r) except HTTPError as err: if exists_ok and err.response.status_code == 409: # Item already exists and `exists_ok=True` return self.get_collection(collection_slug, token=token) else: raise return Collection(**{**r.json(), "endpoint": self.endpoint}) def update_collection_item( self, collection_slug: str, item_object_id: str, *, note: Optional[str] = None, position: Optional[int] = None, token: Union[bool, str, None] = None, ) -> None: """Update an item in a collection. Args: collection_slug (`str`): Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. item_object_id (`str`): ID of the item in the collection. This is not the id of the item on the Hub (repo_id or paper id). It must be retrieved from a [`CollectionItem`] object. Example: `collection.items[0].item_object_id`. note (`str`, *optional*): A note to attach to the item in the collection. The maximum size for a note is 500 characters. position (`int`, *optional*): New position of the item in the collection. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Example: ```py >>> from huggingface_hub import get_collection, update_collection_item # Get collection first >>> collection = get_collection("TheBloke/recent-models-64f9a55bb3115b4f513ec026") # Update item based on its ID (add note + update position) >>> update_collection_item( ... collection_slug="TheBloke/recent-models-64f9a55bb3115b4f513ec026", ... item_object_id=collection.items[-1].item_object_id, ... note="Newly updated model!" ... position=0, ... ) ``` """ payload = {"position": position, "note": note} r = get_session().patch( f"{self.endpoint}/api/collections/{collection_slug}/items/{item_object_id}", headers=self._build_hf_headers(token=token), # Only send not-none values to the API json={key: value for key, value in payload.items() if value is not None}, ) hf_raise_for_status(r) def delete_collection_item( self, collection_slug: str, item_object_id: str, *, missing_ok: bool = False, token: Union[bool, str, None] = None, ) -> None: """Delete an item from a collection. Args: collection_slug (`str`): Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. item_object_id (`str`): ID of the item in the collection. This is not the id of the item on the Hub (repo_id or paper id). It must be retrieved from a [`CollectionItem`] object. Example: `collection.items[0].item_object_id`. missing_ok (`bool`, *optional*): If `True`, do not raise an error if item doesn't exists. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Example: ```py >>> from huggingface_hub import get_collection, delete_collection_item # Get collection first >>> collection = get_collection("TheBloke/recent-models-64f9a55bb3115b4f513ec026") # Delete item based on its ID >>> delete_collection_item( ... collection_slug="TheBloke/recent-models-64f9a55bb3115b4f513ec026", ... item_object_id=collection.items[-1].item_object_id, ... ) ``` """ r = get_session().delete( f"{self.endpoint}/api/collections/{collection_slug}/items/{item_object_id}", headers=self._build_hf_headers(token=token), ) try: hf_raise_for_status(r) except HTTPError as err: if missing_ok and err.response.status_code == 404: # Item already deleted and `missing_ok=True` return else: raise ########################## # Manage access requests # ########################## @validate_hf_hub_args def list_pending_access_requests( self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None ) -> List[AccessRequest]: """ Get pending access requests for a given gated repo. A pending request means the user has requested access to the repo but the request has not been processed yet. If the approval mode is automatic, this list should be empty. Pending requests can be accepted or rejected using [`accept_access_request`] and [`reject_access_request`]. For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. Args: repo_id (`str`): The id of the repo to get access requests for. repo_type (`str`, *optional*): The type of the repo to get access requests for. Must be one of `model`, `dataset` or `space`. Defaults to `model`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: `List[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`, `status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will be populated with user's answers. Raises: [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 400 if the repo is not gated. [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` or `admin` role in the organization the repo belongs to or if you passed a `read` token. Example: ```py >>> from huggingface_hub import list_pending_access_requests, accept_access_request # List pending requests >>> requests = list_pending_access_requests("meta-llama/Llama-2-7b") >>> len(requests) 411 >>> requests[0] [ AccessRequest( username='clem', fullname='Clem 🤗', email='***', timestamp=datetime.datetime(2023, 11, 23, 18, 4, 53, 828000, tzinfo=datetime.timezone.utc), status='pending', fields=None, ), ... ] # Accept Clem's request >>> accept_access_request("meta-llama/Llama-2-7b", "clem") ``` """ return self._list_access_requests(repo_id, "pending", repo_type=repo_type, token=token) @validate_hf_hub_args def list_accepted_access_requests( self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None ) -> List[AccessRequest]: """ Get accepted access requests for a given gated repo. An accepted request means the user has requested access to the repo and the request has been accepted. The user can download any file of the repo. If the approval mode is automatic, this list should contains by default all requests. Accepted requests can be cancelled or rejected at any time using [`cancel_access_request`] and [`reject_access_request`]. A cancelled request will go back to the pending list while a rejected request will go to the rejected list. In both cases, the user will lose access to the repo. For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. Args: repo_id (`str`): The id of the repo to get access requests for. repo_type (`str`, *optional*): The type of the repo to get access requests for. Must be one of `model`, `dataset` or `space`. Defaults to `model`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: `List[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`, `status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will be populated with user's answers. Raises: [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 400 if the repo is not gated. [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` or `admin` role in the organization the repo belongs to or if you passed a `read` token. Example: ```py >>> from huggingface_hub import list_accepted_access_requests >>> requests = list_accepted_access_requests("meta-llama/Llama-2-7b") >>> len(requests) 411 >>> requests[0] [ AccessRequest( username='clem', fullname='Clem 🤗', email='***', timestamp=datetime.datetime(2023, 11, 23, 18, 4, 53, 828000, tzinfo=datetime.timezone.utc), status='accepted', fields=None, ), ... ] ``` """ return self._list_access_requests(repo_id, "accepted", repo_type=repo_type, token=token) @validate_hf_hub_args def list_rejected_access_requests( self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None ) -> List[AccessRequest]: """ Get rejected access requests for a given gated repo. A rejected request means the user has requested access to the repo and the request has been explicitly rejected by a repo owner (either you or another user from your organization). The user cannot download any file of the repo. Rejected requests can be accepted or cancelled at any time using [`accept_access_request`] and [`cancel_access_request`]. A cancelled request will go back to the pending list while an accepted request will go to the accepted list. For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. Args: repo_id (`str`): The id of the repo to get access requests for. repo_type (`str`, *optional*): The type of the repo to get access requests for. Must be one of `model`, `dataset` or `space`. Defaults to `model`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: `List[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`, `status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will be populated with user's answers. Raises: [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 400 if the repo is not gated. [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` or `admin` role in the organization the repo belongs to or if you passed a `read` token. Example: ```py >>> from huggingface_hub import list_rejected_access_requests >>> requests = list_rejected_access_requests("meta-llama/Llama-2-7b") >>> len(requests) 411 >>> requests[0] [ AccessRequest( username='clem', fullname='Clem 🤗', email='***', timestamp=datetime.datetime(2023, 11, 23, 18, 4, 53, 828000, tzinfo=datetime.timezone.utc), status='rejected', fields=None, ), ... ] ``` """ return self._list_access_requests(repo_id, "rejected", repo_type=repo_type, token=token) def _list_access_requests( self, repo_id: str, status: Literal["accepted", "rejected", "pending"], repo_type: Optional[str] = None, token: Union[bool, str, None] = None, ) -> List[AccessRequest]: if repo_type not in constants.REPO_TYPES: raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") if repo_type is None: repo_type = constants.REPO_TYPE_MODEL response = get_session().get( f"{constants.ENDPOINT}/api/{repo_type}s/{repo_id}/user-access-request/{status}", headers=self._build_hf_headers(token=token), ) hf_raise_for_status(response) return [ AccessRequest( username=request["user"]["user"], fullname=request["user"]["fullname"], email=request["user"].get("email"), status=request["status"], timestamp=parse_datetime(request["timestamp"]), fields=request.get("fields"), # only if custom fields in form ) for request in response.json() ] @validate_hf_hub_args def cancel_access_request( self, repo_id: str, user: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None ) -> None: """ Cancel an access request from a user for a given gated repo. A cancelled request will go back to the pending list and the user will lose access to the repo. For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. Args: repo_id (`str`): The id of the repo to cancel access request for. user (`str`): The username of the user which access request should be cancelled. repo_type (`str`, *optional*): The type of the repo to cancel access request for. Must be one of `model`, `dataset` or `space`. Defaults to `model`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Raises: [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 400 if the repo is not gated. [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` or `admin` role in the organization the repo belongs to or if you passed a `read` token. [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 404 if the user does not exist on the Hub. [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 404 if the user access request cannot be found. [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 404 if the user access request is already in the pending list. """ self._handle_access_request(repo_id, user, "pending", repo_type=repo_type, token=token) @validate_hf_hub_args def accept_access_request( self, repo_id: str, user: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None ) -> None: """ Accept an access request from a user for a given gated repo. Once the request is accepted, the user will be able to download any file of the repo and access the community tab. If the approval mode is automatic, you don't have to accept requests manually. An accepted request can be cancelled or rejected at any time using [`cancel_access_request`] and [`reject_access_request`]. For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. Args: repo_id (`str`): The id of the repo to accept access request for. user (`str`): The username of the user which access request should be accepted. repo_type (`str`, *optional*): The type of the repo to accept access request for. Must be one of `model`, `dataset` or `space`. Defaults to `model`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Raises: [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 400 if the repo is not gated. [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` or `admin` role in the organization the repo belongs to or if you passed a `read` token. [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 404 if the user does not exist on the Hub. [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 404 if the user access request cannot be found. [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 404 if the user access request is already in the accepted list. """ self._handle_access_request(repo_id, user, "accepted", repo_type=repo_type, token=token) @validate_hf_hub_args def reject_access_request( self, repo_id: str, user: str, *, repo_type: Optional[str] = None, rejection_reason: Optional[str], token: Union[bool, str, None] = None, ) -> None: """ Reject an access request from a user for a given gated repo. A rejected request will go to the rejected list. The user cannot download any file of the repo. Rejected requests can be accepted or cancelled at any time using [`accept_access_request`] and [`cancel_access_request`]. A cancelled request will go back to the pending list while an accepted request will go to the accepted list. For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. Args: repo_id (`str`): The id of the repo to reject access request for. user (`str`): The username of the user which access request should be rejected. repo_type (`str`, *optional*): The type of the repo to reject access request for. Must be one of `model`, `dataset` or `space`. Defaults to `model`. rejection_reason (`str`, *optional*): Optional rejection reason that will be visible to the user (max 200 characters). token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Raises: [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 400 if the repo is not gated. [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` or `admin` role in the organization the repo belongs to or if you passed a `read` token. [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 404 if the user does not exist on the Hub. [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 404 if the user access request cannot be found. [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 404 if the user access request is already in the rejected list. """ self._handle_access_request( repo_id, user, "rejected", repo_type=repo_type, rejection_reason=rejection_reason, token=token ) @validate_hf_hub_args def _handle_access_request( self, repo_id: str, user: str, status: Literal["accepted", "rejected", "pending"], repo_type: Optional[str] = None, rejection_reason: Optional[str] = None, token: Union[bool, str, None] = None, ) -> None: if repo_type not in constants.REPO_TYPES: raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") if repo_type is None: repo_type = constants.REPO_TYPE_MODEL payload = {"user": user, "status": status} if rejection_reason is not None: if status != "rejected": raise ValueError("`rejection_reason` can only be passed when rejecting an access request.") payload["rejectionReason"] = rejection_reason response = get_session().post( f"{constants.ENDPOINT}/api/{repo_type}s/{repo_id}/user-access-request/handle", headers=self._build_hf_headers(token=token), json=payload, ) hf_raise_for_status(response) @validate_hf_hub_args def grant_access( self, repo_id: str, user: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None ) -> None: """ Grant access to a user for a given gated repo. Granting access don't require for the user to send an access request by themselves. The user is automatically added to the accepted list meaning they can download the files You can revoke the granted access at any time using [`cancel_access_request`] or [`reject_access_request`]. For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. Args: repo_id (`str`): The id of the repo to grant access to. user (`str`): The username of the user to grant access. repo_type (`str`, *optional*): The type of the repo to grant access to. Must be one of `model`, `dataset` or `space`. Defaults to `model`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Raises: [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 400 if the repo is not gated. [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 400 if the user already has access to the repo. [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` or `admin` role in the organization the repo belongs to or if you passed a `read` token. [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 404 if the user does not exist on the Hub. """ if repo_type not in constants.REPO_TYPES: raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") if repo_type is None: repo_type = constants.REPO_TYPE_MODEL response = get_session().post( f"{constants.ENDPOINT}/api/{repo_type}s/{repo_id}/user-access-request/grant", headers=self._build_hf_headers(token=token), json={"user": user}, ) hf_raise_for_status(response) return response.json() ################### # Manage webhooks # ################### @validate_hf_hub_args def get_webhook(self, webhook_id: str, *, token: Union[bool, str, None] = None) -> WebhookInfo: """Get a webhook by its id. Args: webhook_id (`str`): The unique identifier of the webhook to get. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`WebhookInfo`]: Info about the webhook. Example: ```python >>> from huggingface_hub import get_webhook >>> webhook = get_webhook("654bbbc16f2ec14d77f109cc") >>> print(webhook) WebhookInfo( id="654bbbc16f2ec14d77f109cc", watched=[WebhookWatchedItem(type="user", name="julien-c"), WebhookWatchedItem(type="org", name="HuggingFaceH4")], url="https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", secret="my-secret", domains=["repo", "discussion"], disabled=False, ) ``` """ response = get_session().get( f"{constants.ENDPOINT}/api/settings/webhooks/{webhook_id}", headers=self._build_hf_headers(token=token), ) hf_raise_for_status(response) webhook_data = response.json()["webhook"] watched_items = [WebhookWatchedItem(type=item["type"], name=item["name"]) for item in webhook_data["watched"]] webhook = WebhookInfo( id=webhook_data["id"], url=webhook_data["url"], watched=watched_items, domains=webhook_data["domains"], secret=webhook_data.get("secret"), disabled=webhook_data["disabled"], ) return webhook @validate_hf_hub_args def list_webhooks(self, *, token: Union[bool, str, None] = None) -> List[WebhookInfo]: """List all configured webhooks. Args: token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: `List[WebhookInfo]`: List of webhook info objects. Example: ```python >>> from huggingface_hub import list_webhooks >>> webhooks = list_webhooks() >>> len(webhooks) 2 >>> webhooks[0] WebhookInfo( id="654bbbc16f2ec14d77f109cc", watched=[WebhookWatchedItem(type="user", name="julien-c"), WebhookWatchedItem(type="org", name="HuggingFaceH4")], url="https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", secret="my-secret", domains=["repo", "discussion"], disabled=False, ) ``` """ response = get_session().get( f"{constants.ENDPOINT}/api/settings/webhooks", headers=self._build_hf_headers(token=token), ) hf_raise_for_status(response) webhooks_data = response.json() return [ WebhookInfo( id=webhook["id"], url=webhook["url"], watched=[WebhookWatchedItem(type=item["type"], name=item["name"]) for item in webhook["watched"]], domains=webhook["domains"], secret=webhook.get("secret"), disabled=webhook["disabled"], ) for webhook in webhooks_data ] @validate_hf_hub_args def create_webhook( self, *, url: str, watched: List[Union[Dict, WebhookWatchedItem]], domains: Optional[List[constants.WEBHOOK_DOMAIN_T]] = None, secret: Optional[str] = None, token: Union[bool, str, None] = None, ) -> WebhookInfo: """Create a new webhook. Args: url (`str`): URL to send the payload to. watched (`List[WebhookWatchedItem]`): List of [`WebhookWatchedItem`] to be watched by the webhook. It can be users, orgs, models, datasets or spaces. Watched items can also be provided as plain dictionaries. domains (`List[Literal["repo", "discussion"]]`, optional): List of domains to watch. It can be "repo", "discussion" or both. secret (`str`, optional): A secret to sign the payload with. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`WebhookInfo`]: Info about the newly created webhook. Example: ```python >>> from huggingface_hub import create_webhook >>> payload = create_webhook( ... watched=[{"type": "user", "name": "julien-c"}, {"type": "org", "name": "HuggingFaceH4"}], ... url="https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", ... domains=["repo", "discussion"], ... secret="my-secret", ... ) >>> print(payload) WebhookInfo( id="654bbbc16f2ec14d77f109cc", url="https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", watched=[WebhookWatchedItem(type="user", name="julien-c"), WebhookWatchedItem(type="org", name="HuggingFaceH4")], domains=["repo", "discussion"], secret="my-secret", disabled=False, ) ``` """ watched_dicts = [asdict(item) if isinstance(item, WebhookWatchedItem) else item for item in watched] response = get_session().post( f"{constants.ENDPOINT}/api/settings/webhooks", json={"watched": watched_dicts, "url": url, "domains": domains, "secret": secret}, headers=self._build_hf_headers(token=token), ) hf_raise_for_status(response) webhook_data = response.json()["webhook"] watched_items = [WebhookWatchedItem(type=item["type"], name=item["name"]) for item in webhook_data["watched"]] webhook = WebhookInfo( id=webhook_data["id"], url=webhook_data["url"], watched=watched_items, domains=webhook_data["domains"], secret=webhook_data.get("secret"), disabled=webhook_data["disabled"], ) return webhook @validate_hf_hub_args def update_webhook( self, webhook_id: str, *, url: Optional[str] = None, watched: Optional[List[Union[Dict, WebhookWatchedItem]]] = None, domains: Optional[List[constants.WEBHOOK_DOMAIN_T]] = None, secret: Optional[str] = None, token: Union[bool, str, None] = None, ) -> WebhookInfo: """Update an existing webhook. Args: webhook_id (`str`): The unique identifier of the webhook to be updated. url (`str`, optional): The URL to which the payload will be sent. watched (`List[WebhookWatchedItem]`, optional): List of items to watch. It can be users, orgs, models, datasets, or spaces. Refer to [`WebhookWatchedItem`] for more details. Watched items can also be provided as plain dictionaries. domains (`List[Literal["repo", "discussion"]]`, optional): The domains to watch. This can include "repo", "discussion", or both. secret (`str`, optional): A secret to sign the payload with, providing an additional layer of security. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`WebhookInfo`]: Info about the updated webhook. Example: ```python >>> from huggingface_hub import update_webhook >>> updated_payload = update_webhook( ... webhook_id="654bbbc16f2ec14d77f109cc", ... url="https://new.webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", ... watched=[{"type": "user", "name": "julien-c"}, {"type": "org", "name": "HuggingFaceH4"}], ... domains=["repo"], ... secret="my-secret", ... ) >>> print(updated_payload) WebhookInfo( id="654bbbc16f2ec14d77f109cc", url="https://new.webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", watched=[WebhookWatchedItem(type="user", name="julien-c"), WebhookWatchedItem(type="org", name="HuggingFaceH4")], domains=["repo"], secret="my-secret", disabled=False, ``` """ if watched is None: watched = [] watched_dicts = [asdict(item) if isinstance(item, WebhookWatchedItem) else item for item in watched] response = get_session().post( f"{constants.ENDPOINT}/api/settings/webhooks/{webhook_id}", json={"watched": watched_dicts, "url": url, "domains": domains, "secret": secret}, headers=self._build_hf_headers(token=token), ) hf_raise_for_status(response) webhook_data = response.json()["webhook"] watched_items = [WebhookWatchedItem(type=item["type"], name=item["name"]) for item in webhook_data["watched"]] webhook = WebhookInfo( id=webhook_data["id"], url=webhook_data["url"], watched=watched_items, domains=webhook_data["domains"], secret=webhook_data.get("secret"), disabled=webhook_data["disabled"], ) return webhook @validate_hf_hub_args def enable_webhook(self, webhook_id: str, *, token: Union[bool, str, None] = None) -> WebhookInfo: """Enable a webhook (makes it "active"). Args: webhook_id (`str`): The unique identifier of the webhook to enable. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`WebhookInfo`]: Info about the enabled webhook. Example: ```python >>> from huggingface_hub import enable_webhook >>> enabled_webhook = enable_webhook("654bbbc16f2ec14d77f109cc") >>> enabled_webhook WebhookInfo( id="654bbbc16f2ec14d77f109cc", url="https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", watched=[WebhookWatchedItem(type="user", name="julien-c"), WebhookWatchedItem(type="org", name="HuggingFaceH4")], domains=["repo", "discussion"], secret="my-secret", disabled=False, ) ``` """ response = get_session().post( f"{constants.ENDPOINT}/api/settings/webhooks/{webhook_id}/enable", headers=self._build_hf_headers(token=token), ) hf_raise_for_status(response) webhook_data = response.json()["webhook"] watched_items = [WebhookWatchedItem(type=item["type"], name=item["name"]) for item in webhook_data["watched"]] webhook = WebhookInfo( id=webhook_data["id"], url=webhook_data["url"], watched=watched_items, domains=webhook_data["domains"], secret=webhook_data.get("secret"), disabled=webhook_data["disabled"], ) return webhook @validate_hf_hub_args def disable_webhook(self, webhook_id: str, *, token: Union[bool, str, None] = None) -> WebhookInfo: """Disable a webhook (makes it "disabled"). Args: webhook_id (`str`): The unique identifier of the webhook to disable. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: [`WebhookInfo`]: Info about the disabled webhook. Example: ```python >>> from huggingface_hub import disable_webhook >>> disabled_webhook = disable_webhook("654bbbc16f2ec14d77f109cc") >>> disabled_webhook WebhookInfo( id="654bbbc16f2ec14d77f109cc", url="https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", watched=[WebhookWatchedItem(type="user", name="julien-c"), WebhookWatchedItem(type="org", name="HuggingFaceH4")], domains=["repo", "discussion"], secret="my-secret", disabled=True, ) ``` """ response = get_session().post( f"{constants.ENDPOINT}/api/settings/webhooks/{webhook_id}/disable", headers=self._build_hf_headers(token=token), ) hf_raise_for_status(response) webhook_data = response.json()["webhook"] watched_items = [WebhookWatchedItem(type=item["type"], name=item["name"]) for item in webhook_data["watched"]] webhook = WebhookInfo( id=webhook_data["id"], url=webhook_data["url"], watched=watched_items, domains=webhook_data["domains"], secret=webhook_data.get("secret"), disabled=webhook_data["disabled"], ) return webhook @validate_hf_hub_args def delete_webhook(self, webhook_id: str, *, token: Union[bool, str, None] = None) -> None: """Delete a webhook. Args: webhook_id (`str`): The unique identifier of the webhook to delete. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: `None` Example: ```python >>> from huggingface_hub import delete_webhook >>> delete_webhook("654bbbc16f2ec14d77f109cc") ``` """ response = get_session().delete( f"{constants.ENDPOINT}/api/settings/webhooks/{webhook_id}", headers=self._build_hf_headers(token=token), ) hf_raise_for_status(response) ############# # Internals # ############# def _build_hf_headers( self, token: Union[bool, str, None] = None, library_name: Optional[str] = None, library_version: Optional[str] = None, user_agent: Union[Dict, str, None] = None, ) -> Dict[str, str]: """ Alias for [`build_hf_headers`] that uses the token from [`HfApi`] client when `token` is not provided. """ if token is None: # Cannot do `token = token or self.token` as token can be `False`. token = self.token return build_hf_headers( token=token, library_name=library_name or self.library_name, library_version=library_version or self.library_version, user_agent=user_agent or self.user_agent, headers=self.headers, ) def _prepare_folder_deletions( self, repo_id: str, repo_type: Optional[str], revision: Optional[str], path_in_repo: str, delete_patterns: Optional[Union[List[str], str]], token: Union[bool, str, None] = None, ) -> List[CommitOperationDelete]: """Generate the list of Delete operations for a commit to delete files from a repo. List remote files and match them against the `delete_patterns` constraints. Returns a list of [`CommitOperationDelete`] with the matching items. Note: `.gitattributes` file is essential to make a repo work properly on the Hub. This file will always be kept even if it matches the `delete_patterns` constraints. """ if delete_patterns is None: # If no delete patterns, no need to list and filter remote files return [] # List remote files filenames = self.list_repo_files(repo_id=repo_id, revision=revision, repo_type=repo_type, token=token) # Compute relative path in repo if path_in_repo and path_in_repo not in (".", "./"): path_in_repo = path_in_repo.strip("/") + "/" # harmonize relpath_to_abspath = { file[len(path_in_repo) :]: file for file in filenames if file.startswith(path_in_repo) } else: relpath_to_abspath = {file: file for file in filenames} # Apply filter on relative paths and return return [ CommitOperationDelete(path_in_repo=relpath_to_abspath[relpath], is_folder=False) for relpath in filter_repo_objects(relpath_to_abspath.keys(), allow_patterns=delete_patterns) if relpath_to_abspath[relpath] != ".gitattributes" ] def _prepare_upload_folder_additions( self, folder_path: Union[str, Path], path_in_repo: str, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None, token: Union[bool, str, None] = None, ) -> List[CommitOperationAdd]: """Generate the list of Add operations for a commit to upload a folder. Files not matching the `allow_patterns` (allowlist) and `ignore_patterns` (denylist) constraints are discarded. """ folder_path = Path(folder_path).expanduser().resolve() if not folder_path.is_dir(): raise ValueError(f"Provided path: '{folder_path}' is not a directory") # List files from folder relpath_to_abspath = { path.relative_to(folder_path).as_posix(): path for path in sorted(folder_path.glob("**/*")) # sorted to be deterministic if path.is_file() } # Filter files # Patterns are applied on the path relative to `folder_path`. `path_in_repo` is prefixed after the filtering. filtered_repo_objects = list( filter_repo_objects( relpath_to_abspath.keys(), allow_patterns=allow_patterns, ignore_patterns=ignore_patterns ) ) prefix = f"{path_in_repo.strip('/')}/" if path_in_repo else "" # If updating a README.md file, make sure the metadata format is valid # It's better to fail early than to fail after all the files have been hashed. if "README.md" in filtered_repo_objects: self._validate_yaml( content=relpath_to_abspath["README.md"].read_text(encoding="utf8"), repo_type=repo_type, token=token, ) if len(filtered_repo_objects) > 30: log = logger.warning if len(filtered_repo_objects) > 200 else logger.info log( "It seems you are trying to upload a large folder at once. This might take some time and then fail if " "the folder is too large. For such cases, it is recommended to upload in smaller batches or to use " "`HfApi().upload_large_folder(...)`/`huggingface-cli upload-large-folder` instead. For more details, " "check out https://huggingface.co/docs/huggingface_hub/main/en/guides/upload#upload-a-large-folder." ) logger.info(f"Start hashing {len(filtered_repo_objects)} files.") operations = [ CommitOperationAdd( path_or_fileobj=relpath_to_abspath[relpath], # absolute path on disk path_in_repo=prefix + relpath, # "absolute" path in repo ) for relpath in filtered_repo_objects ] logger.info(f"Finished hashing {len(filtered_repo_objects)} files.") return operations def _validate_yaml(self, content: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None): """ Validate YAML from `README.md`, used before file hashing and upload. Args: content (`str`): Content of `README.md` to validate. repo_type (`str`, *optional*): The type of the repo to grant access to. Must be one of `model`, `dataset` or `space`. Defaults to `model`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Raises: - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if YAML is invalid """ repo_type = repo_type if repo_type is not None else constants.REPO_TYPE_MODEL headers = self._build_hf_headers(token=token) response = get_session().post( f"{self.endpoint}/api/validate-yaml", json={"content": content, "repoType": repo_type}, headers=headers, ) # Handle warnings (example: empty metadata) response_content = response.json() message = "\n".join([f"- {warning.get('message')}" for warning in response_content.get("warnings", [])]) if message: warnings.warn(f"Warnings while validating metadata in README.md:\n{message}") # Raise on errors try: hf_raise_for_status(response) except BadRequestError as e: errors = response_content.get("errors", []) message = "\n".join([f"- {error.get('message')}" for error in errors]) raise ValueError(f"Invalid metadata in README.md.\n{message}") from e def get_user_overview(self, username: str, token: Union[bool, str, None] = None) -> User: """ Get an overview of a user on the Hub. Args: username (`str`): Username of the user to get an overview of. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: `User`: A [`User`] object with the user's overview. Raises: [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 404 If the user does not exist on the Hub. """ r = get_session().get( f"{constants.ENDPOINT}/api/users/{username}/overview", headers=self._build_hf_headers(token=token) ) hf_raise_for_status(r) return User(**r.json()) def list_organization_members(self, organization: str, token: Union[bool, str, None] = None) -> Iterable[User]: """ List of members of an organization on the Hub. Args: organization (`str`): Name of the organization to get the members of. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: `Iterable[User]`: A list of [`User`] objects with the members of the organization. Raises: [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 404 If the organization does not exist on the Hub. """ for member in paginate( path=f"{constants.ENDPOINT}/api/organizations/{organization}/members", params={}, headers=self._build_hf_headers(token=token), ): yield User(**member) def list_user_followers(self, username: str, token: Union[bool, str, None] = None) -> Iterable[User]: """ Get the list of followers of a user on the Hub. Args: username (`str`): Username of the user to get the followers of. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: `Iterable[User]`: A list of [`User`] objects with the followers of the user. Raises: [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 404 If the user does not exist on the Hub. """ for follower in paginate( path=f"{constants.ENDPOINT}/api/users/{username}/followers", params={}, headers=self._build_hf_headers(token=token), ): yield User(**follower) def list_user_following(self, username: str, token: Union[bool, str, None] = None) -> Iterable[User]: """ Get the list of users followed by a user on the Hub. Args: username (`str`): Username of the user to get the users followed by. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: `Iterable[User]`: A list of [`User`] objects with the users followed by the user. Raises: [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 404 If the user does not exist on the Hub. """ for followed_user in paginate( path=f"{constants.ENDPOINT}/api/users/{username}/following", params={}, headers=self._build_hf_headers(token=token), ): yield User(**followed_user) def list_papers( self, *, query: Optional[str] = None, token: Union[bool, str, None] = None, ) -> Iterable[PaperInfo]: """ List daily papers on the Hugging Face Hub given a search query. Args: query (`str`, *optional*): A search query string to find papers. If provided, returns papers that match the query. token (Union[bool, str, None], *optional*): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. Returns: `Iterable[PaperInfo]`: an iterable of [`huggingface_hub.hf_api.PaperInfo`] objects. Example: ```python >>> from huggingface_hub import HfApi >>> api = HfApi() # List all papers with "attention" in their title >>> api.list_papers(query="attention") ``` """ path = f"{self.endpoint}/api/papers/search" params = {} if query: params["q"] = query r = get_session().get( path, params=params, headers=self._build_hf_headers(token=token), ) hf_raise_for_status(r) for paper in r.json(): yield PaperInfo(**paper) def paper_info(self, id: str) -> PaperInfo: """ Get information for a paper on the Hub. Args: id (`str`, **optional**): ArXiv id of the paper. Returns: `PaperInfo`: A `PaperInfo` object. Raises: [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 404 If the paper does not exist on the Hub. """ path = f"{self.endpoint}/api/papers/{id}" r = get_session().get(path) hf_raise_for_status(r) return PaperInfo(**r.json()) def auth_check( self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None ) -> None: """ Check if the provided user token has access to a specific repository on the Hugging Face Hub. This method verifies whether the user, authenticated via the provided token, has access to the specified repository. If the repository is not found or if the user lacks the required permissions to access it, the method raises an appropriate exception. Args: repo_id (`str`): The repository to check for access. Format should be `"user/repo_name"`. Example: `"user/my-cool-model"`. repo_type (`str`, *optional*): The type of the repository. Should be one of `"model"`, `"dataset"`, or `"space"`. If not specified, the default is `"model"`. token `(Union[bool, str, None]`, *optional*): A valid user access token. If not provided, the locally saved token will be used, which is the recommended authentication method. Set to `False` to disable authentication. Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. Raises: [`~utils.RepositoryNotFoundError`]: Raised if the repository does not exist, is private, or the user does not have access. This can occur if the `repo_id` or `repo_type` is incorrect or if the repository is private but the user is not authenticated. [`~utils.GatedRepoError`]: Raised if the repository exists but is gated and the user is not authorized to access it. Example: Check if the user has access to a repository: ```python >>> from huggingface_hub import auth_check >>> from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError try: auth_check("user/my-cool-model") except GatedRepoError: # Handle gated repository error print("You do not have permission to access this gated repository.") except RepositoryNotFoundError: # Handle repository not found error print("The repository was not found or you do not have access.") ``` In this example: - If the user has access, the method completes successfully. - If the repository is gated or does not exist, appropriate exceptions are raised, allowing the user to handle them accordingly. """ headers = self._build_hf_headers(token=token) if repo_type is None: repo_type = constants.REPO_TYPE_MODEL if repo_type not in constants.REPO_TYPES: raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/auth-check" r = get_session().get(path, headers=headers) hf_raise_for_status(r) def _parse_revision_from_pr_url(pr_url: str) -> str: """Safely parse revision number from a PR url. Example: ```py >>> _parse_revision_from_pr_url("https://huggingface.co/bigscience/bloom/discussions/2") "refs/pr/2" ``` """ re_match = re.match(_REGEX_DISCUSSION_URL, pr_url) if re_match is None: raise RuntimeError(f"Unexpected response from the hub, expected a Pull Request URL but got: '{pr_url}'") return f"refs/pr/{re_match[1]}" api = HfApi() whoami = api.whoami auth_check = api.auth_check get_token_permission = api.get_token_permission list_models = api.list_models model_info = api.model_info list_datasets = api.list_datasets dataset_info = api.dataset_info list_spaces = api.list_spaces space_info = api.space_info list_papers = api.list_papers paper_info = api.paper_info repo_exists = api.repo_exists revision_exists = api.revision_exists file_exists = api.file_exists repo_info = api.repo_info list_repo_files = api.list_repo_files list_repo_refs = api.list_repo_refs list_repo_commits = api.list_repo_commits list_repo_tree = api.list_repo_tree get_paths_info = api.get_paths_info get_model_tags = api.get_model_tags get_dataset_tags = api.get_dataset_tags create_commit = api.create_commit create_repo = api.create_repo delete_repo = api.delete_repo update_repo_visibility = api.update_repo_visibility update_repo_settings = api.update_repo_settings move_repo = api.move_repo upload_file = api.upload_file upload_folder = api.upload_folder delete_file = api.delete_file delete_folder = api.delete_folder delete_files = api.delete_files upload_large_folder = api.upload_large_folder preupload_lfs_files = api.preupload_lfs_files create_branch = api.create_branch delete_branch = api.delete_branch create_tag = api.create_tag delete_tag = api.delete_tag get_full_repo_name = api.get_full_repo_name # Danger-zone API super_squash_history = api.super_squash_history list_lfs_files = api.list_lfs_files permanently_delete_lfs_files = api.permanently_delete_lfs_files # Safetensors helpers get_safetensors_metadata = api.get_safetensors_metadata parse_safetensors_file_metadata = api.parse_safetensors_file_metadata # Background jobs run_as_future = api.run_as_future # Activity API list_liked_repos = api.list_liked_repos list_repo_likers = api.list_repo_likers unlike = api.unlike # Community API get_discussion_details = api.get_discussion_details get_repo_discussions = api.get_repo_discussions create_discussion = api.create_discussion create_pull_request = api.create_pull_request change_discussion_status = api.change_discussion_status comment_discussion = api.comment_discussion edit_discussion_comment = api.edit_discussion_comment rename_discussion = api.rename_discussion merge_pull_request = api.merge_pull_request # Space API add_space_secret = api.add_space_secret delete_space_secret = api.delete_space_secret get_space_variables = api.get_space_variables add_space_variable = api.add_space_variable delete_space_variable = api.delete_space_variable get_space_runtime = api.get_space_runtime request_space_hardware = api.request_space_hardware set_space_sleep_time = api.set_space_sleep_time pause_space = api.pause_space restart_space = api.restart_space duplicate_space = api.duplicate_space request_space_storage = api.request_space_storage delete_space_storage = api.delete_space_storage # Inference Endpoint API list_inference_endpoints = api.list_inference_endpoints create_inference_endpoint = api.create_inference_endpoint get_inference_endpoint = api.get_inference_endpoint update_inference_endpoint = api.update_inference_endpoint delete_inference_endpoint = api.delete_inference_endpoint pause_inference_endpoint = api.pause_inference_endpoint resume_inference_endpoint = api.resume_inference_endpoint scale_to_zero_inference_endpoint = api.scale_to_zero_inference_endpoint create_inference_endpoint_from_catalog = api.create_inference_endpoint_from_catalog list_inference_catalog = api.list_inference_catalog # Collections API get_collection = api.get_collection list_collections = api.list_collections create_collection = api.create_collection update_collection_metadata = api.update_collection_metadata delete_collection = api.delete_collection add_collection_item = api.add_collection_item update_collection_item = api.update_collection_item delete_collection_item = api.delete_collection_item delete_collection_item = api.delete_collection_item # Access requests API list_pending_access_requests = api.list_pending_access_requests list_accepted_access_requests = api.list_accepted_access_requests list_rejected_access_requests = api.list_rejected_access_requests cancel_access_request = api.cancel_access_request accept_access_request = api.accept_access_request reject_access_request = api.reject_access_request grant_access = api.grant_access # Webhooks API create_webhook = api.create_webhook disable_webhook = api.disable_webhook delete_webhook = api.delete_webhook enable_webhook = api.enable_webhook get_webhook = api.get_webhook list_webhooks = api.list_webhooks update_webhook = api.update_webhook # User API get_user_overview = api.get_user_overview list_organization_members = api.list_organization_members list_user_followers = api.list_user_followers list_user_following = api.list_user_following huggingface_hub-0.31.1/src/huggingface_hub/hf_file_system.py000066400000000000000000001346141500667546600241770ustar00rootroot00000000000000import os import re import tempfile from collections import deque from dataclasses import dataclass, field from datetime import datetime from itertools import chain from pathlib import Path from typing import Any, Dict, Iterator, List, NoReturn, Optional, Tuple, Union from urllib.parse import quote, unquote import fsspec from fsspec.callbacks import _DEFAULT_CALLBACK, NoOpCallback, TqdmCallback from fsspec.utils import isfilelike from requests import Response from . import constants from ._commit_api import CommitOperationCopy, CommitOperationDelete from .errors import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from .file_download import hf_hub_url, http_get from .hf_api import HfApi, LastCommitInfo, RepoFile from .utils import HFValidationError, hf_raise_for_status, http_backoff # Regex used to match special revisions with "/" in them (see #1710) SPECIAL_REFS_REVISION_REGEX = re.compile( r""" (^refs\/convert\/\w+) # `refs/convert/parquet` revisions | (^refs\/pr\/\d+) # PR revisions """, re.VERBOSE, ) @dataclass class HfFileSystemResolvedPath: """Data structure containing information about a resolved Hugging Face file system path.""" repo_type: str repo_id: str revision: str path_in_repo: str # The part placed after '@' in the initial path. It can be a quoted or unquoted refs revision. # Used to reconstruct the unresolved path to return to the user. _raw_revision: Optional[str] = field(default=None, repr=False) def unresolve(self) -> str: repo_path = constants.REPO_TYPES_URL_PREFIXES.get(self.repo_type, "") + self.repo_id if self._raw_revision: return f"{repo_path}@{self._raw_revision}/{self.path_in_repo}".rstrip("/") elif self.revision != constants.DEFAULT_REVISION: return f"{repo_path}@{safe_revision(self.revision)}/{self.path_in_repo}".rstrip("/") else: return f"{repo_path}/{self.path_in_repo}".rstrip("/") class HfFileSystem(fsspec.AbstractFileSystem): """ Access a remote Hugging Face Hub repository as if were a local file system. [`HfFileSystem`] provides fsspec compatibility, which is useful for libraries that require it (e.g., reading Hugging Face datasets directly with `pandas`). However, it introduces additional overhead due to this compatibility layer. For better performance and reliability, it's recommended to use `HfApi` methods when possible. Args: token (`str` or `bool`, *optional*): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. endpoint (`str`, *optional*): Endpoint of the Hub. Defaults to . Usage: ```python >>> from huggingface_hub import HfFileSystem >>> fs = HfFileSystem() >>> # List files >>> fs.glob("my-username/my-model/*.bin") ['my-username/my-model/pytorch_model.bin'] >>> fs.ls("datasets/my-username/my-dataset", detail=False) ['datasets/my-username/my-dataset/.gitattributes', 'datasets/my-username/my-dataset/README.md', 'datasets/my-username/my-dataset/data.json'] >>> # Read/write files >>> with fs.open("my-username/my-model/pytorch_model.bin") as f: ... data = f.read() >>> with fs.open("my-username/my-model/pytorch_model.bin", "wb") as f: ... f.write(data) ``` """ root_marker = "" protocol = "hf" def __init__( self, *args, endpoint: Optional[str] = None, token: Union[bool, str, None] = None, **storage_options, ): super().__init__(*args, **storage_options) self.endpoint = endpoint or constants.ENDPOINT self.token = token self._api = HfApi(endpoint=endpoint, token=token) # Maps (repo_type, repo_id, revision) to a 2-tuple with: # * the 1st element indicating whether the repositoy and the revision exist # * the 2nd element being the exception raised if the repository or revision doesn't exist self._repo_and_revision_exists_cache: Dict[ Tuple[str, str, Optional[str]], Tuple[bool, Optional[Exception]] ] = {} def _repo_and_revision_exist( self, repo_type: str, repo_id: str, revision: Optional[str] ) -> Tuple[bool, Optional[Exception]]: if (repo_type, repo_id, revision) not in self._repo_and_revision_exists_cache: try: self._api.repo_info( repo_id, revision=revision, repo_type=repo_type, timeout=constants.HF_HUB_ETAG_TIMEOUT ) except (RepositoryNotFoundError, HFValidationError) as e: self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = False, e except RevisionNotFoundError as e: self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = True, None else: self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = True, None self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = True, None return self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] def resolve_path(self, path: str, revision: Optional[str] = None) -> HfFileSystemResolvedPath: """ Resolve a Hugging Face file system path into its components. Args: path (`str`): Path to resolve. revision (`str`, *optional*): The revision of the repo to resolve. Defaults to the revision specified in the path. Returns: [`HfFileSystemResolvedPath`]: Resolved path information containing `repo_type`, `repo_id`, `revision` and `path_in_repo`. Raises: `ValueError`: If path contains conflicting revision information. `NotImplementedError`: If trying to list repositories. """ def _align_revision_in_path_with_revision( revision_in_path: Optional[str], revision: Optional[str] ) -> Optional[str]: if revision is not None: if revision_in_path is not None and revision_in_path != revision: raise ValueError( f'Revision specified in path ("{revision_in_path}") and in `revision` argument ("{revision}")' " are not the same." ) else: revision = revision_in_path return revision path = self._strip_protocol(path) if not path: # can't list repositories at root raise NotImplementedError("Access to repositories lists is not implemented.") elif path.split("/")[0] + "/" in constants.REPO_TYPES_URL_PREFIXES.values(): if "/" not in path: # can't list repositories at the repository type level raise NotImplementedError("Access to repositories lists is not implemented.") repo_type, path = path.split("/", 1) repo_type = constants.REPO_TYPES_MAPPING[repo_type] else: repo_type = constants.REPO_TYPE_MODEL if path.count("/") > 0: if "@" in path: repo_id, revision_in_path = path.split("@", 1) if "/" in revision_in_path: match = SPECIAL_REFS_REVISION_REGEX.search(revision_in_path) if match is not None and revision in (None, match.group()): # Handle `refs/convert/parquet` and PR revisions separately path_in_repo = SPECIAL_REFS_REVISION_REGEX.sub("", revision_in_path).lstrip("/") revision_in_path = match.group() else: revision_in_path, path_in_repo = revision_in_path.split("/", 1) else: path_in_repo = "" revision = _align_revision_in_path_with_revision(unquote(revision_in_path), revision) repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision) if not repo_and_revision_exist: _raise_file_not_found(path, err) else: revision_in_path = None repo_id_with_namespace = "/".join(path.split("/")[:2]) path_in_repo_with_namespace = "/".join(path.split("/")[2:]) repo_id_without_namespace = path.split("/")[0] path_in_repo_without_namespace = "/".join(path.split("/")[1:]) repo_id = repo_id_with_namespace path_in_repo = path_in_repo_with_namespace repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision) if not repo_and_revision_exist: if isinstance(err, (RepositoryNotFoundError, HFValidationError)): repo_id = repo_id_without_namespace path_in_repo = path_in_repo_without_namespace repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision) if not repo_and_revision_exist: _raise_file_not_found(path, err) else: _raise_file_not_found(path, err) else: repo_id = path path_in_repo = "" if "@" in path: repo_id, revision_in_path = path.split("@", 1) revision = _align_revision_in_path_with_revision(unquote(revision_in_path), revision) else: revision_in_path = None repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision) if not repo_and_revision_exist: raise NotImplementedError("Access to repositories lists is not implemented.") revision = revision if revision is not None else constants.DEFAULT_REVISION return HfFileSystemResolvedPath(repo_type, repo_id, revision, path_in_repo, _raw_revision=revision_in_path) def invalidate_cache(self, path: Optional[str] = None) -> None: """ Clear the cache for a given path. For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.invalidate_cache). Args: path (`str`, *optional*): Path to clear from cache. If not provided, clear the entire cache. """ if not path: self.dircache.clear() self._repo_and_revision_exists_cache.clear() else: resolved_path = self.resolve_path(path) path = resolved_path.unresolve() while path: self.dircache.pop(path, None) path = self._parent(path) # Only clear repo cache if path is to repo root if not resolved_path.path_in_repo: self._repo_and_revision_exists_cache.pop((resolved_path.repo_type, resolved_path.repo_id, None), None) self._repo_and_revision_exists_cache.pop( (resolved_path.repo_type, resolved_path.repo_id, resolved_path.revision), None ) def _open( self, path: str, mode: str = "rb", revision: Optional[str] = None, block_size: Optional[int] = None, **kwargs, ) -> "HfFileSystemFile": if "a" in mode: raise NotImplementedError("Appending to remote files is not yet supported.") if block_size == 0: return HfFileSystemStreamFile(self, path, mode=mode, revision=revision, block_size=block_size, **kwargs) else: return HfFileSystemFile(self, path, mode=mode, revision=revision, block_size=block_size, **kwargs) def _rm(self, path: str, revision: Optional[str] = None, **kwargs) -> None: resolved_path = self.resolve_path(path, revision=revision) self._api.delete_file( path_in_repo=resolved_path.path_in_repo, repo_id=resolved_path.repo_id, token=self.token, repo_type=resolved_path.repo_type, revision=resolved_path.revision, commit_message=kwargs.get("commit_message"), commit_description=kwargs.get("commit_description"), ) self.invalidate_cache(path=resolved_path.unresolve()) def rm( self, path: str, recursive: bool = False, maxdepth: Optional[int] = None, revision: Optional[str] = None, **kwargs, ) -> None: """ Delete files from a repository. For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.rm). Note: When possible, use `HfApi.delete_file()` for better performance. Args: path (`str`): Path to delete. recursive (`bool`, *optional*): If True, delete directory and all its contents. Defaults to False. maxdepth (`int`, *optional*): Maximum number of subdirectories to visit when deleting recursively. revision (`str`, *optional*): The git revision to delete from. """ resolved_path = self.resolve_path(path, revision=revision) paths = self.expand_path(path, recursive=recursive, maxdepth=maxdepth, revision=revision) paths_in_repo = [self.resolve_path(path).path_in_repo for path in paths if not self.isdir(path)] operations = [CommitOperationDelete(path_in_repo=path_in_repo) for path_in_repo in paths_in_repo] commit_message = f"Delete {path} " commit_message += "recursively " if recursive else "" commit_message += f"up to depth {maxdepth} " if maxdepth is not None else "" # TODO: use `commit_description` to list all the deleted paths? self._api.create_commit( repo_id=resolved_path.repo_id, repo_type=resolved_path.repo_type, token=self.token, operations=operations, revision=resolved_path.revision, commit_message=kwargs.get("commit_message", commit_message), commit_description=kwargs.get("commit_description"), ) self.invalidate_cache(path=resolved_path.unresolve()) def ls( self, path: str, detail: bool = True, refresh: bool = False, revision: Optional[str] = None, **kwargs ) -> List[Union[str, Dict[str, Any]]]: """ List the contents of a directory. For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.ls). Note: When possible, use `HfApi.list_repo_tree()` for better performance. Args: path (`str`): Path to the directory. detail (`bool`, *optional*): If True, returns a list of dictionaries containing file information. If False, returns a list of file paths. Defaults to True. refresh (`bool`, *optional*): If True, bypass the cache and fetch the latest data. Defaults to False. revision (`str`, *optional*): The git revision to list from. Returns: `List[Union[str, Dict[str, Any]]]`: List of file paths (if detail=False) or list of file information dictionaries (if detail=True). """ resolved_path = self.resolve_path(path, revision=revision) path = resolved_path.unresolve() kwargs = {"expand_info": detail, **kwargs} try: out = self._ls_tree(path, refresh=refresh, revision=revision, **kwargs) except EntryNotFoundError: # Path could be a file if not resolved_path.path_in_repo: _raise_file_not_found(path, None) out = self._ls_tree(self._parent(path), refresh=refresh, revision=revision, **kwargs) out = [o for o in out if o["name"] == path] if len(out) == 0: _raise_file_not_found(path, None) return out if detail else [o["name"] for o in out] def _ls_tree( self, path: str, recursive: bool = False, refresh: bool = False, revision: Optional[str] = None, expand_info: bool = True, ): resolved_path = self.resolve_path(path, revision=revision) path = resolved_path.unresolve() root_path = HfFileSystemResolvedPath( resolved_path.repo_type, resolved_path.repo_id, resolved_path.revision, path_in_repo="", _raw_revision=resolved_path._raw_revision, ).unresolve() out = [] if path in self.dircache and not refresh: cached_path_infos = self.dircache[path] out.extend(cached_path_infos) dirs_not_in_dircache = [] if recursive: # Use BFS to traverse the cache and build the "recursive "output # (The Hub uses a so-called "tree first" strategy for the tree endpoint but we sort the output to follow the spec so the result is (eventually) the same) dirs_to_visit = deque( [path_info for path_info in cached_path_infos if path_info["type"] == "directory"] ) while dirs_to_visit: dir_info = dirs_to_visit.popleft() if dir_info["name"] not in self.dircache: dirs_not_in_dircache.append(dir_info["name"]) else: cached_path_infos = self.dircache[dir_info["name"]] out.extend(cached_path_infos) dirs_to_visit.extend( [path_info for path_info in cached_path_infos if path_info["type"] == "directory"] ) dirs_not_expanded = [] if expand_info: # Check if there are directories with non-expanded entries dirs_not_expanded = [self._parent(o["name"]) for o in out if o["last_commit"] is None] if (recursive and dirs_not_in_dircache) or (expand_info and dirs_not_expanded): # If the dircache is incomplete, find the common path of the missing and non-expanded entries # and extend the output with the result of `_ls_tree(common_path, recursive=True)` common_prefix = os.path.commonprefix(dirs_not_in_dircache + dirs_not_expanded) # Get the parent directory if the common prefix itself is not a directory common_path = ( common_prefix.rstrip("/") if common_prefix.endswith("/") or common_prefix == root_path or common_prefix in chain(dirs_not_in_dircache, dirs_not_expanded) else self._parent(common_prefix) ) out = [o for o in out if not o["name"].startswith(common_path + "/")] for cached_path in self.dircache: if cached_path.startswith(common_path + "/"): self.dircache.pop(cached_path, None) self.dircache.pop(common_path, None) out.extend( self._ls_tree( common_path, recursive=recursive, refresh=True, revision=revision, expand_info=expand_info, ) ) else: tree = self._api.list_repo_tree( resolved_path.repo_id, resolved_path.path_in_repo, recursive=recursive, expand=expand_info, revision=resolved_path.revision, repo_type=resolved_path.repo_type, ) for path_info in tree: if isinstance(path_info, RepoFile): cache_path_info = { "name": root_path + "/" + path_info.path, "size": path_info.size, "type": "file", "blob_id": path_info.blob_id, "lfs": path_info.lfs, "last_commit": path_info.last_commit, "security": path_info.security, } else: cache_path_info = { "name": root_path + "/" + path_info.path, "size": 0, "type": "directory", "tree_id": path_info.tree_id, "last_commit": path_info.last_commit, } parent_path = self._parent(cache_path_info["name"]) self.dircache.setdefault(parent_path, []).append(cache_path_info) out.append(cache_path_info) return out def walk(self, path: str, *args, **kwargs) -> Iterator[Tuple[str, List[str], List[str]]]: """ Return all files below the given path. For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.walk). Args: path (`str`): Root path to list files from. Returns: `Iterator[Tuple[str, List[str], List[str]]]`: An iterator of (path, list of directory names, list of file names) tuples. """ # Set expand_info=False by default to get a x10 speed boost kwargs = {"expand_info": kwargs.get("detail", False), **kwargs} path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve() yield from super().walk(path, *args, **kwargs) def glob(self, path: str, **kwargs) -> List[str]: """ Find files by glob-matching. For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.glob). Args: path (`str`): Path pattern to match. Returns: `List[str]`: List of paths matching the pattern. """ # Set expand_info=False by default to get a x10 speed boost kwargs = {"expand_info": kwargs.get("detail", False), **kwargs} path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve() return super().glob(path, **kwargs) def find( self, path: str, maxdepth: Optional[int] = None, withdirs: bool = False, detail: bool = False, refresh: bool = False, revision: Optional[str] = None, **kwargs, ) -> Union[List[str], Dict[str, Dict[str, Any]]]: """ List all files below path. For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.find). Args: path (`str`): Root path to list files from. maxdepth (`int`, *optional*): Maximum depth to descend into subdirectories. withdirs (`bool`, *optional*): Include directory paths in the output. Defaults to False. detail (`bool`, *optional*): If True, returns a dict mapping paths to file information. Defaults to False. refresh (`bool`, *optional*): If True, bypass the cache and fetch the latest data. Defaults to False. revision (`str`, *optional*): The git revision to list from. Returns: `Union[List[str], Dict[str, Dict[str, Any]]]`: List of paths or dict of file information. """ if maxdepth: return super().find( path, maxdepth=maxdepth, withdirs=withdirs, detail=detail, refresh=refresh, revision=revision, **kwargs ) resolved_path = self.resolve_path(path, revision=revision) path = resolved_path.unresolve() kwargs = {"expand_info": detail, **kwargs} try: out = self._ls_tree(path, recursive=True, refresh=refresh, revision=resolved_path.revision, **kwargs) except EntryNotFoundError: # Path could be a file if self.info(path, revision=revision, **kwargs)["type"] == "file": out = {path: {}} else: out = {} else: if not withdirs: out = [o for o in out if o["type"] != "directory"] else: # If `withdirs=True`, include the directory itself to be consistent with the spec path_info = self.info(path, revision=resolved_path.revision, **kwargs) out = [path_info] + out if path_info["type"] == "directory" else out out = {o["name"]: o for o in out} names = sorted(out) if not detail: return names else: return {name: out[name] for name in names} def cp_file(self, path1: str, path2: str, revision: Optional[str] = None, **kwargs) -> None: """ Copy a file within or between repositories. Note: When possible, use `HfApi.upload_file()` for better performance. Args: path1 (`str`): Source path to copy from. path2 (`str`): Destination path to copy to. revision (`str`, *optional*): The git revision to copy from. """ resolved_path1 = self.resolve_path(path1, revision=revision) resolved_path2 = self.resolve_path(path2, revision=revision) same_repo = ( resolved_path1.repo_type == resolved_path2.repo_type and resolved_path1.repo_id == resolved_path2.repo_id ) if same_repo: commit_message = f"Copy {path1} to {path2}" self._api.create_commit( repo_id=resolved_path1.repo_id, repo_type=resolved_path1.repo_type, revision=resolved_path2.revision, commit_message=kwargs.get("commit_message", commit_message), commit_description=kwargs.get("commit_description", ""), operations=[ CommitOperationCopy( src_path_in_repo=resolved_path1.path_in_repo, path_in_repo=resolved_path2.path_in_repo, src_revision=resolved_path1.revision, ) ], ) else: with self.open(path1, "rb", revision=resolved_path1.revision) as f: content = f.read() commit_message = f"Copy {path1} to {path2}" self._api.upload_file( path_or_fileobj=content, path_in_repo=resolved_path2.path_in_repo, repo_id=resolved_path2.repo_id, token=self.token, repo_type=resolved_path2.repo_type, revision=resolved_path2.revision, commit_message=kwargs.get("commit_message", commit_message), commit_description=kwargs.get("commit_description"), ) self.invalidate_cache(path=resolved_path1.unresolve()) self.invalidate_cache(path=resolved_path2.unresolve()) def modified(self, path: str, **kwargs) -> datetime: """ Get the last modified time of a file. For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.modified). Args: path (`str`): Path to the file. Returns: `datetime`: Last commit date of the file. """ info = self.info(path, **kwargs) return info["last_commit"]["date"] def info(self, path: str, refresh: bool = False, revision: Optional[str] = None, **kwargs) -> Dict[str, Any]: """ Get information about a file or directory. For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.info). Note: When possible, use `HfApi.get_paths_info()` or `HfApi.repo_info()` for better performance. Args: path (`str`): Path to get info for. refresh (`bool`, *optional*): If True, bypass the cache and fetch the latest data. Defaults to False. revision (`str`, *optional*): The git revision to get info from. Returns: `Dict[str, Any]`: Dictionary containing file information (type, size, commit info, etc.). """ resolved_path = self.resolve_path(path, revision=revision) path = resolved_path.unresolve() expand_info = kwargs.get( "expand_info", True ) # don't expose it as a parameter in the public API to follow the spec if not resolved_path.path_in_repo: # Path is the root directory out = { "name": path, "size": 0, "type": "directory", } if expand_info: last_commit = self._api.list_repo_commits( resolved_path.repo_id, repo_type=resolved_path.repo_type, revision=resolved_path.revision )[-1] out = { **out, "tree_id": None, # TODO: tree_id of the root directory? "last_commit": LastCommitInfo( oid=last_commit.commit_id, title=last_commit.title, date=last_commit.created_at ), } else: out = None parent_path = self._parent(path) if not expand_info and parent_path not in self.dircache: # Fill the cache with cheap call self.ls(parent_path, expand_info=False) if parent_path in self.dircache: # Check if the path is in the cache out1 = [o for o in self.dircache[parent_path] if o["name"] == path] if not out1: _raise_file_not_found(path, None) out = out1[0] if refresh or out is None or (expand_info and out and out["last_commit"] is None): paths_info = self._api.get_paths_info( resolved_path.repo_id, resolved_path.path_in_repo, expand=expand_info, revision=resolved_path.revision, repo_type=resolved_path.repo_type, ) if not paths_info: _raise_file_not_found(path, None) path_info = paths_info[0] root_path = HfFileSystemResolvedPath( resolved_path.repo_type, resolved_path.repo_id, resolved_path.revision, path_in_repo="", _raw_revision=resolved_path._raw_revision, ).unresolve() if isinstance(path_info, RepoFile): out = { "name": root_path + "/" + path_info.path, "size": path_info.size, "type": "file", "blob_id": path_info.blob_id, "lfs": path_info.lfs, "last_commit": path_info.last_commit, "security": path_info.security, } else: out = { "name": root_path + "/" + path_info.path, "size": 0, "type": "directory", "tree_id": path_info.tree_id, "last_commit": path_info.last_commit, } if not expand_info: out = {k: out[k] for k in ["name", "size", "type"]} assert out is not None return out def exists(self, path, **kwargs): """ Check if a file exists. For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.exists). Note: When possible, use `HfApi.file_exists()` for better performance. Args: path (`str`): Path to check. Returns: `bool`: True if file exists, False otherwise. """ try: if kwargs.get("refresh", False): self.invalidate_cache(path) self.info(path, **{**kwargs, "expand_info": False}) return True except: # noqa: E722 return False def isdir(self, path): """ Check if a path is a directory. For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.isdir). Args: path (`str`): Path to check. Returns: `bool`: True if path is a directory, False otherwise. """ try: return self.info(path, expand_info=False)["type"] == "directory" except OSError: return False def isfile(self, path): """ Check if a path is a file. For more details, refer to [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.isfile). Args: path (`str`): Path to check. Returns: `bool`: True if path is a file, False otherwise. """ try: return self.info(path, expand_info=False)["type"] == "file" except: # noqa: E722 return False def url(self, path: str) -> str: """ Get the HTTP URL of the given path. Args: path (`str`): Path to get URL for. Returns: `str`: HTTP URL to access the file or directory on the Hub. """ resolved_path = self.resolve_path(path) url = hf_hub_url( resolved_path.repo_id, resolved_path.path_in_repo, repo_type=resolved_path.repo_type, revision=resolved_path.revision, endpoint=self.endpoint, ) if self.isdir(path): url = url.replace("/resolve/", "/tree/", 1) return url def get_file(self, rpath, lpath, callback=_DEFAULT_CALLBACK, outfile=None, **kwargs) -> None: """ Copy single remote file to local. Note: When possible, use `HfApi.hf_hub_download()` for better performance. Args: rpath (`str`): Remote path to download from. lpath (`str`): Local path to download to. callback (`Callback`, *optional*): Optional callback to track download progress. Defaults to no callback. outfile (`IO`, *optional*): Optional file-like object to write to. If provided, `lpath` is ignored. """ revision = kwargs.get("revision") unhandled_kwargs = set(kwargs.keys()) - {"revision"} if not isinstance(callback, (NoOpCallback, TqdmCallback)) or len(unhandled_kwargs) > 0: # for now, let's not handle custom callbacks # and let's not handle custom kwargs return super().get_file(rpath, lpath, callback=callback, outfile=outfile, **kwargs) # Taken from https://github.com/fsspec/filesystem_spec/blob/47b445ae4c284a82dd15e0287b1ffc410e8fc470/fsspec/spec.py#L883 if isfilelike(lpath): outfile = lpath elif self.isdir(rpath): os.makedirs(lpath, exist_ok=True) return None if isinstance(lpath, (str, Path)): # otherwise, let's assume it's a file-like object os.makedirs(os.path.dirname(lpath), exist_ok=True) # Open file if not already open close_file = False if outfile is None: outfile = open(lpath, "wb") close_file = True initial_pos = outfile.tell() # Custom implementation of `get_file` to use `http_get`. resolve_remote_path = self.resolve_path(rpath, revision=revision) expected_size = self.info(rpath, revision=revision)["size"] callback.set_size(expected_size) try: http_get( url=hf_hub_url( repo_id=resolve_remote_path.repo_id, revision=resolve_remote_path.revision, filename=resolve_remote_path.path_in_repo, repo_type=resolve_remote_path.repo_type, endpoint=self.endpoint, ), temp_file=outfile, displayed_filename=rpath, expected_size=expected_size, resume_size=0, headers=self._api._build_hf_headers(), _tqdm_bar=callback.tqdm if isinstance(callback, TqdmCallback) else None, ) outfile.seek(initial_pos) finally: # Close file only if we opened it ourselves if close_file: outfile.close() @property def transaction(self): """A context within which files are committed together upon exit Requires the file class to implement `.commit()` and `.discard()` for the normal and exception cases. """ # Taken from https://github.com/fsspec/filesystem_spec/blob/3fbb6fee33b46cccb015607630843dea049d3243/fsspec/spec.py#L231 # See https://github.com/huggingface/huggingface_hub/issues/1733 raise NotImplementedError("Transactional commits are not supported.") def start_transaction(self): """Begin write transaction for deferring files, non-context version""" # Taken from https://github.com/fsspec/filesystem_spec/blob/3fbb6fee33b46cccb015607630843dea049d3243/fsspec/spec.py#L241 # See https://github.com/huggingface/huggingface_hub/issues/1733 raise NotImplementedError("Transactional commits are not supported.") class HfFileSystemFile(fsspec.spec.AbstractBufferedFile): def __init__(self, fs: HfFileSystem, path: str, revision: Optional[str] = None, **kwargs): try: self.resolved_path = fs.resolve_path(path, revision=revision) except FileNotFoundError as e: if "w" in kwargs.get("mode", ""): raise FileNotFoundError( f"{e}.\nMake sure the repository and revision exist before writing data." ) from e raise # avoid an unnecessary .info() call with expensive expand_info=True to instantiate .details if kwargs.get("mode", "rb") == "rb": self.details = fs.info(self.resolved_path.unresolve(), expand_info=False) super().__init__(fs, self.resolved_path.unresolve(), **kwargs) self.fs: HfFileSystem def __del__(self): if not hasattr(self, "resolved_path"): # Means that the constructor failed. Nothing to do. return return super().__del__() def _fetch_range(self, start: int, end: int) -> bytes: headers = { "range": f"bytes={start}-{end - 1}", **self.fs._api._build_hf_headers(), } url = hf_hub_url( repo_id=self.resolved_path.repo_id, revision=self.resolved_path.revision, filename=self.resolved_path.path_in_repo, repo_type=self.resolved_path.repo_type, endpoint=self.fs.endpoint, ) r = http_backoff( "GET", url, headers=headers, retry_on_status_codes=(500, 502, 503, 504), timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT, ) hf_raise_for_status(r) return r.content def _initiate_upload(self) -> None: self.temp_file = tempfile.NamedTemporaryFile(prefix="hffs-", delete=False) def _upload_chunk(self, final: bool = False) -> None: self.buffer.seek(0) block = self.buffer.read() self.temp_file.write(block) if final: self.temp_file.close() self.fs._api.upload_file( path_or_fileobj=self.temp_file.name, path_in_repo=self.resolved_path.path_in_repo, repo_id=self.resolved_path.repo_id, token=self.fs.token, repo_type=self.resolved_path.repo_type, revision=self.resolved_path.revision, commit_message=self.kwargs.get("commit_message"), commit_description=self.kwargs.get("commit_description"), ) os.remove(self.temp_file.name) self.fs.invalidate_cache( path=self.resolved_path.unresolve(), ) def read(self, length=-1): """Read remote file. If `length` is not provided or is -1, the entire file is downloaded and read. On POSIX systems and if `hf_transfer` is not enabled, the file is loaded in memory directly. Otherwise, the file is downloaded to a temporary file and read from there. """ if self.mode == "rb" and (length is None or length == -1) and self.loc == 0: with self.fs.open(self.path, "rb", block_size=0) as f: # block_size=0 enables fast streaming return f.read() return super().read(length) def url(self) -> str: return self.fs.url(self.path) class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile): def __init__( self, fs: HfFileSystem, path: str, mode: str = "rb", revision: Optional[str] = None, block_size: int = 0, cache_type: str = "none", **kwargs, ): if block_size != 0: raise ValueError(f"HfFileSystemStreamFile only supports block_size=0 but got {block_size}") if cache_type != "none": raise ValueError(f"HfFileSystemStreamFile only supports cache_type='none' but got {cache_type}") if "w" in mode: raise ValueError(f"HfFileSystemStreamFile only supports reading but got mode='{mode}'") try: self.resolved_path = fs.resolve_path(path, revision=revision) except FileNotFoundError as e: if "w" in kwargs.get("mode", ""): raise FileNotFoundError( f"{e}.\nMake sure the repository and revision exist before writing data." ) from e # avoid an unnecessary .info() call to instantiate .details self.details = {"name": self.resolved_path.unresolve(), "size": None} super().__init__( fs, self.resolved_path.unresolve(), mode=mode, block_size=block_size, cache_type=cache_type, **kwargs ) self.response: Optional[Response] = None self.fs: HfFileSystem def seek(self, loc: int, whence: int = 0): if loc == 0 and whence == 1: return if loc == self.loc and whence == 0: return raise ValueError("Cannot seek streaming HF file") def read(self, length: int = -1): read_args = (length,) if length >= 0 else () if self.response is None or self.response.raw.isclosed(): url = hf_hub_url( repo_id=self.resolved_path.repo_id, revision=self.resolved_path.revision, filename=self.resolved_path.path_in_repo, repo_type=self.resolved_path.repo_type, endpoint=self.fs.endpoint, ) self.response = http_backoff( "GET", url, headers=self.fs._api._build_hf_headers(), retry_on_status_codes=(500, 502, 503, 504), stream=True, timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT, ) hf_raise_for_status(self.response) try: out = self.response.raw.read(*read_args) except Exception: self.response.close() # Retry by recreating the connection url = hf_hub_url( repo_id=self.resolved_path.repo_id, revision=self.resolved_path.revision, filename=self.resolved_path.path_in_repo, repo_type=self.resolved_path.repo_type, endpoint=self.fs.endpoint, ) self.response = http_backoff( "GET", url, headers={"Range": "bytes=%d-" % self.loc, **self.fs._api._build_hf_headers()}, retry_on_status_codes=(500, 502, 503, 504), stream=True, timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT, ) hf_raise_for_status(self.response) try: out = self.response.raw.read(*read_args) except Exception: self.response.close() raise self.loc += len(out) return out def url(self) -> str: return self.fs.url(self.path) def __del__(self): if not hasattr(self, "resolved_path"): # Means that the constructor failed. Nothing to do. return return super().__del__() def __reduce__(self): return reopen, (self.fs, self.path, self.mode, self.blocksize, self.cache.name) def safe_revision(revision: str) -> str: return revision if SPECIAL_REFS_REVISION_REGEX.match(revision) else safe_quote(revision) def safe_quote(s: str) -> str: return quote(s, safe="") def _raise_file_not_found(path: str, err: Optional[Exception]) -> NoReturn: msg = path if isinstance(err, RepositoryNotFoundError): msg = f"{path} (repository not found)" elif isinstance(err, RevisionNotFoundError): msg = f"{path} (revision not found)" elif isinstance(err, HFValidationError): msg = f"{path} (invalid repository id)" raise FileNotFoundError(msg) from err def reopen(fs: HfFileSystem, path: str, mode: str, block_size: int, cache_type: str): return fs.open(path, mode=mode, block_size=block_size, cache_type=cache_type) huggingface_hub-0.31.1/src/huggingface_hub/hub_mixin.py000066400000000000000000001122561500667546600231570ustar00rootroot00000000000000import inspect import json import os from dataclasses import Field, asdict, dataclass, is_dataclass from pathlib import Path from typing import Any, Callable, ClassVar, Dict, List, Optional, Protocol, Tuple, Type, TypeVar, Union import packaging.version from . import constants from .errors import EntryNotFoundError, HfHubHTTPError from .file_download import hf_hub_download from .hf_api import HfApi from .repocard import ModelCard, ModelCardData from .utils import ( SoftTemporaryDirectory, is_jsonable, is_safetensors_available, is_simple_optional_type, is_torch_available, logging, unwrap_simple_optional_type, validate_hf_hub_args, ) if is_torch_available(): import torch # type: ignore if is_safetensors_available(): import safetensors from safetensors.torch import load_model as load_model_as_safetensor from safetensors.torch import save_model as save_model_as_safetensor logger = logging.get_logger(__name__) # Type alias for dataclass instances, copied from https://github.com/python/typeshed/blob/9f28171658b9ca6c32a7cb93fbb99fc92b17858b/stdlib/_typeshed/__init__.pyi#L349 class DataclassInstance(Protocol): __dataclass_fields__: ClassVar[Dict[str, Field]] # Generic variable that is either ModelHubMixin or a subclass thereof T = TypeVar("T", bound="ModelHubMixin") # Generic variable to represent an args type ARGS_T = TypeVar("ARGS_T") ENCODER_T = Callable[[ARGS_T], Any] DECODER_T = Callable[[Any], ARGS_T] CODER_T = Tuple[ENCODER_T, DECODER_T] DEFAULT_MODEL_CARD = """ --- # For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 # Doc / guide: https://huggingface.co/docs/hub/model-cards {{ card_data }} --- This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration: - Code: {{ repo_url | default("[More Information Needed]", true) }} - Paper: {{ paper_url | default("[More Information Needed]", true) }} - Docs: {{ docs_url | default("[More Information Needed]", true) }} """ @dataclass class MixinInfo: model_card_template: str model_card_data: ModelCardData docs_url: Optional[str] = None paper_url: Optional[str] = None repo_url: Optional[str] = None class ModelHubMixin: """ A generic mixin to integrate ANY machine learning framework with the Hub. To integrate your framework, your model class must inherit from this class. Custom logic for saving/loading models have to be overwritten in [`_from_pretrained`] and [`_save_pretrained`]. [`PyTorchModelHubMixin`] is a good example of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions. When inheriting from [`ModelHubMixin`], you can define class-level attributes. These attributes are not passed to `__init__` but to the class definition itself. This is useful to define metadata about the library integrating [`ModelHubMixin`]. For more details on how to integrate the mixin with your library, checkout the [integration guide](../guides/integrations). Args: repo_url (`str`, *optional*): URL of the library repository. Used to generate model card. paper_url (`str`, *optional*): URL of the library paper. Used to generate model card. docs_url (`str`, *optional*): URL of the library documentation. Used to generate model card. model_card_template (`str`, *optional*): Template of the model card. Used to generate model card. Defaults to a generic template. language (`str` or `List[str]`, *optional*): Language supported by the library. Used to generate model card. library_name (`str`, *optional*): Name of the library integrating ModelHubMixin. Used to generate model card. license (`str`, *optional*): License of the library integrating ModelHubMixin. Used to generate model card. E.g: "apache-2.0" license_name (`str`, *optional*): Name of the library integrating ModelHubMixin. Used to generate model card. Only used if `license` is set to `other`. E.g: "coqui-public-model-license". license_link (`str`, *optional*): URL to the license of the library integrating ModelHubMixin. Used to generate model card. Only used if `license` is set to `other` and `license_name` is set. E.g: "https://coqui.ai/cpml". pipeline_tag (`str`, *optional*): Tag of the pipeline. Used to generate model card. E.g. "text-classification". tags (`List[str]`, *optional*): Tags to be added to the model card. Used to generate model card. E.g. ["computer-vision"] coders (`Dict[Type, Tuple[Callable, Callable]]`, *optional*): Dictionary of custom types and their encoders/decoders. Used to encode/decode arguments that are not jsonable by default. E.g dataclasses, argparse.Namespace, OmegaConf, etc. Example: ```python >>> from huggingface_hub import ModelHubMixin # Inherit from ModelHubMixin >>> class MyCustomModel( ... ModelHubMixin, ... library_name="my-library", ... tags=["computer-vision"], ... repo_url="https://github.com/huggingface/my-cool-library", ... paper_url="https://arxiv.org/abs/2304.12244", ... docs_url="https://huggingface.co/docs/my-cool-library", ... # ^ optional metadata to generate model card ... ): ... def __init__(self, size: int = 512, device: str = "cpu"): ... # define how to initialize your model ... super().__init__() ... ... ... ... def _save_pretrained(self, save_directory: Path) -> None: ... # define how to serialize your model ... ... ... ... @classmethod ... def from_pretrained( ... cls: Type[T], ... pretrained_model_name_or_path: Union[str, Path], ... *, ... force_download: bool = False, ... resume_download: Optional[bool] = None, ... proxies: Optional[Dict] = None, ... token: Optional[Union[str, bool]] = None, ... cache_dir: Optional[Union[str, Path]] = None, ... local_files_only: bool = False, ... revision: Optional[str] = None, ... **model_kwargs, ... ) -> T: ... # define how to deserialize your model ... ... >>> model = MyCustomModel(size=256, device="gpu") # Save model weights to local directory >>> model.save_pretrained("my-awesome-model") # Push model weights to the Hub >>> model.push_to_hub("my-awesome-model") # Download and initialize weights from the Hub >>> reloaded_model = MyCustomModel.from_pretrained("username/my-awesome-model") >>> reloaded_model.size 256 # Model card has been correctly populated >>> from huggingface_hub import ModelCard >>> card = ModelCard.load("username/my-awesome-model") >>> card.data.tags ["x-custom-tag", "pytorch_model_hub_mixin", "model_hub_mixin"] >>> card.data.library_name "my-library" ``` """ _hub_mixin_config: Optional[Union[dict, DataclassInstance]] = None # ^ optional config attribute automatically set in `from_pretrained` _hub_mixin_info: MixinInfo # ^ information about the library integrating ModelHubMixin (used to generate model card) _hub_mixin_inject_config: bool # whether `_from_pretrained` expects `config` or not _hub_mixin_init_parameters: Dict[str, inspect.Parameter] # __init__ parameters _hub_mixin_jsonable_default_values: Dict[str, Any] # default values for __init__ parameters _hub_mixin_jsonable_custom_types: Tuple[Type, ...] # custom types that can be encoded/decoded _hub_mixin_coders: Dict[Type, CODER_T] # encoders/decoders for custom types # ^ internal values to handle config def __init_subclass__( cls, *, # Generic info for model card repo_url: Optional[str] = None, paper_url: Optional[str] = None, docs_url: Optional[str] = None, # Model card template model_card_template: str = DEFAULT_MODEL_CARD, # Model card metadata language: Optional[List[str]] = None, library_name: Optional[str] = None, license: Optional[str] = None, license_name: Optional[str] = None, license_link: Optional[str] = None, pipeline_tag: Optional[str] = None, tags: Optional[List[str]] = None, # How to encode/decode arguments with custom type into a JSON config? coders: Optional[ Dict[Type, CODER_T] # Key is a type. # Value is a tuple (encoder, decoder). # Example: {MyCustomType: (lambda x: x.value, lambda data: MyCustomType(data))} ] = None, ) -> None: """Inspect __init__ signature only once when subclassing + handle modelcard.""" super().__init_subclass__() # Will be reused when creating modelcard tags = tags or [] tags.append("model_hub_mixin") # Initialize MixinInfo if not existent info = MixinInfo(model_card_template=model_card_template, model_card_data=ModelCardData()) # If parent class has a MixinInfo, inherit from it as a copy if hasattr(cls, "_hub_mixin_info"): # Inherit model card template from parent class if not explicitly set if model_card_template == DEFAULT_MODEL_CARD: info.model_card_template = cls._hub_mixin_info.model_card_template # Inherit from parent model card data info.model_card_data = ModelCardData(**cls._hub_mixin_info.model_card_data.to_dict()) # Inherit other info info.docs_url = cls._hub_mixin_info.docs_url info.paper_url = cls._hub_mixin_info.paper_url info.repo_url = cls._hub_mixin_info.repo_url cls._hub_mixin_info = info # Update MixinInfo with metadata if model_card_template is not None and model_card_template != DEFAULT_MODEL_CARD: info.model_card_template = model_card_template if repo_url is not None: info.repo_url = repo_url if paper_url is not None: info.paper_url = paper_url if docs_url is not None: info.docs_url = docs_url if language is not None: info.model_card_data.language = language if library_name is not None: info.model_card_data.library_name = library_name if license is not None: info.model_card_data.license = license if license_name is not None: info.model_card_data.license_name = license_name if license_link is not None: info.model_card_data.license_link = license_link if pipeline_tag is not None: info.model_card_data.pipeline_tag = pipeline_tag if tags is not None: if info.model_card_data.tags is not None: info.model_card_data.tags.extend(tags) else: info.model_card_data.tags = tags info.model_card_data.tags = sorted(set(info.model_card_data.tags)) # Handle encoders/decoders for args cls._hub_mixin_coders = coders or {} cls._hub_mixin_jsonable_custom_types = tuple(cls._hub_mixin_coders.keys()) # Inspect __init__ signature to handle config cls._hub_mixin_init_parameters = dict(inspect.signature(cls.__init__).parameters) cls._hub_mixin_jsonable_default_values = { param.name: cls._encode_arg(param.default) for param in cls._hub_mixin_init_parameters.values() if param.default is not inspect.Parameter.empty and cls._is_jsonable(param.default) } cls._hub_mixin_inject_config = "config" in inspect.signature(cls._from_pretrained).parameters def __new__(cls: Type[T], *args, **kwargs) -> T: """Create a new instance of the class and handle config. 3 cases: - If `self._hub_mixin_config` is already set, do nothing. - If `config` is passed as a dataclass, set it as `self._hub_mixin_config`. - Otherwise, build `self._hub_mixin_config` from default values and passed values. """ instance = super().__new__(cls) # If `config` is already set, return early if instance._hub_mixin_config is not None: return instance # Infer passed values passed_values = { **{ key: value for key, value in zip( # [1:] to skip `self` parameter list(cls._hub_mixin_init_parameters)[1:], args, ) }, **kwargs, } # If config passed as dataclass => set it and return early if is_dataclass(passed_values.get("config")): instance._hub_mixin_config = passed_values["config"] return instance # Otherwise, build config from default + passed values init_config = { # default values **cls._hub_mixin_jsonable_default_values, # passed values **{ key: cls._encode_arg(value) # Encode custom types as jsonable value for key, value in passed_values.items() if instance._is_jsonable(value) # Only if jsonable or we have a custom encoder }, } passed_config = init_config.pop("config", {}) # Populate `init_config` with provided config if isinstance(passed_config, dict): init_config.update(passed_config) # Set `config` attribute and return if init_config != {}: instance._hub_mixin_config = init_config return instance @classmethod def _is_jsonable(cls, value: Any) -> bool: """Check if a value is JSON serializable.""" if is_dataclass(value): return True if isinstance(value, cls._hub_mixin_jsonable_custom_types): return True return is_jsonable(value) @classmethod def _encode_arg(cls, arg: Any) -> Any: """Encode an argument into a JSON serializable format.""" if is_dataclass(arg): return asdict(arg) for type_, (encoder, _) in cls._hub_mixin_coders.items(): if isinstance(arg, type_): if arg is None: return None return encoder(arg) return arg @classmethod def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> Optional[ARGS_T]: """Decode a JSON serializable value into an argument.""" if is_simple_optional_type(expected_type): if value is None: return None expected_type = unwrap_simple_optional_type(expected_type) # Dataclass => handle it if is_dataclass(expected_type): return _load_dataclass(expected_type, value) # type: ignore[return-value] # Otherwise => check custom decoders for type_, (_, decoder) in cls._hub_mixin_coders.items(): if inspect.isclass(expected_type) and issubclass(expected_type, type_): return decoder(value) # Otherwise => don't decode return value def save_pretrained( self, save_directory: Union[str, Path], *, config: Optional[Union[dict, DataclassInstance]] = None, repo_id: Optional[str] = None, push_to_hub: bool = False, model_card_kwargs: Optional[Dict[str, Any]] = None, **push_to_hub_kwargs, ) -> Optional[str]: """ Save weights in local directory. Args: save_directory (`str` or `Path`): Path to directory in which the model weights and configuration will be saved. config (`dict` or `DataclassInstance`, *optional*): Model configuration specified as a key/value dictionary or a dataclass instance. push_to_hub (`bool`, *optional*, defaults to `False`): Whether or not to push your model to the Huggingface Hub after saving it. repo_id (`str`, *optional*): ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if not provided. model_card_kwargs (`Dict[str, Any]`, *optional*): Additional arguments passed to the model card template to customize the model card. push_to_hub_kwargs: Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method. Returns: `str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise. """ save_directory = Path(save_directory) save_directory.mkdir(parents=True, exist_ok=True) # Remove config.json if already exists. After `_save_pretrained` we don't want to overwrite config.json # as it might have been saved by the custom `_save_pretrained` already. However we do want to overwrite # an existing config.json if it was not saved by `_save_pretrained`. config_path = save_directory / constants.CONFIG_NAME config_path.unlink(missing_ok=True) # save model weights/files (framework-specific) self._save_pretrained(save_directory) # save config (if provided and if not serialized yet in `_save_pretrained`) if config is None: config = self._hub_mixin_config if config is not None: if is_dataclass(config): config = asdict(config) # type: ignore[arg-type] if not config_path.exists(): config_str = json.dumps(config, sort_keys=True, indent=2) config_path.write_text(config_str) # save model card model_card_path = save_directory / "README.md" model_card_kwargs = model_card_kwargs if model_card_kwargs is not None else {} if not model_card_path.exists(): # do not overwrite if already exists self.generate_model_card(**model_card_kwargs).save(save_directory / "README.md") # push to the Hub if required if push_to_hub: kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input if config is not None: # kwarg for `push_to_hub` kwargs["config"] = config if repo_id is None: repo_id = save_directory.name # Defaults to `save_directory` name return self.push_to_hub(repo_id=repo_id, model_card_kwargs=model_card_kwargs, **kwargs) return None def _save_pretrained(self, save_directory: Path) -> None: """ Overwrite this method in subclass to define how to save your model. Check out our [integration guide](../guides/integrations) for instructions. Args: save_directory (`str` or `Path`): Path to directory in which the model weights and configuration will be saved. """ raise NotImplementedError @classmethod @validate_hf_hub_args def from_pretrained( cls: Type[T], pretrained_model_name_or_path: Union[str, Path], *, force_download: bool = False, resume_download: Optional[bool] = None, proxies: Optional[Dict] = None, token: Optional[Union[str, bool]] = None, cache_dir: Optional[Union[str, Path]] = None, local_files_only: bool = False, revision: Optional[str] = None, **model_kwargs, ) -> T: """ Download a model from the Huggingface Hub and instantiate it. Args: pretrained_model_name_or_path (`str`, `Path`): - Either the `model_id` (string) of a model hosted on the Hub, e.g. `bigscience/bloom`. - Or a path to a `directory` containing model weights saved using [`~transformers.PreTrainedModel.save_pretrained`], e.g., `../path/to/my_model_directory/`. revision (`str`, *optional*): Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. Defaults to the latest commit on `main` branch. force_download (`bool`, *optional*, defaults to `False`): Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding the existing cache. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request. token (`str` or `bool`, *optional*): The token to use as HTTP bearer authorization for remote files. By default, it will use the token cached when running `huggingface-cli login`. cache_dir (`str`, `Path`, *optional*): Path to the folder where cached files are stored. local_files_only (`bool`, *optional*, defaults to `False`): If `True`, avoid downloading the file and return the path to the local cached file if it exists. model_kwargs (`Dict`, *optional*): Additional kwargs to pass to the model during initialization. """ model_id = str(pretrained_model_name_or_path) config_file: Optional[str] = None if os.path.isdir(model_id): if constants.CONFIG_NAME in os.listdir(model_id): config_file = os.path.join(model_id, constants.CONFIG_NAME) else: logger.warning(f"{constants.CONFIG_NAME} not found in {Path(model_id).resolve()}") else: try: config_file = hf_hub_download( repo_id=model_id, filename=constants.CONFIG_NAME, revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, token=token, local_files_only=local_files_only, ) except HfHubHTTPError as e: logger.info(f"{constants.CONFIG_NAME} not found on the HuggingFace Hub: {str(e)}") # Read config config = None if config_file is not None: with open(config_file, "r", encoding="utf-8") as f: config = json.load(f) # Decode custom types in config for key, value in config.items(): if key in cls._hub_mixin_init_parameters: expected_type = cls._hub_mixin_init_parameters[key].annotation if expected_type is not inspect.Parameter.empty: config[key] = cls._decode_arg(expected_type, value) # Populate model_kwargs from config for param in cls._hub_mixin_init_parameters.values(): if param.name not in model_kwargs and param.name in config: model_kwargs[param.name] = config[param.name] # Check if `config` argument was passed at init if "config" in cls._hub_mixin_init_parameters and "config" not in model_kwargs: # Decode `config` argument if it was passed config_annotation = cls._hub_mixin_init_parameters["config"].annotation config = cls._decode_arg(config_annotation, config) # Forward config to model initialization model_kwargs["config"] = config # Inject config if `**kwargs` are expected if is_dataclass(cls): for key in cls.__dataclass_fields__: if key not in model_kwargs and key in config: model_kwargs[key] = config[key] elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()): for key, value in config.items(): if key not in model_kwargs: model_kwargs[key] = value # Finally, also inject if `_from_pretrained` expects it if cls._hub_mixin_inject_config and "config" not in model_kwargs: model_kwargs["config"] = config instance = cls._from_pretrained( model_id=str(model_id), revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, token=token, **model_kwargs, ) # Implicitly set the config as instance attribute if not already set by the class # This way `config` will be available when calling `save_pretrained` or `push_to_hub`. if config is not None and (getattr(instance, "_hub_mixin_config", None) in (None, {})): instance._hub_mixin_config = config return instance @classmethod def _from_pretrained( cls: Type[T], *, model_id: str, revision: Optional[str], cache_dir: Optional[Union[str, Path]], force_download: bool, proxies: Optional[Dict], resume_download: Optional[bool], local_files_only: bool, token: Optional[Union[str, bool]], **model_kwargs, ) -> T: """Overwrite this method in subclass to define how to load your model from pretrained. Use [`hf_hub_download`] or [`snapshot_download`] to download files from the Hub before loading them. Most args taken as input can be directly passed to those 2 methods. If needed, you can add more arguments to this method using "model_kwargs". For example [`PyTorchModelHubMixin._from_pretrained`] takes as input a `map_location` parameter to set on which device the model should be loaded. Check out our [integration guide](../guides/integrations) for more instructions. Args: model_id (`str`): ID of the model to load from the Huggingface Hub (e.g. `bigscience/bloom`). revision (`str`, *optional*): Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. Defaults to the latest commit on `main` branch. force_download (`bool`, *optional*, defaults to `False`): Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding the existing cache. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`). token (`str` or `bool`, *optional*): The token to use as HTTP bearer authorization for remote files. By default, it will use the token cached when running `huggingface-cli login`. cache_dir (`str`, `Path`, *optional*): Path to the folder where cached files are stored. local_files_only (`bool`, *optional*, defaults to `False`): If `True`, avoid downloading the file and return the path to the local cached file if it exists. model_kwargs: Additional keyword arguments passed along to the [`~ModelHubMixin._from_pretrained`] method. """ raise NotImplementedError @validate_hf_hub_args def push_to_hub( self, repo_id: str, *, config: Optional[Union[dict, DataclassInstance]] = None, commit_message: str = "Push model using huggingface_hub.", private: Optional[bool] = None, token: Optional[str] = None, branch: Optional[str] = None, create_pr: Optional[bool] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, delete_patterns: Optional[Union[List[str], str]] = None, model_card_kwargs: Optional[Dict[str, Any]] = None, ) -> str: """ Upload model checkpoint to the Hub. Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more details. Args: repo_id (`str`): ID of the repository to push to (example: `"username/my-model"`). config (`dict` or `DataclassInstance`, *optional*): Model configuration specified as a key/value dictionary or a dataclass instance. commit_message (`str`, *optional*): Message to commit while pushing. private (`bool`, *optional*): Whether the repository created should be private. If `None` (default), the repo will be public unless the organization's default is private. token (`str`, *optional*): The token to use as HTTP bearer authorization for remote files. By default, it will use the token cached when running `huggingface-cli login`. branch (`str`, *optional*): The git branch on which to push the model. This defaults to `"main"`. create_pr (`boolean`, *optional*): Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`. allow_patterns (`List[str]` or `str`, *optional*): If provided, only files matching at least one pattern are pushed. ignore_patterns (`List[str]` or `str`, *optional*): If provided, files matching any of the patterns are not pushed. delete_patterns (`List[str]` or `str`, *optional*): If provided, remote files matching any of the patterns will be deleted from the repo. model_card_kwargs (`Dict[str, Any]`, *optional*): Additional arguments passed to the model card template to customize the model card. Returns: The url of the commit of your model in the given repository. """ api = HfApi(token=token) repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id # Push the files to the repo in a single commit with SoftTemporaryDirectory() as tmp: saved_path = Path(tmp) / repo_id self.save_pretrained(saved_path, config=config, model_card_kwargs=model_card_kwargs) return api.upload_folder( repo_id=repo_id, repo_type="model", folder_path=saved_path, commit_message=commit_message, revision=branch, create_pr=create_pr, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, delete_patterns=delete_patterns, ) def generate_model_card(self, *args, **kwargs) -> ModelCard: card = ModelCard.from_template( card_data=self._hub_mixin_info.model_card_data, template_str=self._hub_mixin_info.model_card_template, repo_url=self._hub_mixin_info.repo_url, paper_url=self._hub_mixin_info.paper_url, docs_url=self._hub_mixin_info.docs_url, **kwargs, ) return card class PyTorchModelHubMixin(ModelHubMixin): """ Implementation of [`ModelHubMixin`] to provide model Hub upload/download capabilities to PyTorch models. The model is set in evaluation mode by default using `model.eval()` (dropout modules are deactivated). To train the model, you should first set it back in training mode with `model.train()`. See [`ModelHubMixin`] for more details on how to use the mixin. Example: ```python >>> import torch >>> import torch.nn as nn >>> from huggingface_hub import PyTorchModelHubMixin >>> class MyModel( ... nn.Module, ... PyTorchModelHubMixin, ... library_name="keras-nlp", ... repo_url="https://github.com/keras-team/keras-nlp", ... paper_url="https://arxiv.org/abs/2304.12244", ... docs_url="https://keras.io/keras_nlp/", ... # ^ optional metadata to generate model card ... ): ... def __init__(self, hidden_size: int = 512, vocab_size: int = 30000, output_size: int = 4): ... super().__init__() ... self.param = nn.Parameter(torch.rand(hidden_size, vocab_size)) ... self.linear = nn.Linear(output_size, vocab_size) ... def forward(self, x): ... return self.linear(x + self.param) >>> model = MyModel(hidden_size=256) # Save model weights to local directory >>> model.save_pretrained("my-awesome-model") # Push model weights to the Hub >>> model.push_to_hub("my-awesome-model") # Download and initialize weights from the Hub >>> model = MyModel.from_pretrained("username/my-awesome-model") >>> model.hidden_size 256 ``` """ def __init_subclass__(cls, *args, tags: Optional[List[str]] = None, **kwargs) -> None: tags = tags or [] tags.append("pytorch_model_hub_mixin") kwargs["tags"] = tags return super().__init_subclass__(*args, **kwargs) def _save_pretrained(self, save_directory: Path) -> None: """Save weights from a Pytorch model to a local directory.""" model_to_save = self.module if hasattr(self, "module") else self # type: ignore save_model_as_safetensor(model_to_save, str(save_directory / constants.SAFETENSORS_SINGLE_FILE)) @classmethod def _from_pretrained( cls, *, model_id: str, revision: Optional[str], cache_dir: Optional[Union[str, Path]], force_download: bool, proxies: Optional[Dict], resume_download: Optional[bool], local_files_only: bool, token: Union[str, bool, None], map_location: str = "cpu", strict: bool = False, **model_kwargs, ): """Load Pytorch pretrained weights and return the loaded model.""" model = cls(**model_kwargs) if os.path.isdir(model_id): print("Loading weights from local directory") model_file = os.path.join(model_id, constants.SAFETENSORS_SINGLE_FILE) return cls._load_as_safetensor(model, model_file, map_location, strict) else: try: model_file = hf_hub_download( repo_id=model_id, filename=constants.SAFETENSORS_SINGLE_FILE, revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, token=token, local_files_only=local_files_only, ) return cls._load_as_safetensor(model, model_file, map_location, strict) except EntryNotFoundError: model_file = hf_hub_download( repo_id=model_id, filename=constants.PYTORCH_WEIGHTS_NAME, revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, token=token, local_files_only=local_files_only, ) return cls._load_as_pickle(model, model_file, map_location, strict) @classmethod def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: bool) -> T: state_dict = torch.load(model_file, map_location=torch.device(map_location), weights_only=True) model.load_state_dict(state_dict, strict=strict) # type: ignore model.eval() # type: ignore return model @classmethod def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T: if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"): # type: ignore [attr-defined] load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type] if map_location != "cpu": logger.warning( "Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors." " This means that the model is loaded on 'cpu' first and then copied to the device." " This leads to a slower loading time." " Please update safetensors to version 0.4.3 or above for improved performance." ) model.to(map_location) # type: ignore [attr-defined] else: safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) # type: ignore [arg-type] return model def _load_dataclass(datacls: Type[DataclassInstance], data: dict) -> DataclassInstance: """Load a dataclass instance from a dictionary. Fields not expected by the dataclass are ignored. """ return datacls(**{k: v for k, v in data.items() if k in datacls.__dataclass_fields__}) huggingface_hub-0.31.1/src/huggingface_hub/inference/000077500000000000000000000000001500667546600225525ustar00rootroot00000000000000huggingface_hub-0.31.1/src/huggingface_hub/inference/__init__.py000066400000000000000000000000001500667546600246510ustar00rootroot00000000000000huggingface_hub-0.31.1/src/huggingface_hub/inference/_client.py000066400000000000000000004731521500667546600245550ustar00rootroot00000000000000# coding=utf-8 # Copyright 2023-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Related resources: # https://huggingface.co/tasks # https://huggingface.co/docs/huggingface.js/inference/README # https://github.com/huggingface/huggingface.js/tree/main/packages/inference/src # https://github.com/huggingface/text-generation-inference/tree/main/clients/python # https://github.com/huggingface/text-generation-inference/blob/main/clients/python/text_generation/client.py # https://huggingface.slack.com/archives/C03E4DQ9LAJ/p1680169099087869 # https://github.com/huggingface/unity-api#tasks # # Some TODO: # - add all tasks # # NOTE: the philosophy of this client is "let's make it as easy as possible to use it, even if less optimized". Some # examples of how it translates: # - Timeout / Server unavailable is handled by the client in a single "timeout" parameter. # - Files can be provided as bytes, file paths, or URLs and the client will try to "guess" the type. # - Images are parsed as PIL.Image for easier manipulation. # - Provides a "recommended model" for each task => suboptimal but user-wise quicker to get a first script running. # - Only the main parameters are publicly exposed. Power users can always read the docs for more options. import base64 import logging import re import warnings from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Union, overload from requests import HTTPError from huggingface_hub import constants from huggingface_hub.errors import BadRequestError, InferenceTimeoutError from huggingface_hub.inference._common import ( TASKS_EXPECTING_IMAGES, ContentT, ModelStatus, RequestParameters, _b64_encode, _b64_to_image, _bytes_to_dict, _bytes_to_image, _bytes_to_list, _get_unsupported_text_generation_kwargs, _import_numpy, _open_as_binary, _set_unsupported_text_generation_kwargs, _stream_chat_completion_response, _stream_text_generation_response, raise_text_generation_error, ) from huggingface_hub.inference._generated.types import ( AudioClassificationOutputElement, AudioClassificationOutputTransform, AudioToAudioOutputElement, AutomaticSpeechRecognitionOutput, ChatCompletionInputGrammarType, ChatCompletionInputStreamOptions, ChatCompletionInputTool, ChatCompletionInputToolChoiceClass, ChatCompletionInputToolChoiceEnum, ChatCompletionOutput, ChatCompletionStreamOutput, DocumentQuestionAnsweringOutputElement, FillMaskOutputElement, ImageClassificationOutputElement, ImageClassificationOutputTransform, ImageSegmentationOutputElement, ImageSegmentationSubtask, ImageToImageTargetSize, ImageToTextOutput, ObjectDetectionOutputElement, Padding, QuestionAnsweringOutputElement, SummarizationOutput, SummarizationTruncationStrategy, TableQuestionAnsweringOutputElement, TextClassificationOutputElement, TextClassificationOutputTransform, TextGenerationInputGrammarType, TextGenerationOutput, TextGenerationStreamOutput, TextToSpeechEarlyStoppingEnum, TokenClassificationAggregationStrategy, TokenClassificationOutputElement, TranslationOutput, TranslationTruncationStrategy, VisualQuestionAnsweringOutputElement, ZeroShotClassificationOutputElement, ZeroShotImageClassificationOutputElement, ) from huggingface_hub.inference._providers import PROVIDER_T, get_provider_helper from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status from huggingface_hub.utils._auth import get_token from huggingface_hub.utils._deprecation import _deprecate_method if TYPE_CHECKING: import numpy as np from PIL.Image import Image logger = logging.getLogger(__name__) MODEL_KWARGS_NOT_USED_REGEX = re.compile(r"The following `model_kwargs` are not used by the model: \[(.*?)\]") class InferenceClient: """ Initialize a new Inference Client. [`InferenceClient`] aims to provide a unified experience to perform inference. The client can be used seamlessly with either the (free) Inference API, self-hosted Inference Endpoints, or third-party Inference Providers. Args: model (`str`, `optional`): The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `meta-llama/Meta-Llama-3-8B-Instruct` or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is automatically selected for the task. Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2 arguments are mutually exclusive. If using `base_url` for chat completion, the `/chat/completions` suffix path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) documentation for details). When passing a URL as `model`, the client will not append any suffix path to it. provider (`str`, *optional*): Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`. Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. If model is a URL or `base_url` is passed, then `provider` is not used. token (`str`, *optional*): Hugging Face token. Will default to the locally saved token if not provided. Note: for better compatibility with OpenAI's client, `token` has been aliased as `api_key`. Those 2 arguments are mutually exclusive and have the exact same behavior. timeout (`float`, `optional`): The maximum number of seconds to wait for a response from the server. Defaults to None, meaning it will loop until the server is available. headers (`Dict[str, str]`, `optional`): Additional headers to send to the server. By default only the authorization and user-agent headers are sent. Values in this dictionary will override the default values. bill_to (`str`, `optional`): The billing account to use for the requests. By default the requests are billed on the user's account. Requests can only be billed to an organization the user is a member of, and which has subscribed to Enterprise Hub. cookies (`Dict[str, str]`, `optional`): Additional cookies to send to the server. proxies (`Any`, `optional`): Proxies to use for the request. base_url (`str`, `optional`): Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`] follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None. api_key (`str`, `optional`): Token to use for authentication. This is a duplicated argument from `token` to make [`InferenceClient`] follow the same pattern as `openai.OpenAI` client. Cannot be used if `token` is set. Defaults to None. """ def __init__( self, model: Optional[str] = None, *, provider: Union[Literal["auto"], PROVIDER_T, None] = None, token: Optional[str] = None, timeout: Optional[float] = None, headers: Optional[Dict[str, str]] = None, cookies: Optional[Dict[str, str]] = None, proxies: Optional[Any] = None, bill_to: Optional[str] = None, # OpenAI compatibility base_url: Optional[str] = None, api_key: Optional[str] = None, ) -> None: if model is not None and base_url is not None: raise ValueError( "Received both `model` and `base_url` arguments. Please provide only one of them." " `base_url` is an alias for `model` to make the API compatible with OpenAI's client." " If using `base_url` for chat completion, the `/chat/completions` suffix path will be appended to the base url." " When passing a URL as `model`, the client will not append any suffix path to it." ) if token is not None and api_key is not None: raise ValueError( "Received both `token` and `api_key` arguments. Please provide only one of them." " `api_key` is an alias for `token` to make the API compatible with OpenAI's client." " It has the exact same behavior as `token`." ) token = token if token is not None else api_key if isinstance(token, bool): # Legacy behavior: previously is was possible to pass `token=False` to disable authentication. This is not # supported anymore as authentication is required. Better to explicitly raise here rather than risking # sending the locally saved token without the user knowing about it. if token is False: raise ValueError( "Cannot use `token=False` to disable authentication as authentication is required to run Inference." ) warnings.warn( "Using `token=True` to automatically use the locally saved token is deprecated and will be removed in a future release. " "Please use `token=None` instead (default).", DeprecationWarning, ) token = get_token() self.model: Optional[str] = base_url or model self.token: Optional[str] = token self.headers = {**headers} if headers is not None else {} if bill_to is not None: if ( constants.HUGGINGFACE_HEADER_X_BILL_TO in self.headers and self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO] != bill_to ): warnings.warn( f"Overriding existing '{self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO]}' value in headers with '{bill_to}'.", UserWarning, ) self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO] = bill_to if token is not None and not token.startswith("hf_"): warnings.warn( "You've provided an external provider's API key, so requests will be billed directly by the provider. " "The `bill_to` parameter is only applicable for Hugging Face billing and will be ignored.", UserWarning, ) # Configure provider self.provider = provider self.cookies = cookies self.timeout = timeout self.proxies = proxies def __repr__(self): return f"" @overload def _inner_post( # type: ignore[misc] self, request_parameters: RequestParameters, *, stream: Literal[False] = ... ) -> bytes: ... @overload def _inner_post( # type: ignore[misc] self, request_parameters: RequestParameters, *, stream: Literal[True] = ... ) -> Iterable[bytes]: ... @overload def _inner_post( self, request_parameters: RequestParameters, *, stream: bool = False ) -> Union[bytes, Iterable[bytes]]: ... def _inner_post( self, request_parameters: RequestParameters, *, stream: bool = False ) -> Union[bytes, Iterable[bytes]]: """Make a request to the inference server.""" # TODO: this should be handled in provider helpers directly if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers: request_parameters.headers["Accept"] = "image/png" with _open_as_binary(request_parameters.data) as data_as_binary: try: response = get_session().post( request_parameters.url, json=request_parameters.json, data=data_as_binary, headers=request_parameters.headers, cookies=self.cookies, timeout=self.timeout, stream=stream, proxies=self.proxies, ) except TimeoutError as error: # Convert any `TimeoutError` to a `InferenceTimeoutError` raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore try: hf_raise_for_status(response) return response.iter_lines() if stream else response.content except HTTPError as error: if error.response.status_code == 422 and request_parameters.task != "unknown": msg = str(error.args[0]) if len(error.response.text) > 0: msg += f"\n{error.response.text}\n" error.args = (msg,) + error.args[1:] raise def audio_classification( self, audio: ContentT, *, model: Optional[str] = None, top_k: Optional[int] = None, function_to_apply: Optional["AudioClassificationOutputTransform"] = None, ) -> List[AudioClassificationOutputElement]: """ Perform audio classification on the provided audio content. Args: audio (Union[str, Path, bytes, BinaryIO]): The audio content to classify. It can be raw audio bytes, a local audio file, or a URL pointing to an audio file. model (`str`, *optional*): The model to use for audio classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for audio classification will be used. top_k (`int`, *optional*): When specified, limits the output to the top K most probable classes. function_to_apply (`"AudioClassificationOutputTransform"`, *optional*): The function to apply to the model outputs in order to retrieve the scores. Returns: `List[AudioClassificationOutputElement]`: List of [`AudioClassificationOutputElement`] items containing the predicted labels and their confidence. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> client.audio_classification("audio.flac") [ AudioClassificationOutputElement(score=0.4976358711719513, label='hap'), AudioClassificationOutputElement(score=0.3677836060523987, label='neu'), ... ] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="audio-classification", model=model_id) request_parameters = provider_helper.prepare_request( inputs=audio, parameters={"function_to_apply": function_to_apply, "top_k": top_k}, headers=self.headers, model=model_id, api_key=self.token, ) response = self._inner_post(request_parameters) return AudioClassificationOutputElement.parse_obj_as_list(response) def audio_to_audio( self, audio: ContentT, *, model: Optional[str] = None, ) -> List[AudioToAudioOutputElement]: """ Performs multiple tasks related to audio-to-audio depending on the model (eg: speech enhancement, source separation). Args: audio (Union[str, Path, bytes, BinaryIO]): The audio content for the model. It can be raw audio bytes, a local audio file, or a URL pointing to an audio file. model (`str`, *optional*): The model can be any model which takes an audio file and returns another audio file. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for audio_to_audio will be used. Returns: `List[AudioToAudioOutputElement]`: A list of [`AudioToAudioOutputElement`] items containing audios label, content-type, and audio content in blob. Raises: `InferenceTimeoutError`: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> audio_output = client.audio_to_audio("audio.flac") >>> for i, item in enumerate(audio_output): >>> with open(f"output_{i}.flac", "wb") as f: f.write(item.blob) ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="audio-to-audio", model=model_id) request_parameters = provider_helper.prepare_request( inputs=audio, parameters={}, headers=self.headers, model=model_id, api_key=self.token, ) response = self._inner_post(request_parameters) audio_output = AudioToAudioOutputElement.parse_obj_as_list(response) for item in audio_output: item.blob = base64.b64decode(item.blob) return audio_output def automatic_speech_recognition( self, audio: ContentT, *, model: Optional[str] = None, extra_body: Optional[Dict] = None, ) -> AutomaticSpeechRecognitionOutput: """ Perform automatic speech recognition (ASR or audio-to-text) on the given audio content. Args: audio (Union[str, Path, bytes, BinaryIO]): The content to transcribe. It can be raw audio bytes, local audio file, or a URL to an audio file. model (`str`, *optional*): The model to use for ASR. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for ASR will be used. extra_body (`Dict`, *optional*): Additional provider-specific parameters to pass to the model. Refer to the provider's documentation for supported parameters. Returns: [`AutomaticSpeechRecognitionOutput`]: An item containing the transcribed text and optionally the timestamp chunks. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> client.automatic_speech_recognition("hello_world.flac").text "hello world" ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="automatic-speech-recognition", model=model_id) request_parameters = provider_helper.prepare_request( inputs=audio, parameters={**(extra_body or {})}, headers=self.headers, model=model_id, api_key=self.token, ) response = self._inner_post(request_parameters) return AutomaticSpeechRecognitionOutput.parse_obj_as_instance(response) @overload def chat_completion( # type: ignore self, messages: List[Dict], *, model: Optional[str] = None, stream: Literal[False] = False, frequency_penalty: Optional[float] = None, logit_bias: Optional[List[float]] = None, logprobs: Optional[bool] = None, max_tokens: Optional[int] = None, n: Optional[int] = None, presence_penalty: Optional[float] = None, response_format: Optional[ChatCompletionInputGrammarType] = None, seed: Optional[int] = None, stop: Optional[List[str]] = None, stream_options: Optional[ChatCompletionInputStreamOptions] = None, temperature: Optional[float] = None, tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None, tool_prompt: Optional[str] = None, tools: Optional[List[ChatCompletionInputTool]] = None, top_logprobs: Optional[int] = None, top_p: Optional[float] = None, extra_body: Optional[Dict] = None, ) -> ChatCompletionOutput: ... @overload def chat_completion( # type: ignore self, messages: List[Dict], *, model: Optional[str] = None, stream: Literal[True] = True, frequency_penalty: Optional[float] = None, logit_bias: Optional[List[float]] = None, logprobs: Optional[bool] = None, max_tokens: Optional[int] = None, n: Optional[int] = None, presence_penalty: Optional[float] = None, response_format: Optional[ChatCompletionInputGrammarType] = None, seed: Optional[int] = None, stop: Optional[List[str]] = None, stream_options: Optional[ChatCompletionInputStreamOptions] = None, temperature: Optional[float] = None, tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None, tool_prompt: Optional[str] = None, tools: Optional[List[ChatCompletionInputTool]] = None, top_logprobs: Optional[int] = None, top_p: Optional[float] = None, extra_body: Optional[Dict] = None, ) -> Iterable[ChatCompletionStreamOutput]: ... @overload def chat_completion( self, messages: List[Dict], *, model: Optional[str] = None, stream: bool = False, frequency_penalty: Optional[float] = None, logit_bias: Optional[List[float]] = None, logprobs: Optional[bool] = None, max_tokens: Optional[int] = None, n: Optional[int] = None, presence_penalty: Optional[float] = None, response_format: Optional[ChatCompletionInputGrammarType] = None, seed: Optional[int] = None, stop: Optional[List[str]] = None, stream_options: Optional[ChatCompletionInputStreamOptions] = None, temperature: Optional[float] = None, tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None, tool_prompt: Optional[str] = None, tools: Optional[List[ChatCompletionInputTool]] = None, top_logprobs: Optional[int] = None, top_p: Optional[float] = None, extra_body: Optional[Dict] = None, ) -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]: ... def chat_completion( self, messages: List[Dict], *, model: Optional[str] = None, stream: bool = False, # Parameters from ChatCompletionInput (handled manually) frequency_penalty: Optional[float] = None, logit_bias: Optional[List[float]] = None, logprobs: Optional[bool] = None, max_tokens: Optional[int] = None, n: Optional[int] = None, presence_penalty: Optional[float] = None, response_format: Optional[ChatCompletionInputGrammarType] = None, seed: Optional[int] = None, stop: Optional[List[str]] = None, stream_options: Optional[ChatCompletionInputStreamOptions] = None, temperature: Optional[float] = None, tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None, tool_prompt: Optional[str] = None, tools: Optional[List[ChatCompletionInputTool]] = None, top_logprobs: Optional[int] = None, top_p: Optional[float] = None, extra_body: Optional[Dict] = None, ) -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]: """ A method for completing conversations using a specified language model. The `client.chat_completion` method is aliased as `client.chat.completions.create` for compatibility with OpenAI's client. Inputs and outputs are strictly the same and using either syntax will yield the same results. Check out the [Inference guide](https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility) for more details about OpenAI's compatibility. You can pass provider-specific parameters to the model by using the `extra_body` argument. Args: messages (List of [`ChatCompletionInputMessage`]): Conversation history consisting of roles and content pairs. model (`str`, *optional*): The model to use for chat-completion. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for chat-based text-generation will be used. See https://huggingface.co/tasks/text-generation for more details. If `model` is a model ID, it is passed to the server as the `model` parameter. If you want to define a custom URL while setting `model` in the request payload, you must set `base_url` when initializing [`InferenceClient`]. frequency_penalty (`float`, *optional*): Penalizes new tokens based on their existing frequency in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0. logit_bias (`List[float]`, *optional*): Adjusts the likelihood of specific tokens appearing in the generated output. logprobs (`bool`, *optional*): Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message. max_tokens (`int`, *optional*): Maximum number of tokens allowed in the response. Defaults to 100. n (`int`, *optional*): The number of completions to generate for each prompt. presence_penalty (`float`, *optional*): Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. response_format ([`ChatCompletionInputGrammarType`], *optional*): Grammar constraints. Can be either a JSONSchema or a regex. seed (Optional[`int`], *optional*): Seed for reproducible control flow. Defaults to None. stop (`List[str]`, *optional*): Up to four strings which trigger the end of the response. Defaults to None. stream (`bool`, *optional*): Enable realtime streaming of responses. Defaults to False. stream_options ([`ChatCompletionInputStreamOptions`], *optional*): Options for streaming completions. temperature (`float`, *optional*): Controls randomness of the generations. Lower values ensure less random completions. Range: [0, 2]. Defaults to 1.0. top_logprobs (`int`, *optional*): An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with an associated log probability. logprobs must be set to true if this parameter is used. top_p (`float`, *optional*): Fraction of the most likely next words to sample from. Must be between 0 and 1. Defaults to 1.0. tool_choice ([`ChatCompletionInputToolChoiceClass`] or [`ChatCompletionInputToolChoiceEnum`], *optional*): The tool to use for the completion. Defaults to "auto". tool_prompt (`str`, *optional*): A prompt to be appended before the tools. tools (List of [`ChatCompletionInputTool`], *optional*): A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. extra_body (`Dict`, *optional*): Additional provider-specific parameters to pass to the model. Refer to the provider's documentation for supported parameters. Returns: [`ChatCompletionOutput`] or Iterable of [`ChatCompletionStreamOutput`]: Generated text returned from the server: - if `stream=False`, the generated text is returned as a [`ChatCompletionOutput`] (default). - if `stream=True`, the generated text is returned token by token as a sequence of [`ChatCompletionStreamOutput`]. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py >>> from huggingface_hub import InferenceClient >>> messages = [{"role": "user", "content": "What is the capital of France?"}] >>> client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct") >>> client.chat_completion(messages, max_tokens=100) ChatCompletionOutput( choices=[ ChatCompletionOutputComplete( finish_reason='eos_token', index=0, message=ChatCompletionOutputMessage( role='assistant', content='The capital of France is Paris.', name=None, tool_calls=None ), logprobs=None ) ], created=1719907176, id='', model='meta-llama/Meta-Llama-3-8B-Instruct', object='text_completion', system_fingerprint='2.0.4-sha-f426a33', usage=ChatCompletionOutputUsage( completion_tokens=8, prompt_tokens=17, total_tokens=25 ) ) ``` Example using streaming: ```py >>> from huggingface_hub import InferenceClient >>> messages = [{"role": "user", "content": "What is the capital of France?"}] >>> client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct") >>> for token in client.chat_completion(messages, max_tokens=10, stream=True): ... print(token) ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content='The', role='assistant'), index=0, finish_reason=None)], created=1710498504) ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' capital', role='assistant'), index=0, finish_reason=None)], created=1710498504) (...) ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' may', role='assistant'), index=0, finish_reason=None)], created=1710498504) ``` Example using OpenAI's syntax: ```py # instead of `from openai import OpenAI` from huggingface_hub import InferenceClient # instead of `client = OpenAI(...)` client = InferenceClient( base_url=..., api_key=..., ) output = client.chat.completions.create( model="meta-llama/Meta-Llama-3-8B-Instruct", messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Count to 10"}, ], stream=True, max_tokens=1024, ) for chunk in output: print(chunk.choices[0].delta.content) ``` Example using a third-party provider directly with extra (provider-specific) parameters. Usage will be billed on your Together AI account. ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient( ... provider="together", # Use Together AI provider ... api_key="", # Pass your Together API key directly ... ) >>> client.chat_completion( ... model="meta-llama/Meta-Llama-3-8B-Instruct", ... messages=[{"role": "user", "content": "What is the capital of France?"}], ... extra_body={"safety_model": "Meta-Llama/Llama-Guard-7b"}, ... ) ``` Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account. ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient( ... provider="sambanova", # Use Sambanova provider ... api_key="hf_...", # Pass your HF token ... ) >>> client.chat_completion( ... model="meta-llama/Meta-Llama-3-8B-Instruct", ... messages=[{"role": "user", "content": "What is the capital of France?"}], ... ) ``` Example using Image + Text as input: ```py >>> from huggingface_hub import InferenceClient # provide a remote URL >>> image_url ="https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" # or a base64-encoded image >>> image_path = "/path/to/image.jpeg" >>> with open(image_path, "rb") as f: ... base64_image = base64.b64encode(f.read()).decode("utf-8") >>> image_url = f"data:image/jpeg;base64,{base64_image}" >>> client = InferenceClient("meta-llama/Llama-3.2-11B-Vision-Instruct") >>> output = client.chat.completions.create( ... messages=[ ... { ... "role": "user", ... "content": [ ... { ... "type": "image_url", ... "image_url": {"url": image_url}, ... }, ... { ... "type": "text", ... "text": "Describe this image in one sentence.", ... }, ... ], ... }, ... ], ... ) >>> output The image depicts the iconic Statue of Liberty situated in New York Harbor, New York, on a clear day. ``` Example using tools: ```py >>> client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") >>> messages = [ ... { ... "role": "system", ... "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.", ... }, ... { ... "role": "user", ... "content": "What's the weather like the next 3 days in San Francisco, CA?", ... }, ... ] >>> tools = [ ... { ... "type": "function", ... "function": { ... "name": "get_current_weather", ... "description": "Get the current weather", ... "parameters": { ... "type": "object", ... "properties": { ... "location": { ... "type": "string", ... "description": "The city and state, e.g. San Francisco, CA", ... }, ... "format": { ... "type": "string", ... "enum": ["celsius", "fahrenheit"], ... "description": "The temperature unit to use. Infer this from the users location.", ... }, ... }, ... "required": ["location", "format"], ... }, ... }, ... }, ... { ... "type": "function", ... "function": { ... "name": "get_n_day_weather_forecast", ... "description": "Get an N-day weather forecast", ... "parameters": { ... "type": "object", ... "properties": { ... "location": { ... "type": "string", ... "description": "The city and state, e.g. San Francisco, CA", ... }, ... "format": { ... "type": "string", ... "enum": ["celsius", "fahrenheit"], ... "description": "The temperature unit to use. Infer this from the users location.", ... }, ... "num_days": { ... "type": "integer", ... "description": "The number of days to forecast", ... }, ... }, ... "required": ["location", "format", "num_days"], ... }, ... }, ... }, ... ] >>> response = client.chat_completion( ... model="meta-llama/Meta-Llama-3-70B-Instruct", ... messages=messages, ... tools=tools, ... tool_choice="auto", ... max_tokens=500, ... ) >>> response.choices[0].message.tool_calls[0].function ChatCompletionOutputFunctionDefinition( arguments={ 'location': 'San Francisco, CA', 'format': 'fahrenheit', 'num_days': 3 }, name='get_n_day_weather_forecast', description=None ) ``` Example using response_format: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") >>> messages = [ ... { ... "role": "user", ... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I saw and when?", ... }, ... ] >>> response_format = { ... "type": "json", ... "value": { ... "properties": { ... "location": {"type": "string"}, ... "activity": {"type": "string"}, ... "animals_seen": {"type": "integer", "minimum": 1, "maximum": 5}, ... "animals": {"type": "array", "items": {"type": "string"}}, ... }, ... "required": ["location", "activity", "animals_seen", "animals"], ... }, ... } >>> response = client.chat_completion( ... messages=messages, ... response_format=response_format, ... max_tokens=500, ... ) >>> response.choices[0].message.content '{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}' ``` """ # Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently. # `self.model` takes precedence over 'model' argument for building URL. # `model` takes precedence for payload value. model_id_or_url = self.model or model payload_model = model or self.model # Get the provider helper provider_helper = get_provider_helper( self.provider, task="conversational", model=model_id_or_url if model_id_or_url is not None and model_id_or_url.startswith(("http://", "https://")) else payload_model, ) # Prepare the payload parameters = { "model": payload_model, "frequency_penalty": frequency_penalty, "logit_bias": logit_bias, "logprobs": logprobs, "max_tokens": max_tokens, "n": n, "presence_penalty": presence_penalty, "response_format": response_format, "seed": seed, "stop": stop, "temperature": temperature, "tool_choice": tool_choice, "tool_prompt": tool_prompt, "tools": tools, "top_logprobs": top_logprobs, "top_p": top_p, "stream": stream, "stream_options": stream_options, **(extra_body or {}), } request_parameters = provider_helper.prepare_request( inputs=messages, parameters=parameters, headers=self.headers, model=model_id_or_url, api_key=self.token, ) data = self._inner_post(request_parameters, stream=stream) if stream: return _stream_chat_completion_response(data) # type: ignore[arg-type] return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type] def document_question_answering( self, image: ContentT, question: str, *, model: Optional[str] = None, doc_stride: Optional[int] = None, handle_impossible_answer: Optional[bool] = None, lang: Optional[str] = None, max_answer_len: Optional[int] = None, max_question_len: Optional[int] = None, max_seq_len: Optional[int] = None, top_k: Optional[int] = None, word_boxes: Optional[List[Union[List[float], str]]] = None, ) -> List[DocumentQuestionAnsweringOutputElement]: """ Answer questions on document images. Args: image (`Union[str, Path, bytes, BinaryIO]`): The input image for the context. It can be raw bytes, an image file, or a URL to an online image. question (`str`): Question to be answered. model (`str`, *optional*): The model to use for the document question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended document question answering model will be used. Defaults to None. doc_stride (`int`, *optional*): If the words in the document are too long to fit with the question for the model, it will be split in several chunks with some overlap. This argument controls the size of that overlap. handle_impossible_answer (`bool`, *optional*): Whether to accept impossible as an answer lang (`str`, *optional*): Language to use while running OCR. Defaults to english. max_answer_len (`int`, *optional*): The maximum length of predicted answers (e.g., only answers with a shorter length are considered). max_question_len (`int`, *optional*): The maximum length of the question after tokenization. It will be truncated if needed. max_seq_len (`int`, *optional*): The maximum length of the total sentence (context + question) in tokens of each chunk passed to the model. The context will be split in several chunks (using doc_stride as overlap) if needed. top_k (`int`, *optional*): The number of answers to return (will be chosen by order of likelihood). Can return less than top_k answers if there are not enough options available within the context. word_boxes (`List[Union[List[float], str`, *optional*): A list of words and bounding boxes (normalized 0->1000). If provided, the inference will skip the OCR step and use the provided bounding boxes instead. Returns: `List[DocumentQuestionAnsweringOutputElement]`: a list of [`DocumentQuestionAnsweringOutputElement`] items containing the predicted label, associated probability, word ids, and page number. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> client.document_question_answering(image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png", question="What is the invoice number?") [DocumentQuestionAnsweringOutputElement(answer='us-001', end=16, score=0.9999666213989258, start=16)] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="document-question-answering", model=model_id) inputs: Dict[str, Any] = {"question": question, "image": _b64_encode(image)} request_parameters = provider_helper.prepare_request( inputs=inputs, parameters={ "doc_stride": doc_stride, "handle_impossible_answer": handle_impossible_answer, "lang": lang, "max_answer_len": max_answer_len, "max_question_len": max_question_len, "max_seq_len": max_seq_len, "top_k": top_k, "word_boxes": word_boxes, }, headers=self.headers, model=model_id, api_key=self.token, ) response = self._inner_post(request_parameters) return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response) def feature_extraction( self, text: str, *, normalize: Optional[bool] = None, prompt_name: Optional[str] = None, truncate: Optional[bool] = None, truncation_direction: Optional[Literal["Left", "Right"]] = None, model: Optional[str] = None, ) -> "np.ndarray": """ Generate embeddings for a given text. Args: text (`str`): The text to embed. model (`str`, *optional*): The model to use for the feature extraction task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended feature extraction model will be used. Defaults to None. normalize (`bool`, *optional*): Whether to normalize the embeddings or not. Only available on server powered by Text-Embedding-Inference. prompt_name (`str`, *optional*): The name of the prompt that should be used by for encoding. If not set, no prompt will be applied. Must be a key in the `Sentence Transformers` configuration `prompts` dictionary. For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ",...}, then the sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?" because the prompt text will be prepended before any text to encode. truncate (`bool`, *optional*): Whether to truncate the embeddings or not. Only available on server powered by Text-Embedding-Inference. truncation_direction (`Literal["Left", "Right"]`, *optional*): Which side of the input should be truncated when `truncate=True` is passed. Returns: `np.ndarray`: The embedding representing the input text as a float32 numpy array. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> client.feature_extraction("Hi, who are you?") array([[ 2.424802 , 2.93384 , 1.1750331 , ..., 1.240499, -0.13776633, -0.7889173 ], [-0.42943227, -0.6364878 , -1.693462 , ..., 0.41978157, -2.4336355 , 0.6162071 ], ..., [ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32) ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="feature-extraction", model=model_id) request_parameters = provider_helper.prepare_request( inputs=text, parameters={ "normalize": normalize, "prompt_name": prompt_name, "truncate": truncate, "truncation_direction": truncation_direction, }, headers=self.headers, model=model_id, api_key=self.token, ) response = self._inner_post(request_parameters) np = _import_numpy() return np.array(provider_helper.get_response(response), dtype="float32") def fill_mask( self, text: str, *, model: Optional[str] = None, targets: Optional[List[str]] = None, top_k: Optional[int] = None, ) -> List[FillMaskOutputElement]: """ Fill in a hole with a missing word (token to be precise). Args: text (`str`): a string to be filled from, must contain the [MASK] token (check model card for exact name of the mask). model (`str`, *optional*): The model to use for the fill mask task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended fill mask model will be used. targets (`List[str`, *optional*): When passed, the model will limit the scores to the passed targets instead of looking up in the whole vocabulary. If the provided targets are not in the model vocab, they will be tokenized and the first resulting token will be used (with a warning, and that might be slower). top_k (`int`, *optional*): When passed, overrides the number of predictions to return. Returns: `List[FillMaskOutputElement]`: a list of [`FillMaskOutputElement`] items containing the predicted label, associated probability, token reference, and completed text. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> client.fill_mask("The goal of life is .") [ FillMaskOutputElement(score=0.06897063553333282, token=11098, token_str=' happiness', sequence='The goal of life is happiness.'), FillMaskOutputElement(score=0.06554922461509705, token=45075, token_str=' immortality', sequence='The goal of life is immortality.') ] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="fill-mask", model=model_id) request_parameters = provider_helper.prepare_request( inputs=text, parameters={"targets": targets, "top_k": top_k}, headers=self.headers, model=model_id, api_key=self.token, ) response = self._inner_post(request_parameters) return FillMaskOutputElement.parse_obj_as_list(response) def image_classification( self, image: ContentT, *, model: Optional[str] = None, function_to_apply: Optional["ImageClassificationOutputTransform"] = None, top_k: Optional[int] = None, ) -> List[ImageClassificationOutputElement]: """ Perform image classification on the given image using the specified model. Args: image (`Union[str, Path, bytes, BinaryIO]`): The image to classify. It can be raw bytes, an image file, or a URL to an online image. model (`str`, *optional*): The model to use for image classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used. function_to_apply (`"ImageClassificationOutputTransform"`, *optional*): The function to apply to the model outputs in order to retrieve the scores. top_k (`int`, *optional*): When specified, limits the output to the top K most probable classes. Returns: `List[ImageClassificationOutputElement]`: a list of [`ImageClassificationOutputElement`] items containing the predicted label and associated probability. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> client.image_classification("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg") [ImageClassificationOutputElement(label='Blenheim spaniel', score=0.9779096841812134), ...] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="image-classification", model=model_id) request_parameters = provider_helper.prepare_request( inputs=image, parameters={"function_to_apply": function_to_apply, "top_k": top_k}, headers=self.headers, model=model_id, api_key=self.token, ) response = self._inner_post(request_parameters) return ImageClassificationOutputElement.parse_obj_as_list(response) def image_segmentation( self, image: ContentT, *, model: Optional[str] = None, mask_threshold: Optional[float] = None, overlap_mask_area_threshold: Optional[float] = None, subtask: Optional["ImageSegmentationSubtask"] = None, threshold: Optional[float] = None, ) -> List[ImageSegmentationOutputElement]: """ Perform image segmentation on the given image using the specified model. You must have `PIL` installed if you want to work with images (`pip install Pillow`). Args: image (`Union[str, Path, bytes, BinaryIO]`): The image to segment. It can be raw bytes, an image file, or a URL to an online image. model (`str`, *optional*): The model to use for image segmentation. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for image segmentation will be used. mask_threshold (`float`, *optional*): Threshold to use when turning the predicted masks into binary values. overlap_mask_area_threshold (`float`, *optional*): Mask overlap threshold to eliminate small, disconnected segments. subtask (`"ImageSegmentationSubtask"`, *optional*): Segmentation task to be performed, depending on model capabilities. threshold (`float`, *optional*): Probability threshold to filter out predicted masks. Returns: `List[ImageSegmentationOutputElement]`: A list of [`ImageSegmentationOutputElement`] items containing the segmented masks and associated attributes. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> client.image_segmentation("cat.jpg") [ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=), ...] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="image-segmentation", model=model_id) request_parameters = provider_helper.prepare_request( inputs=image, parameters={ "mask_threshold": mask_threshold, "overlap_mask_area_threshold": overlap_mask_area_threshold, "subtask": subtask, "threshold": threshold, }, headers=self.headers, model=model_id, api_key=self.token, ) response = self._inner_post(request_parameters) output = ImageSegmentationOutputElement.parse_obj_as_list(response) for item in output: item.mask = _b64_to_image(item.mask) # type: ignore [assignment] return output def image_to_image( self, image: ContentT, prompt: Optional[str] = None, *, negative_prompt: Optional[str] = None, num_inference_steps: Optional[int] = None, guidance_scale: Optional[float] = None, model: Optional[str] = None, target_size: Optional[ImageToImageTargetSize] = None, **kwargs, ) -> "Image": """ Perform image-to-image translation using a specified model. You must have `PIL` installed if you want to work with images (`pip install Pillow`). Args: image (`Union[str, Path, bytes, BinaryIO]`): The input image for translation. It can be raw bytes, an image file, or a URL to an online image. prompt (`str`, *optional*): The text prompt to guide the image generation. negative_prompt (`str`, *optional*): One prompt to guide what NOT to include in image generation. num_inference_steps (`int`, *optional*): For diffusion models. The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*): For diffusion models. A higher guidance scale value encourages the model to generate images closely linked to the text prompt at the expense of lower image quality. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. target_size (`ImageToImageTargetSize`, *optional*): The size in pixel of the output image. Returns: `Image`: The translated image. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> image = client.image_to_image("cat.jpg", prompt="turn the cat into a tiger") >>> image.save("tiger.jpg") ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="image-to-image", model=model_id) request_parameters = provider_helper.prepare_request( inputs=image, parameters={ "prompt": prompt, "negative_prompt": negative_prompt, "target_size": target_size, "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, **kwargs, }, headers=self.headers, model=model_id, api_key=self.token, ) response = self._inner_post(request_parameters) return _bytes_to_image(response) def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput: """ Takes an input image and return text. Models can have very different outputs depending on your use case (image captioning, optical character recognition (OCR), Pix2Struct, etc). Please have a look to the model card to learn more about a model's specificities. Args: image (`Union[str, Path, bytes, BinaryIO]`): The input image to caption. It can be raw bytes, an image file, or a URL to an online image.. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. Returns: [`ImageToTextOutput`]: The generated text. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> client.image_to_text("cat.jpg") 'a cat standing in a grassy field ' >>> client.image_to_text("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg") 'a dog laying on the grass next to a flower pot ' ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="image-to-text", model=model_id) request_parameters = provider_helper.prepare_request( inputs=image, parameters={}, headers=self.headers, model=model_id, api_key=self.token, ) response = self._inner_post(request_parameters) output = ImageToTextOutput.parse_obj(response) return output[0] if isinstance(output, list) else output def object_detection( self, image: ContentT, *, model: Optional[str] = None, threshold: Optional[float] = None ) -> List[ObjectDetectionOutputElement]: """ Perform object detection on the given image using the specified model. You must have `PIL` installed if you want to work with images (`pip install Pillow`). Args: image (`Union[str, Path, bytes, BinaryIO]`): The image to detect objects on. It can be raw bytes, an image file, or a URL to an online image. model (`str`, *optional*): The model to use for object detection. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for object detection (DETR) will be used. threshold (`float`, *optional*): The probability necessary to make a prediction. Returns: `List[ObjectDetectionOutputElement]`: A list of [`ObjectDetectionOutputElement`] items containing the bounding boxes and associated attributes. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. `ValueError`: If the request output is not a List. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> client.object_detection("people.jpg") [ObjectDetectionOutputElement(score=0.9486683011054993, label='person', box=ObjectDetectionBoundingBox(xmin=59, ymin=39, xmax=420, ymax=510)), ...] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="object-detection", model=model_id) request_parameters = provider_helper.prepare_request( inputs=image, parameters={"threshold": threshold}, headers=self.headers, model=model_id, api_key=self.token, ) response = self._inner_post(request_parameters) return ObjectDetectionOutputElement.parse_obj_as_list(response) def question_answering( self, question: str, context: str, *, model: Optional[str] = None, align_to_words: Optional[bool] = None, doc_stride: Optional[int] = None, handle_impossible_answer: Optional[bool] = None, max_answer_len: Optional[int] = None, max_question_len: Optional[int] = None, max_seq_len: Optional[int] = None, top_k: Optional[int] = None, ) -> Union[QuestionAnsweringOutputElement, List[QuestionAnsweringOutputElement]]: """ Retrieve the answer to a question from a given text. Args: question (`str`): Question to be answered. context (`str`): The context of the question. model (`str`): The model to use for the question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. align_to_words (`bool`, *optional*): Attempts to align the answer to real words. Improves quality on space separated languages. Might hurt on non-space-separated languages (like Japanese or Chinese) doc_stride (`int`, *optional*): If the context is too long to fit with the question for the model, it will be split in several chunks with some overlap. This argument controls the size of that overlap. handle_impossible_answer (`bool`, *optional*): Whether to accept impossible as an answer. max_answer_len (`int`, *optional*): The maximum length of predicted answers (e.g., only answers with a shorter length are considered). max_question_len (`int`, *optional*): The maximum length of the question after tokenization. It will be truncated if needed. max_seq_len (`int`, *optional*): The maximum length of the total sentence (context + question) in tokens of each chunk passed to the model. The context will be split in several chunks (using docStride as overlap) if needed. top_k (`int`, *optional*): The number of answers to return (will be chosen by order of likelihood). Note that we return less than topk answers if there are not enough options available within the context. Returns: Union[`QuestionAnsweringOutputElement`, List[`QuestionAnsweringOutputElement`]]: When top_k is 1 or not provided, it returns a single `QuestionAnsweringOutputElement`. When top_k is greater than 1, it returns a list of `QuestionAnsweringOutputElement`. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> client.question_answering(question="What's my name?", context="My name is Clara and I live in Berkeley.") QuestionAnsweringOutputElement(answer='Clara', end=16, score=0.9326565265655518, start=11) ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="question-answering", model=model_id) request_parameters = provider_helper.prepare_request( inputs=None, parameters={ "align_to_words": align_to_words, "doc_stride": doc_stride, "handle_impossible_answer": handle_impossible_answer, "max_answer_len": max_answer_len, "max_question_len": max_question_len, "max_seq_len": max_seq_len, "top_k": top_k, }, extra_payload={"question": question, "context": context}, headers=self.headers, model=model_id, api_key=self.token, ) response = self._inner_post(request_parameters) # Parse the response as a single `QuestionAnsweringOutputElement` when top_k is 1 or not provided, or a list of `QuestionAnsweringOutputElement` to ensure backward compatibility. output = QuestionAnsweringOutputElement.parse_obj(response) return output def sentence_similarity( self, sentence: str, other_sentences: List[str], *, model: Optional[str] = None ) -> List[float]: """ Compute the semantic similarity between a sentence and a list of other sentences by comparing their embeddings. Args: sentence (`str`): The main sentence to compare to others. other_sentences (`List[str]`): The list of sentences to compare to. model (`str`, *optional*): The model to use for the sentence similarity task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended sentence similarity model will be used. Defaults to None. Returns: `List[float]`: The embedding representing the input text. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> client.sentence_similarity( ... "Machine learning is so easy.", ... other_sentences=[ ... "Deep learning is so straightforward.", ... "This is so difficult, like rocket science.", ... "I can't believe how much I struggled with this.", ... ], ... ) [0.7785726189613342, 0.45876261591911316, 0.2906220555305481] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="sentence-similarity", model=model_id) request_parameters = provider_helper.prepare_request( inputs={"source_sentence": sentence, "sentences": other_sentences}, parameters={}, extra_payload={}, headers=self.headers, model=model_id, api_key=self.token, ) response = self._inner_post(request_parameters) return _bytes_to_list(response) def summarization( self, text: str, *, model: Optional[str] = None, clean_up_tokenization_spaces: Optional[bool] = None, generate_parameters: Optional[Dict[str, Any]] = None, truncation: Optional["SummarizationTruncationStrategy"] = None, ) -> SummarizationOutput: """ Generate a summary of a given text using a specified model. Args: text (`str`): The input text to summarize. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for summarization will be used. clean_up_tokenization_spaces (`bool`, *optional*): Whether to clean up the potential extra spaces in the text output. generate_parameters (`Dict[str, Any]`, *optional*): Additional parametrization of the text generation algorithm. truncation (`"SummarizationTruncationStrategy"`, *optional*): The truncation strategy to use. Returns: [`SummarizationOutput`]: The generated summary text. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> client.summarization("The Eiffel tower...") SummarizationOutput(generated_text="The Eiffel tower is one of the most famous landmarks in the world....") ``` """ parameters = { "clean_up_tokenization_spaces": clean_up_tokenization_spaces, "generate_parameters": generate_parameters, "truncation": truncation, } model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="summarization", model=model_id) request_parameters = provider_helper.prepare_request( inputs=text, parameters=parameters, headers=self.headers, model=model_id, api_key=self.token, ) response = self._inner_post(request_parameters) return SummarizationOutput.parse_obj_as_list(response)[0] def table_question_answering( self, table: Dict[str, Any], query: str, *, model: Optional[str] = None, padding: Optional["Padding"] = None, sequential: Optional[bool] = None, truncation: Optional[bool] = None, ) -> TableQuestionAnsweringOutputElement: """ Retrieve the answer to a question from information given in a table. Args: table (`str`): A table of data represented as a dict of lists where entries are headers and the lists are all the values, all lists must have the same size. query (`str`): The query in plain text that you want to ask the table. model (`str`): The model to use for the table-question-answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. padding (`"Padding"`, *optional*): Activates and controls padding. sequential (`bool`, *optional*): Whether to do inference sequentially or as a batch. Batching is faster, but models like SQA require the inference to be done sequentially to extract relations within sequences, given their conversational nature. truncation (`bool`, *optional*): Activates and controls truncation. Returns: [`TableQuestionAnsweringOutputElement`]: a table question answering output containing the answer, coordinates, cells and the aggregator used. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> query = "How many stars does the transformers repository have?" >>> table = {"Repository": ["Transformers", "Datasets", "Tokenizers"], "Stars": ["36542", "4512", "3934"]} >>> client.table_question_answering(table, query, model="google/tapas-base-finetuned-wtq") TableQuestionAnsweringOutputElement(answer='36542', coordinates=[[0, 1]], cells=['36542'], aggregator='AVERAGE') ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="table-question-answering", model=model_id) request_parameters = provider_helper.prepare_request( inputs=None, parameters={"model": model, "padding": padding, "sequential": sequential, "truncation": truncation}, extra_payload={"query": query, "table": table}, headers=self.headers, model=model_id, api_key=self.token, ) response = self._inner_post(request_parameters) return TableQuestionAnsweringOutputElement.parse_obj_as_instance(response) def tabular_classification(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[str]: """ Classifying a target category (a group) based on a set of attributes. Args: table (`Dict[str, Any]`): Set of attributes to classify. model (`str`, *optional*): The model to use for the tabular classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended tabular classification model will be used. Defaults to None. Returns: `List`: a list of labels, one per row in the initial table. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> table = { ... "fixed_acidity": ["7.4", "7.8", "10.3"], ... "volatile_acidity": ["0.7", "0.88", "0.32"], ... "citric_acid": ["0", "0", "0.45"], ... "residual_sugar": ["1.9", "2.6", "6.4"], ... "chlorides": ["0.076", "0.098", "0.073"], ... "free_sulfur_dioxide": ["11", "25", "5"], ... "total_sulfur_dioxide": ["34", "67", "13"], ... "density": ["0.9978", "0.9968", "0.9976"], ... "pH": ["3.51", "3.2", "3.23"], ... "sulphates": ["0.56", "0.68", "0.82"], ... "alcohol": ["9.4", "9.8", "12.6"], ... } >>> client.tabular_classification(table=table, model="julien-c/wine-quality") ["5", "5", "5"] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="tabular-classification", model=model_id) request_parameters = provider_helper.prepare_request( inputs=None, extra_payload={"table": table}, parameters={}, headers=self.headers, model=model_id, api_key=self.token, ) response = self._inner_post(request_parameters) return _bytes_to_list(response) def tabular_regression(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[float]: """ Predicting a numerical target value given a set of attributes/features in a table. Args: table (`Dict[str, Any]`): Set of attributes stored in a table. The attributes used to predict the target can be both numerical and categorical. model (`str`, *optional*): The model to use for the tabular regression task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended tabular regression model will be used. Defaults to None. Returns: `List`: a list of predicted numerical target values. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> table = { ... "Height": ["11.52", "12.48", "12.3778"], ... "Length1": ["23.2", "24", "23.9"], ... "Length2": ["25.4", "26.3", "26.5"], ... "Length3": ["30", "31.2", "31.1"], ... "Species": ["Bream", "Bream", "Bream"], ... "Width": ["4.02", "4.3056", "4.6961"], ... } >>> client.tabular_regression(table, model="scikit-learn/Fish-Weight") [110, 120, 130] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="tabular-regression", model=model_id) request_parameters = provider_helper.prepare_request( inputs=None, parameters={}, extra_payload={"table": table}, headers=self.headers, model=model_id, api_key=self.token, ) response = self._inner_post(request_parameters) return _bytes_to_list(response) def text_classification( self, text: str, *, model: Optional[str] = None, top_k: Optional[int] = None, function_to_apply: Optional["TextClassificationOutputTransform"] = None, ) -> List[TextClassificationOutputElement]: """ Perform text classification (e.g. sentiment-analysis) on the given text. Args: text (`str`): A string to be classified. model (`str`, *optional*): The model to use for the text classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended text classification model will be used. Defaults to None. top_k (`int`, *optional*): When specified, limits the output to the top K most probable classes. function_to_apply (`"TextClassificationOutputTransform"`, *optional*): The function to apply to the model outputs in order to retrieve the scores. Returns: `List[TextClassificationOutputElement]`: a list of [`TextClassificationOutputElement`] items containing the predicted label and associated probability. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> client.text_classification("I like you") [ TextClassificationOutputElement(label='POSITIVE', score=0.9998695850372314), TextClassificationOutputElement(label='NEGATIVE', score=0.0001304351753788069), ] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="text-classification", model=model_id) request_parameters = provider_helper.prepare_request( inputs=text, parameters={ "function_to_apply": function_to_apply, "top_k": top_k, }, headers=self.headers, model=model_id, api_key=self.token, ) response = self._inner_post(request_parameters) return TextClassificationOutputElement.parse_obj_as_list(response)[0] # type: ignore [return-value] @overload def text_generation( # type: ignore self, prompt: str, *, details: Literal[False] = ..., stream: Literal[False] = ..., model: Optional[str] = None, # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) adapter_id: Optional[str] = None, best_of: Optional[int] = None, decoder_input_details: Optional[bool] = None, do_sample: Optional[bool] = False, # Manual default value frequency_penalty: Optional[float] = None, grammar: Optional[TextGenerationInputGrammarType] = None, max_new_tokens: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = False, # Manual default value seed: Optional[int] = None, stop: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, top_p: Optional[float] = None, truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: Optional[bool] = None, ) -> str: ... @overload def text_generation( # type: ignore self, prompt: str, *, details: Literal[True] = ..., stream: Literal[False] = ..., model: Optional[str] = None, # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) adapter_id: Optional[str] = None, best_of: Optional[int] = None, decoder_input_details: Optional[bool] = None, do_sample: Optional[bool] = False, # Manual default value frequency_penalty: Optional[float] = None, grammar: Optional[TextGenerationInputGrammarType] = None, max_new_tokens: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = False, # Manual default value seed: Optional[int] = None, stop: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, top_p: Optional[float] = None, truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: Optional[bool] = None, ) -> TextGenerationOutput: ... @overload def text_generation( # type: ignore self, prompt: str, *, details: Literal[False] = ..., stream: Literal[True] = ..., model: Optional[str] = None, # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) adapter_id: Optional[str] = None, best_of: Optional[int] = None, decoder_input_details: Optional[bool] = None, do_sample: Optional[bool] = False, # Manual default value frequency_penalty: Optional[float] = None, grammar: Optional[TextGenerationInputGrammarType] = None, max_new_tokens: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = False, # Manual default value seed: Optional[int] = None, stop: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, top_p: Optional[float] = None, truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: Optional[bool] = None, ) -> Iterable[str]: ... @overload def text_generation( # type: ignore self, prompt: str, *, details: Literal[True] = ..., stream: Literal[True] = ..., model: Optional[str] = None, # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) adapter_id: Optional[str] = None, best_of: Optional[int] = None, decoder_input_details: Optional[bool] = None, do_sample: Optional[bool] = False, # Manual default value frequency_penalty: Optional[float] = None, grammar: Optional[TextGenerationInputGrammarType] = None, max_new_tokens: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = False, # Manual default value seed: Optional[int] = None, stop: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, top_p: Optional[float] = None, truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: Optional[bool] = None, ) -> Iterable[TextGenerationStreamOutput]: ... @overload def text_generation( self, prompt: str, *, details: Literal[True] = ..., stream: bool = ..., model: Optional[str] = None, # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) adapter_id: Optional[str] = None, best_of: Optional[int] = None, decoder_input_details: Optional[bool] = None, do_sample: Optional[bool] = False, # Manual default value frequency_penalty: Optional[float] = None, grammar: Optional[TextGenerationInputGrammarType] = None, max_new_tokens: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = False, # Manual default value seed: Optional[int] = None, stop: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, top_p: Optional[float] = None, truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: Optional[bool] = None, ) -> Union[TextGenerationOutput, Iterable[TextGenerationStreamOutput]]: ... def text_generation( self, prompt: str, *, details: bool = False, stream: bool = False, model: Optional[str] = None, # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) adapter_id: Optional[str] = None, best_of: Optional[int] = None, decoder_input_details: Optional[bool] = None, do_sample: Optional[bool] = False, # Manual default value frequency_penalty: Optional[float] = None, grammar: Optional[TextGenerationInputGrammarType] = None, max_new_tokens: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = False, # Manual default value seed: Optional[int] = None, stop: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, top_p: Optional[float] = None, truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: Optional[bool] = None, ) -> Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]: """ Given a prompt, generate the following text. If you want to generate a response from chat messages, you should use the [`InferenceClient.chat_completion`] method. It accepts a list of messages instead of a single text prompt and handles the chat templating for you. Args: prompt (`str`): Input text. details (`bool`, *optional*): By default, text_generation returns a string. Pass `details=True` if you want a detailed output (tokens, probabilities, seed, finish reason, etc.). Only available for models running on with the `text-generation-inference` backend. stream (`bool`, *optional*): By default, text_generation returns the full generated text. Pass `stream=True` if you want a stream of tokens to be returned. Only available for models running on with the `text-generation-inference` backend. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. adapter_id (`str`, *optional*): Lora adapter id. best_of (`int`, *optional*): Generate best_of sequences and return the one if the highest token logprobs. decoder_input_details (`bool`, *optional*): Return the decoder input token logprobs and ids. You must set `details=True` as well for it to be taken into account. Defaults to `False`. do_sample (`bool`, *optional*): Activate logits sampling frequency_penalty (`float`, *optional*): Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. grammar ([`TextGenerationInputGrammarType`], *optional*): Grammar constraints. Can be either a JSONSchema or a regex. max_new_tokens (`int`, *optional*): Maximum number of generated tokens. Defaults to 100. repetition_penalty (`float`, *optional*): The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. return_full_text (`bool`, *optional*): Whether to prepend the prompt to the generated text seed (`int`, *optional*): Random sampling seed stop (`List[str]`, *optional*): Stop generating tokens if a member of `stop` is generated. stop_sequences (`List[str]`, *optional*): Deprecated argument. Use `stop` instead. temperature (`float`, *optional*): The value used to module the logits distribution. top_n_tokens (`int`, *optional*): Return information about the `top_n_tokens` most likely tokens at each generation step, instead of just the sampled token. top_k (`int`, *optional`): The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p (`float`, *optional`): If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. truncate (`int`, *optional`): Truncate inputs tokens to the given size. typical_p (`float`, *optional`): Typical Decoding mass See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`, *optional`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Returns: `Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]`: Generated text returned from the server: - if `stream=False` and `details=False`, the generated text is returned as a `str` (default) - if `stream=True` and `details=False`, the generated text is returned token by token as a `Iterable[str]` - if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.TextGenerationOutput`] - if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.TextGenerationStreamOutput`] Raises: `ValidationError`: If input values are not valid. No HTTP call is made to the server. [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() # Case 1: generate text >>> client.text_generation("The huggingface_hub library is ", max_new_tokens=12) '100% open source and built to be easy to use.' # Case 2: iterate over the generated tokens. Useful for large generation. >>> for token in client.text_generation("The huggingface_hub library is ", max_new_tokens=12, stream=True): ... print(token) 100 % open source and built to be easy to use . # Case 3: get more details about the generation process. >>> client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True) TextGenerationOutput( generated_text='100% open source and built to be easy to use.', details=TextGenerationDetails( finish_reason='length', generated_tokens=12, seed=None, prefill=[ TextGenerationPrefillOutputToken(id=487, text='The', logprob=None), TextGenerationPrefillOutputToken(id=53789, text=' hugging', logprob=-13.171875), (...) TextGenerationPrefillOutputToken(id=204, text=' ', logprob=-7.0390625) ], tokens=[ TokenElement(id=1425, text='100', logprob=-1.0175781, special=False), TokenElement(id=16, text='%', logprob=-0.0463562, special=False), (...) TokenElement(id=25, text='.', logprob=-0.5703125, special=False) ], best_of_sequences=None ) ) # Case 4: iterate over the generated tokens with more details. # Last object is more complete, containing the full generated text and the finish reason. >>> for details in client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True, stream=True): ... print(details) ... TextGenerationStreamOutput(token=TokenElement(id=1425, text='100', logprob=-1.0175781, special=False), generated_text=None, details=None) TextGenerationStreamOutput(token=TokenElement(id=16, text='%', logprob=-0.0463562, special=False), generated_text=None, details=None) TextGenerationStreamOutput(token=TokenElement(id=1314, text=' open', logprob=-1.3359375, special=False), generated_text=None, details=None) TextGenerationStreamOutput(token=TokenElement(id=3178, text=' source', logprob=-0.28100586, special=False), generated_text=None, details=None) TextGenerationStreamOutput(token=TokenElement(id=273, text=' and', logprob=-0.5961914, special=False), generated_text=None, details=None) TextGenerationStreamOutput(token=TokenElement(id=3426, text=' built', logprob=-1.9423828, special=False), generated_text=None, details=None) TextGenerationStreamOutput(token=TokenElement(id=271, text=' to', logprob=-1.4121094, special=False), generated_text=None, details=None) TextGenerationStreamOutput(token=TokenElement(id=314, text=' be', logprob=-1.5224609, special=False), generated_text=None, details=None) TextGenerationStreamOutput(token=TokenElement(id=1833, text=' easy', logprob=-2.1132812, special=False), generated_text=None, details=None) TextGenerationStreamOutput(token=TokenElement(id=271, text=' to', logprob=-0.08520508, special=False), generated_text=None, details=None) TextGenerationStreamOutput(token=TokenElement(id=745, text=' use', logprob=-0.39453125, special=False), generated_text=None, details=None) TextGenerationStreamOutput(token=TokenElement( id=25, text='.', logprob=-0.5703125, special=False), generated_text='100% open source and built to be easy to use.', details=TextGenerationStreamOutputStreamDetails(finish_reason='length', generated_tokens=12, seed=None) ) # Case 5: generate constrained output using grammar >>> response = client.text_generation( ... prompt="I saw a puppy a cat and a raccoon during my bike ride in the park", ... model="HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1", ... max_new_tokens=100, ... repetition_penalty=1.3, ... grammar={ ... "type": "json", ... "value": { ... "properties": { ... "location": {"type": "string"}, ... "activity": {"type": "string"}, ... "animals_seen": {"type": "integer", "minimum": 1, "maximum": 5}, ... "animals": {"type": "array", "items": {"type": "string"}}, ... }, ... "required": ["location", "activity", "animals_seen", "animals"], ... }, ... }, ... ) >>> json.loads(response) { "activity": "bike riding", "animals": ["puppy", "cat", "raccoon"], "animals_seen": 3, "location": "park" } ``` """ if decoder_input_details and not details: warnings.warn( "`decoder_input_details=True` has been passed to the server but `details=False` is set meaning that" " the output from the server will be truncated." ) decoder_input_details = False if stop_sequences is not None: warnings.warn( "`stop_sequences` is a deprecated argument for `text_generation` task" " and will be removed in version '0.28.0'. Use `stop` instead.", FutureWarning, ) if stop is None: stop = stop_sequences # use deprecated arg if provided # Build payload parameters = { "adapter_id": adapter_id, "best_of": best_of, "decoder_input_details": decoder_input_details, "details": details, "do_sample": do_sample, "frequency_penalty": frequency_penalty, "grammar": grammar, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, "return_full_text": return_full_text, "seed": seed, "stop": stop if stop is not None else [], "temperature": temperature, "top_k": top_k, "top_n_tokens": top_n_tokens, "top_p": top_p, "truncate": truncate, "typical_p": typical_p, "watermark": watermark, } # Remove some parameters if not a TGI server unsupported_kwargs = _get_unsupported_text_generation_kwargs(model) if len(unsupported_kwargs) > 0: # The server does not support some parameters # => means it is not a TGI server # => remove unsupported parameters and warn the user ignored_parameters = [] for key in unsupported_kwargs: if parameters.get(key): ignored_parameters.append(key) parameters.pop(key, None) if len(ignored_parameters) > 0: warnings.warn( "API endpoint/model for text-generation is not served via TGI. Ignoring following parameters:" f" {', '.join(ignored_parameters)}.", UserWarning, ) if details: warnings.warn( "API endpoint/model for text-generation is not served via TGI. Parameter `details=True` will" " be ignored meaning only the generated text will be returned.", UserWarning, ) details = False if stream: raise ValueError( "API endpoint/model for text-generation is not served via TGI. Cannot return output as a stream." " Please pass `stream=False` as input." ) model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="text-generation", model=model_id) request_parameters = provider_helper.prepare_request( inputs=prompt, parameters=parameters, extra_payload={"stream": stream}, headers=self.headers, model=model_id, api_key=self.token, ) # Handle errors separately for more precise error messages try: bytes_output = self._inner_post(request_parameters, stream=stream) except HTTPError as e: match = MODEL_KWARGS_NOT_USED_REGEX.search(str(e)) if isinstance(e, BadRequestError) and match: unused_params = [kwarg.strip("' ") for kwarg in match.group(1).split(",")] _set_unsupported_text_generation_kwargs(model, unused_params) return self.text_generation( # type: ignore prompt=prompt, details=details, stream=stream, model=model_id, adapter_id=adapter_id, best_of=best_of, decoder_input_details=decoder_input_details, do_sample=do_sample, frequency_penalty=frequency_penalty, grammar=grammar, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, return_full_text=return_full_text, seed=seed, stop=stop, temperature=temperature, top_k=top_k, top_n_tokens=top_n_tokens, top_p=top_p, truncate=truncate, typical_p=typical_p, watermark=watermark, ) raise_text_generation_error(e) # Parse output if stream: return _stream_text_generation_response(bytes_output, details) # type: ignore data = _bytes_to_dict(bytes_output) # type: ignore[arg-type] # Data can be a single element (dict) or an iterable of dicts where we select the first element of. if isinstance(data, list): data = data[0] response = provider_helper.get_response(data, request_parameters) return TextGenerationOutput.parse_obj_as_instance(response) if details else response["generated_text"] def text_to_image( self, prompt: str, *, negative_prompt: Optional[str] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: Optional[int] = None, guidance_scale: Optional[float] = None, model: Optional[str] = None, scheduler: Optional[str] = None, seed: Optional[int] = None, extra_body: Optional[Dict[str, Any]] = None, ) -> "Image": """ Generate an image based on a given text using a specified model. You must have `PIL` installed if you want to work with images (`pip install Pillow`). You can pass provider-specific parameters to the model by using the `extra_body` argument. Args: prompt (`str`): The prompt to generate an image from. negative_prompt (`str`, *optional*): One prompt to guide what NOT to include in image generation. height (`int`, *optional*): The height in pixels of the output image width (`int`, *optional*): The width in pixels of the output image num_inference_steps (`int`, *optional*): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*): A higher guidance scale value encourages the model to generate images closely linked to the text prompt, but values too high may cause saturation and other artifacts. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended text-to-image model will be used. Defaults to None. scheduler (`str`, *optional*): Override the scheduler with a compatible one. seed (`int`, *optional*): Seed for the random number generator. extra_body (`Dict[str, Any]`, *optional*): Additional provider-specific parameters to pass to the model. Refer to the provider's documentation for supported parameters. Returns: `Image`: The generated image. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> image = client.text_to_image("An astronaut riding a horse on the moon.") >>> image.save("astronaut.png") >>> image = client.text_to_image( ... "An astronaut riding a horse on the moon.", ... negative_prompt="low resolution, blurry", ... model="stabilityai/stable-diffusion-2-1", ... ) >>> image.save("better_astronaut.png") ``` Example using a third-party provider directly. Usage will be billed on your fal.ai account. ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient( ... provider="fal-ai", # Use fal.ai provider ... api_key="fal-ai-api-key", # Pass your fal.ai API key ... ) >>> image = client.text_to_image( ... "A majestic lion in a fantasy forest", ... model="black-forest-labs/FLUX.1-schnell", ... ) >>> image.save("lion.png") ``` Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account. ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient( ... provider="replicate", # Use replicate provider ... api_key="hf_...", # Pass your HF token ... ) >>> image = client.text_to_image( ... "An astronaut riding a horse on the moon.", ... model="black-forest-labs/FLUX.1-dev", ... ) >>> image.save("astronaut.png") ``` Example using Replicate provider with extra parameters ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient( ... provider="replicate", # Use replicate provider ... api_key="hf_...", # Pass your HF token ... ) >>> image = client.text_to_image( ... "An astronaut riding a horse on the moon.", ... model="black-forest-labs/FLUX.1-schnell", ... extra_body={"output_quality": 100}, ... ) >>> image.save("astronaut.png") ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="text-to-image", model=model_id) request_parameters = provider_helper.prepare_request( inputs=prompt, parameters={ "negative_prompt": negative_prompt, "height": height, "width": width, "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, "scheduler": scheduler, "seed": seed, **(extra_body or {}), }, headers=self.headers, model=model_id, api_key=self.token, ) response = self._inner_post(request_parameters) response = provider_helper.get_response(response) return _bytes_to_image(response) def text_to_video( self, prompt: str, *, model: Optional[str] = None, guidance_scale: Optional[float] = None, negative_prompt: Optional[List[str]] = None, num_frames: Optional[float] = None, num_inference_steps: Optional[int] = None, seed: Optional[int] = None, extra_body: Optional[Dict[str, Any]] = None, ) -> bytes: """ Generate a video based on a given text. You can pass provider-specific parameters to the model by using the `extra_body` argument. Args: prompt (`str`): The prompt to generate a video from. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended text-to-video model will be used. Defaults to None. guidance_scale (`float`, *optional*): A higher guidance scale value encourages the model to generate videos closely linked to the text prompt, but values too high may cause saturation and other artifacts. negative_prompt (`List[str]`, *optional*): One or several prompt to guide what NOT to include in video generation. num_frames (`float`, *optional*): The num_frames parameter determines how many video frames are generated. num_inference_steps (`int`, *optional*): The number of denoising steps. More denoising steps usually lead to a higher quality video at the expense of slower inference. seed (`int`, *optional*): Seed for the random number generator. extra_body (`Dict[str, Any]`, *optional*): Additional provider-specific parameters to pass to the model. Refer to the provider's documentation for supported parameters. Returns: `bytes`: The generated video. Example: Example using a third-party provider directly. Usage will be billed on your fal.ai account. ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient( ... provider="fal-ai", # Using fal.ai provider ... api_key="fal-ai-api-key", # Pass your fal.ai API key ... ) >>> video = client.text_to_video( ... "A majestic lion running in a fantasy forest", ... model="tencent/HunyuanVideo", ... ) >>> with open("lion.mp4", "wb") as file: ... file.write(video) ``` Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account. ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient( ... provider="replicate", # Using replicate provider ... api_key="hf_...", # Pass your HF token ... ) >>> video = client.text_to_video( ... "A cat running in a park", ... model="genmo/mochi-1-preview", ... ) >>> with open("cat.mp4", "wb") as file: ... file.write(video) ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="text-to-video", model=model_id) request_parameters = provider_helper.prepare_request( inputs=prompt, parameters={ "guidance_scale": guidance_scale, "negative_prompt": negative_prompt, "num_frames": num_frames, "num_inference_steps": num_inference_steps, "seed": seed, **(extra_body or {}), }, headers=self.headers, model=model_id, api_key=self.token, ) response = self._inner_post(request_parameters) response = provider_helper.get_response(response, request_parameters) return response def text_to_speech( self, text: str, *, model: Optional[str] = None, do_sample: Optional[bool] = None, early_stopping: Optional[Union[bool, "TextToSpeechEarlyStoppingEnum"]] = None, epsilon_cutoff: Optional[float] = None, eta_cutoff: Optional[float] = None, max_length: Optional[int] = None, max_new_tokens: Optional[int] = None, min_length: Optional[int] = None, min_new_tokens: Optional[int] = None, num_beam_groups: Optional[int] = None, num_beams: Optional[int] = None, penalty_alpha: Optional[float] = None, temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, typical_p: Optional[float] = None, use_cache: Optional[bool] = None, extra_body: Optional[Dict[str, Any]] = None, ) -> bytes: """ Synthesize an audio of a voice pronouncing a given text. You can pass provider-specific parameters to the model by using the `extra_body` argument. Args: text (`str`): The text to synthesize. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended text-to-speech model will be used. Defaults to None. do_sample (`bool`, *optional*): Whether to use sampling instead of greedy decoding when generating new tokens. early_stopping (`Union[bool, "TextToSpeechEarlyStoppingEnum"]`, *optional*): Controls the stopping condition for beam-based methods. epsilon_cutoff (`float`, *optional*): If set to float strictly between 0 and 1, only tokens with a conditional probability greater than epsilon_cutoff will be sampled. In the paper, suggested values range from 3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191) for more details. eta_cutoff (`float`, *optional*): Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to float strictly between 0 and 1, a token is only considered if it is greater than either eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model. See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191) for more details. max_length (`int`, *optional*): The maximum length (in tokens) of the generated text, including the input. max_new_tokens (`int`, *optional*): The maximum number of tokens to generate. Takes precedence over max_length. min_length (`int`, *optional*): The minimum length (in tokens) of the generated text, including the input. min_new_tokens (`int`, *optional*): The minimum number of tokens to generate. Takes precedence over min_length. num_beam_groups (`int`, *optional*): Number of groups to divide num_beams into in order to ensure diversity among different groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details. num_beams (`int`, *optional*): Number of beams to use for beam search. penalty_alpha (`float`, *optional*): The value balances the model confidence and the degeneration penalty in contrastive search decoding. temperature (`float`, *optional*): The value used to modulate the next token probabilities. top_k (`int`, *optional*): The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p (`float`, *optional*): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. typical_p (`float`, *optional*): Local typicality measures how similar the conditional probability of predicting a target token next is to the expected conditional probability of predicting a random token next, given the partial text already generated. If set to float < 1, the smallest set of the most locally typical tokens with probabilities that add up to typical_p or higher are kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details. use_cache (`bool`, *optional*): Whether the model should use the past last key/values attentions to speed up decoding extra_body (`Dict[str, Any]`, *optional*): Additional provider-specific parameters to pass to the model. Refer to the provider's documentation for supported parameters. Returns: `bytes`: The generated audio. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py >>> from pathlib import Path >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> audio = client.text_to_speech("Hello world") >>> Path("hello_world.flac").write_bytes(audio) ``` Example using a third-party provider directly. Usage will be billed on your Replicate account. ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient( ... provider="replicate", ... api_key="your-replicate-api-key", # Pass your Replicate API key directly ... ) >>> audio = client.text_to_speech( ... text="Hello world", ... model="OuteAI/OuteTTS-0.3-500M", ... ) >>> Path("hello_world.flac").write_bytes(audio) ``` Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account. ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient( ... provider="replicate", ... api_key="hf_...", # Pass your HF token ... ) >>> audio =client.text_to_speech( ... text="Hello world", ... model="OuteAI/OuteTTS-0.3-500M", ... ) >>> Path("hello_world.flac").write_bytes(audio) ``` Example using Replicate provider with extra parameters ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient( ... provider="replicate", # Use replicate provider ... api_key="hf_...", # Pass your HF token ... ) >>> audio = client.text_to_speech( ... "Hello, my name is Kororo, an awesome text-to-speech model.", ... model="hexgrad/Kokoro-82M", ... extra_body={"voice": "af_nicole"}, ... ) >>> Path("hello.flac").write_bytes(audio) ``` Example music-gen using "YuE-s1-7B-anneal-en-cot" on fal.ai ```py >>> from huggingface_hub import InferenceClient >>> lyrics = ''' ... [verse] ... In the town where I was born ... Lived a man who sailed to sea ... And he told us of his life ... In the land of submarines ... So we sailed on to the sun ... 'Til we found a sea of green ... And we lived beneath the waves ... In our yellow submarine ... [chorus] ... We all live in a yellow submarine ... Yellow submarine, yellow submarine ... We all live in a yellow submarine ... Yellow submarine, yellow submarine ... ''' >>> genres = "pavarotti-style tenor voice" >>> client = InferenceClient( ... provider="fal-ai", ... model="m-a-p/YuE-s1-7B-anneal-en-cot", ... api_key=..., ... ) >>> audio = client.text_to_speech(lyrics, extra_body={"genres": genres}) >>> with open("output.mp3", "wb") as f: ... f.write(audio) ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="text-to-speech", model=model_id) request_parameters = provider_helper.prepare_request( inputs=text, parameters={ "do_sample": do_sample, "early_stopping": early_stopping, "epsilon_cutoff": epsilon_cutoff, "eta_cutoff": eta_cutoff, "max_length": max_length, "max_new_tokens": max_new_tokens, "min_length": min_length, "min_new_tokens": min_new_tokens, "num_beam_groups": num_beam_groups, "num_beams": num_beams, "penalty_alpha": penalty_alpha, "temperature": temperature, "top_k": top_k, "top_p": top_p, "typical_p": typical_p, "use_cache": use_cache, **(extra_body or {}), }, headers=self.headers, model=model_id, api_key=self.token, ) response = self._inner_post(request_parameters) response = provider_helper.get_response(response) return response def token_classification( self, text: str, *, model: Optional[str] = None, aggregation_strategy: Optional["TokenClassificationAggregationStrategy"] = None, ignore_labels: Optional[List[str]] = None, stride: Optional[int] = None, ) -> List[TokenClassificationOutputElement]: """ Perform token classification on the given text. Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text. Args: text (`str`): A string to be classified. model (`str`, *optional*): The model to use for the token classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended token classification model will be used. Defaults to None. aggregation_strategy (`"TokenClassificationAggregationStrategy"`, *optional*): The strategy used to fuse tokens based on model predictions ignore_labels (`List[str`, *optional*): A list of labels to ignore stride (`int`, *optional*): The number of overlapping tokens between chunks when splitting the input text. Returns: `List[TokenClassificationOutputElement]`: List of [`TokenClassificationOutputElement`] items containing the entity group, confidence score, word, start and end index. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> client.token_classification("My name is Sarah Jessica Parker but you can call me Jessica") [ TokenClassificationOutputElement( entity_group='PER', score=0.9971321225166321, word='Sarah Jessica Parker', start=11, end=31, ), TokenClassificationOutputElement( entity_group='PER', score=0.9773476123809814, word='Jessica', start=52, end=59, ) ] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="token-classification", model=model_id) request_parameters = provider_helper.prepare_request( inputs=text, parameters={ "aggregation_strategy": aggregation_strategy, "ignore_labels": ignore_labels, "stride": stride, }, headers=self.headers, model=model_id, api_key=self.token, ) response = self._inner_post(request_parameters) return TokenClassificationOutputElement.parse_obj_as_list(response) def translation( self, text: str, *, model: Optional[str] = None, src_lang: Optional[str] = None, tgt_lang: Optional[str] = None, clean_up_tokenization_spaces: Optional[bool] = None, truncation: Optional["TranslationTruncationStrategy"] = None, generate_parameters: Optional[Dict[str, Any]] = None, ) -> TranslationOutput: """ Convert text from one language to another. Check out https://huggingface.co/tasks/translation for more information on how to choose the best model for your specific use case. Source and target languages usually depend on the model. However, it is possible to specify source and target languages for certain models. If you are working with one of these models, you can use `src_lang` and `tgt_lang` arguments to pass the relevant information. Args: text (`str`): A string to be translated. model (`str`, *optional*): The model to use for the translation task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended translation model will be used. Defaults to None. src_lang (`str`, *optional*): The source language of the text. Required for models that can translate from multiple languages. tgt_lang (`str`, *optional*): Target language to translate to. Required for models that can translate to multiple languages. clean_up_tokenization_spaces (`bool`, *optional*): Whether to clean up the potential extra spaces in the text output. truncation (`"TranslationTruncationStrategy"`, *optional*): The truncation strategy to use. generate_parameters (`Dict[str, Any]`, *optional*): Additional parametrization of the text generation algorithm. Returns: [`TranslationOutput`]: The generated translated text. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. `ValueError`: If only one of the `src_lang` and `tgt_lang` arguments are provided. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> client.translation("My name is Wolfgang and I live in Berlin") 'Mein Name ist Wolfgang und ich lebe in Berlin.' >>> client.translation("My name is Wolfgang and I live in Berlin", model="Helsinki-NLP/opus-mt-en-fr") TranslationOutput(translation_text='Je m'appelle Wolfgang et je vis à Berlin.') ``` Specifying languages: ```py >>> client.translation("My name is Sarah Jessica Parker but you can call me Jessica", model="facebook/mbart-large-50-many-to-many-mmt", src_lang="en_XX", tgt_lang="fr_XX") "Mon nom est Sarah Jessica Parker mais vous pouvez m'appeler Jessica" ``` """ # Throw error if only one of `src_lang` and `tgt_lang` was given if src_lang is not None and tgt_lang is None: raise ValueError("You cannot specify `src_lang` without specifying `tgt_lang`.") if src_lang is None and tgt_lang is not None: raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.") model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="translation", model=model_id) request_parameters = provider_helper.prepare_request( inputs=text, parameters={ "src_lang": src_lang, "tgt_lang": tgt_lang, "clean_up_tokenization_spaces": clean_up_tokenization_spaces, "truncation": truncation, "generate_parameters": generate_parameters, }, headers=self.headers, model=model_id, api_key=self.token, ) response = self._inner_post(request_parameters) return TranslationOutput.parse_obj_as_list(response)[0] def visual_question_answering( self, image: ContentT, question: str, *, model: Optional[str] = None, top_k: Optional[int] = None, ) -> List[VisualQuestionAnsweringOutputElement]: """ Answering open-ended questions based on an image. Args: image (`Union[str, Path, bytes, BinaryIO]`): The input image for the context. It can be raw bytes, an image file, or a URL to an online image. question (`str`): Question to be answered. model (`str`, *optional*): The model to use for the visual question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended visual question answering model will be used. Defaults to None. top_k (`int`, *optional*): The number of answers to return (will be chosen by order of likelihood). Note that we return less than topk answers if there are not enough options available within the context. Returns: `List[VisualQuestionAnsweringOutputElement]`: a list of [`VisualQuestionAnsweringOutputElement`] items containing the predicted label and associated probability. Raises: `InferenceTimeoutError`: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> client.visual_question_answering( ... image="https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg", ... question="What is the animal doing?" ... ) [ VisualQuestionAnsweringOutputElement(score=0.778609573841095, answer='laying down'), VisualQuestionAnsweringOutputElement(score=0.6957435607910156, answer='sitting'), ] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="visual-question-answering", model=model_id) request_parameters = provider_helper.prepare_request( inputs=image, parameters={"top_k": top_k}, headers=self.headers, model=model_id, api_key=self.token, extra_payload={"question": question, "image": _b64_encode(image)}, ) response = self._inner_post(request_parameters) return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response) def zero_shot_classification( self, text: str, candidate_labels: List[str], *, multi_label: Optional[bool] = False, hypothesis_template: Optional[str] = None, model: Optional[str] = None, ) -> List[ZeroShotClassificationOutputElement]: """ Provide as input a text and a set of candidate labels to classify the input text. Args: text (`str`): The input text to classify. candidate_labels (`List[str]`): The set of possible class labels to classify the text into. labels (`List[str]`, *optional*): (deprecated) List of strings. Each string is the verbalization of a possible label for the input text. multi_label (`bool`, *optional*): Whether multiple candidate labels can be true. If false, the scores are normalized such that the sum of the label likelihoods for each sequence is 1. If true, the labels are considered independent and probabilities are normalized for each candidate. hypothesis_template (`str`, *optional*): The sentence used in conjunction with `candidate_labels` to attempt the text classification by replacing the placeholder with the candidate labels. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. This parameter overrides the model defined at the instance level. If not provided, the default recommended zero-shot classification model will be used. Returns: `List[ZeroShotClassificationOutputElement]`: List of [`ZeroShotClassificationOutputElement`] items containing the predicted labels and their confidence. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. Example with `multi_label=False`: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> text = ( ... "A new model offers an explanation for how the Galilean satellites formed around the solar system's" ... "largest world. Konstantin Batygin did not set out to solve one of the solar system's most puzzling" ... " mysteries when he went for a run up a hill in Nice, France." ... ) >>> labels = ["space & cosmos", "scientific discovery", "microbiology", "robots", "archeology"] >>> client.zero_shot_classification(text, labels) [ ZeroShotClassificationOutputElement(label='scientific discovery', score=0.7961668968200684), ZeroShotClassificationOutputElement(label='space & cosmos', score=0.18570658564567566), ZeroShotClassificationOutputElement(label='microbiology', score=0.00730885099619627), ZeroShotClassificationOutputElement(label='archeology', score=0.006258360575884581), ZeroShotClassificationOutputElement(label='robots', score=0.004559356719255447), ] >>> client.zero_shot_classification(text, labels, multi_label=True) [ ZeroShotClassificationOutputElement(label='scientific discovery', score=0.9829297661781311), ZeroShotClassificationOutputElement(label='space & cosmos', score=0.755190908908844), ZeroShotClassificationOutputElement(label='microbiology', score=0.0005462635890580714), ZeroShotClassificationOutputElement(label='archeology', score=0.00047131875180639327), ZeroShotClassificationOutputElement(label='robots', score=0.00030448526376858354), ] ``` Example with `multi_label=True` and a custom `hypothesis_template`: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> client.zero_shot_classification( ... text="I really like our dinner and I'm very happy. I don't like the weather though.", ... labels=["positive", "negative", "pessimistic", "optimistic"], ... multi_label=True, ... hypothesis_template="This text is {} towards the weather" ... ) [ ZeroShotClassificationOutputElement(label='negative', score=0.9231801629066467), ZeroShotClassificationOutputElement(label='pessimistic', score=0.8760990500450134), ZeroShotClassificationOutputElement(label='optimistic', score=0.0008674879791215062), ZeroShotClassificationOutputElement(label='positive', score=0.0005250611575320363) ] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="zero-shot-classification", model=model_id) request_parameters = provider_helper.prepare_request( inputs=text, parameters={ "candidate_labels": candidate_labels, "multi_label": multi_label, "hypothesis_template": hypothesis_template, }, headers=self.headers, model=model_id, api_key=self.token, ) response = self._inner_post(request_parameters) output = _bytes_to_dict(response) return [ ZeroShotClassificationOutputElement.parse_obj_as_instance({"label": label, "score": score}) for label, score in zip(output["labels"], output["scores"]) ] def zero_shot_image_classification( self, image: ContentT, candidate_labels: List[str], *, model: Optional[str] = None, hypothesis_template: Optional[str] = None, # deprecated argument labels: List[str] = None, # type: ignore ) -> List[ZeroShotImageClassificationOutputElement]: """ Provide input image and text labels to predict text labels for the image. Args: image (`Union[str, Path, bytes, BinaryIO]`): The input image to caption. It can be raw bytes, an image file, or a URL to an online image. candidate_labels (`List[str]`): The candidate labels for this image labels (`List[str]`, *optional*): (deprecated) List of string possible labels. There must be at least 2 labels. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. This parameter overrides the model defined at the instance level. If not provided, the default recommended zero-shot image classification model will be used. hypothesis_template (`str`, *optional*): The sentence used in conjunction with `candidate_labels` to attempt the image classification by replacing the placeholder with the candidate labels. Returns: `List[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> client.zero_shot_image_classification( ... "https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg", ... labels=["dog", "cat", "horse"], ... ) [ZeroShotImageClassificationOutputElement(label='dog', score=0.956),...] ``` """ # Raise ValueError if input is less than 2 labels if len(candidate_labels) < 2: raise ValueError("You must specify at least 2 classes to compare.") model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="zero-shot-image-classification", model=model_id) request_parameters = provider_helper.prepare_request( inputs=image, parameters={ "candidate_labels": candidate_labels, "hypothesis_template": hypothesis_template, }, headers=self.headers, model=model_id, api_key=self.token, ) response = self._inner_post(request_parameters) return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response) @_deprecate_method( version="0.33.0", message=( "HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)." " Use `HfApi.list_models(..., inference_provider='...')` to list warm models per provider." ), ) def list_deployed_models( self, frameworks: Union[None, str, Literal["all"], List[str]] = None ) -> Dict[str, List[str]]: """ List models deployed on the HF Serverless Inference API service. This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that are supported and account for 95% of the hosted models. However, if you want a complete list of models you can specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more frameworks are checked, the more time it will take. This endpoint method does not return a live list of all models available for the HF Inference API service. It searches over a cached list of models that were recently available and the list may not be up to date. If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`]. This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to check its availability, you can directly use [`~InferenceClient.get_model_status`]. Args: frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*): The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to "all", all available frameworks will be tested. It is also possible to provide a single framework or a custom set of frameworks to check. Returns: `Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs. Example: ```python >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() # Discover zero-shot-classification models currently deployed >>> models = client.list_deployed_models() >>> models["zero-shot-classification"] ['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...] # List from only 1 framework >>> client.list_deployed_models("text-generation-inference") {'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...} ``` """ if self.provider != "hf-inference": raise ValueError(f"Listing deployed models is not supported on '{self.provider}'.") # Resolve which frameworks to check if frameworks is None: frameworks = constants.MAIN_INFERENCE_API_FRAMEWORKS elif frameworks == "all": frameworks = constants.ALL_INFERENCE_API_FRAMEWORKS elif isinstance(frameworks, str): frameworks = [frameworks] frameworks = list(set(frameworks)) # Fetch them iteratively models_by_task: Dict[str, List[str]] = {} def _unpack_response(framework: str, items: List[Dict]) -> None: for model in items: if framework == "sentence-transformers": # Model running with the `sentence-transformers` framework can work with both tasks even if not # branded as such in the API response models_by_task.setdefault("feature-extraction", []).append(model["model_id"]) models_by_task.setdefault("sentence-similarity", []).append(model["model_id"]) else: models_by_task.setdefault(model["task"], []).append(model["model_id"]) for framework in frameworks: response = get_session().get( f"{constants.INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token) ) hf_raise_for_status(response) _unpack_response(framework, response.json()) # Sort alphabetically for discoverability and return for task, models in models_by_task.items(): models_by_task[task] = sorted(set(models), key=lambda x: x.lower()) return models_by_task def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]: """ Get information about the deployed endpoint. This endpoint is only available on endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI). Endpoints powered by `transformers` return an empty payload. Args: model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. Returns: `Dict[str, Any]`: Information about the endpoint. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") >>> client.get_endpoint_info() { 'model_id': 'meta-llama/Meta-Llama-3-70B-Instruct', 'model_sha': None, 'model_dtype': 'torch.float16', 'model_device_type': 'cuda', 'model_pipeline_tag': None, 'max_concurrent_requests': 128, 'max_best_of': 2, 'max_stop_sequences': 4, 'max_input_length': 8191, 'max_total_tokens': 8192, 'waiting_served_ratio': 0.3, 'max_batch_total_tokens': 1259392, 'max_waiting_tokens': 20, 'max_batch_size': None, 'validation_workers': 32, 'max_client_batch_size': 4, 'version': '2.0.2', 'sha': 'dccab72549635c7eb5ddb17f43f0b7cdff07c214', 'docker_label': 'sha-dccab72' } ``` """ if self.provider != "hf-inference": raise ValueError(f"Getting endpoint info is not supported on '{self.provider}'.") model = model or self.model if model is None: raise ValueError("Model id not provided.") if model.startswith(("http://", "https://")): url = model.rstrip("/") + "/info" else: url = f"{constants.INFERENCE_ENDPOINT}/models/{model}/info" response = get_session().get(url, headers=build_hf_headers(token=self.token)) hf_raise_for_status(response) return response.json() def health_check(self, model: Optional[str] = None) -> bool: """ Check the health of the deployed endpoint. Health check is only available with Inference Endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI). For Inference API, please use [`InferenceClient.get_model_status`] instead. Args: model (`str`, *optional*): URL of the Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. Returns: `bool`: True if everything is working fine. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient("https://jzgu0buei5.us-east-1.aws.endpoints.huggingface.cloud") >>> client.health_check() True ``` """ if self.provider != "hf-inference": raise ValueError(f"Health check is not supported on '{self.provider}'.") model = model or self.model if model is None: raise ValueError("Model id not provided.") if not model.startswith(("http://", "https://")): raise ValueError( "Model must be an Inference Endpoint URL. For serverless Inference API, please use `InferenceClient.get_model_status`." ) url = model.rstrip("/") + "/health" response = get_session().get(url, headers=build_hf_headers(token=self.token)) return response.status_code == 200 @_deprecate_method( version="0.33.0", message=( "HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)." " Use `HfApi.model_info` to get the model status both with HF Inference API and external providers." ), ) def get_model_status(self, model: Optional[str] = None) -> ModelStatus: """ Get the status of a model hosted on the HF Inference API. This endpoint is mostly useful when you already know which model you want to use and want to check its availability. If you want to discover already deployed models, you should rather use [`~InferenceClient.list_deployed_models`]. Args: model (`str`, *optional*): Identifier of the model for witch the status gonna be checked. If model is not provided, the model associated with this instance of [`InferenceClient`] will be used. Only HF Inference API service can be checked so the identifier cannot be a URL. Returns: [`ModelStatus`]: An instance of ModelStatus dataclass, containing information, about the state of the model: load, state, compute type and framework. Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> client.get_model_status("meta-llama/Meta-Llama-3-8B-Instruct") ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference') ``` """ if self.provider != "hf-inference": raise ValueError(f"Getting model status is not supported on '{self.provider}'.") model = model or self.model if model is None: raise ValueError("Model id not provided.") if model.startswith("https://"): raise NotImplementedError("Model status is only available for Inference API endpoints.") url = f"{constants.INFERENCE_ENDPOINT}/status/{model}" response = get_session().get(url, headers=build_hf_headers(token=self.token)) hf_raise_for_status(response) response_data = response.json() if "error" in response_data: raise ValueError(response_data["error"]) return ModelStatus( loaded=response_data["loaded"], state=response_data["state"], compute_type=response_data["compute_type"], framework=response_data["framework"], ) @property def chat(self) -> "ProxyClientChat": return ProxyClientChat(self) class _ProxyClient: """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client.""" def __init__(self, client: InferenceClient): self._client = client class ProxyClientChat(_ProxyClient): """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client.""" @property def completions(self) -> "ProxyClientChatCompletions": return ProxyClientChatCompletions(self._client) class ProxyClientChatCompletions(_ProxyClient): """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client.""" @property def create(self): return self._client.chat_completion huggingface_hub-0.31.1/src/huggingface_hub/inference/_common.py000066400000000000000000000343771500667546600245710ustar00rootroot00000000000000# coding=utf-8 # Copyright 2023-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains utilities used by both the sync and async inference clients.""" import base64 import io import json import logging from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path from typing import ( TYPE_CHECKING, Any, AsyncIterable, BinaryIO, ContextManager, Dict, Generator, Iterable, List, Literal, NoReturn, Optional, Union, overload, ) from requests import HTTPError from huggingface_hub.errors import ( GenerationError, IncompleteGenerationError, OverloadedError, TextGenerationError, UnknownError, ValidationError, ) from ..utils import get_session, is_aiohttp_available, is_numpy_available, is_pillow_available from ._generated.types import ChatCompletionStreamOutput, TextGenerationStreamOutput if TYPE_CHECKING: from aiohttp import ClientResponse, ClientSession from PIL.Image import Image # TYPES UrlT = str PathT = Union[str, Path] BinaryT = Union[bytes, BinaryIO] ContentT = Union[BinaryT, PathT, UrlT] # Use to set a Accept: image/png header TASKS_EXPECTING_IMAGES = {"text-to-image", "image-to-image"} logger = logging.getLogger(__name__) @dataclass class RequestParameters: url: str task: str model: Optional[str] json: Optional[Union[str, Dict, List]] data: Optional[ContentT] headers: Dict[str, Any] # Add dataclass for ModelStatus. We use this dataclass in get_model_status function. @dataclass class ModelStatus: """ This Dataclass represents the model status in the HF Inference API. Args: loaded (`bool`): If the model is currently loaded into HF's Inference API. Models are loaded on-demand, leading to the user's first request taking longer. If a model is loaded, you can be assured that it is in a healthy state. state (`str`): The current state of the model. This can be 'Loaded', 'Loadable', 'TooBig'. If a model's state is 'Loadable', it's not too big and has a supported backend. Loadable models are automatically loaded when the user first requests inference on the endpoint. This means it is transparent for the user to load a model, except that the first call takes longer to complete. compute_type (`Dict`): Information about the compute resource the model is using or will use, such as 'gpu' type and number of replicas. framework (`str`): The name of the framework that the model was built with, such as 'transformers' or 'text-generation-inference'. """ loaded: bool state: str compute_type: Dict framework: str ## IMPORT UTILS def _import_aiohttp(): # Make sure `aiohttp` is installed on the machine. if not is_aiohttp_available(): raise ImportError("Please install aiohttp to use `AsyncInferenceClient` (`pip install aiohttp`).") import aiohttp return aiohttp def _import_numpy(): """Make sure `numpy` is installed on the machine.""" if not is_numpy_available(): raise ImportError("Please install numpy to use deal with embeddings (`pip install numpy`).") import numpy return numpy def _import_pil_image(): """Make sure `PIL` is installed on the machine.""" if not is_pillow_available(): raise ImportError( "Please install Pillow to use deal with images (`pip install Pillow`). If you don't want the image to be" " post-processed, use `client.post(...)` and get the raw response from the server." ) from PIL import Image return Image ## ENCODING / DECODING UTILS @overload def _open_as_binary( content: ContentT, ) -> ContextManager[BinaryT]: ... # means "if input is not None, output is not None" @overload def _open_as_binary( content: Literal[None], ) -> ContextManager[Literal[None]]: ... # means "if input is None, output is None" @contextmanager # type: ignore def _open_as_binary(content: Optional[ContentT]) -> Generator[Optional[BinaryT], None, None]: """Open `content` as a binary file, either from a URL, a local path, or raw bytes. Do nothing if `content` is None, TODO: handle a PIL.Image as input TODO: handle base64 as input """ # If content is a string => must be either a URL or a path if isinstance(content, str): if content.startswith("https://") or content.startswith("http://"): logger.debug(f"Downloading content from {content}") yield get_session().get(content).content # TODO: retrieve as stream and pipe to post request ? return content = Path(content) if not content.exists(): raise FileNotFoundError( f"File not found at {content}. If `data` is a string, it must either be a URL or a path to a local" " file. To pass raw content, please encode it as bytes first." ) # If content is a Path => open it if isinstance(content, Path): logger.debug(f"Opening content from {content}") with content.open("rb") as f: yield f else: # Otherwise: already a file-like object or None yield content def _b64_encode(content: ContentT) -> str: """Encode a raw file (image, audio) into base64. Can be bytes, an opened file, a path or a URL.""" with _open_as_binary(content) as data: data_as_bytes = data if isinstance(data, bytes) else data.read() return base64.b64encode(data_as_bytes).decode() def _b64_to_image(encoded_image: str) -> "Image": """Parse a base64-encoded string into a PIL Image.""" Image = _import_pil_image() return Image.open(io.BytesIO(base64.b64decode(encoded_image))) def _bytes_to_list(content: bytes) -> List: """Parse bytes from a Response object into a Python list. Expects the response body to be JSON-encoded data. NOTE: This is exactly the same implementation as `_bytes_to_dict` and will not complain if the returned data is a dictionary. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect. """ return json.loads(content.decode()) def _bytes_to_dict(content: bytes) -> Dict: """Parse bytes from a Response object into a Python dictionary. Expects the response body to be JSON-encoded data. NOTE: This is exactly the same implementation as `_bytes_to_list` and will not complain if the returned data is a list. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect. """ return json.loads(content.decode()) def _bytes_to_image(content: bytes) -> "Image": """Parse bytes from a Response object into a PIL Image. Expects the response body to be raw bytes. To deal with b64 encoded images, use `_b64_to_image` instead. """ Image = _import_pil_image() return Image.open(io.BytesIO(content)) def _as_dict(response: Union[bytes, Dict]) -> Dict: return json.loads(response) if isinstance(response, bytes) else response ## PAYLOAD UTILS ## STREAMING UTILS def _stream_text_generation_response( bytes_output_as_lines: Iterable[bytes], details: bool ) -> Union[Iterable[str], Iterable[TextGenerationStreamOutput]]: """Used in `InferenceClient.text_generation`.""" # Parse ServerSentEvents for byte_payload in bytes_output_as_lines: try: output = _format_text_generation_stream_output(byte_payload, details) except StopIteration: break if output is not None: yield output async def _async_stream_text_generation_response( bytes_output_as_lines: AsyncIterable[bytes], details: bool ) -> Union[AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]: """Used in `AsyncInferenceClient.text_generation`.""" # Parse ServerSentEvents async for byte_payload in bytes_output_as_lines: try: output = _format_text_generation_stream_output(byte_payload, details) except StopIteration: break if output is not None: yield output def _format_text_generation_stream_output( byte_payload: bytes, details: bool ) -> Optional[Union[str, TextGenerationStreamOutput]]: if not byte_payload.startswith(b"data:"): return None # empty line if byte_payload.strip() == b"data: [DONE]": raise StopIteration("[DONE] signal received.") # Decode payload payload = byte_payload.decode("utf-8") json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) # Either an error as being returned if json_payload.get("error") is not None: raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type")) # Or parse token payload output = TextGenerationStreamOutput.parse_obj_as_instance(json_payload) return output.token.text if not details else output def _stream_chat_completion_response( bytes_lines: Iterable[bytes], ) -> Iterable[ChatCompletionStreamOutput]: """Used in `InferenceClient.chat_completion` if model is served with TGI.""" for item in bytes_lines: try: output = _format_chat_completion_stream_output(item) except StopIteration: break if output is not None: yield output async def _async_stream_chat_completion_response( bytes_lines: AsyncIterable[bytes], ) -> AsyncIterable[ChatCompletionStreamOutput]: """Used in `AsyncInferenceClient.chat_completion`.""" async for item in bytes_lines: try: output = _format_chat_completion_stream_output(item) except StopIteration: break if output is not None: yield output def _format_chat_completion_stream_output( byte_payload: bytes, ) -> Optional[ChatCompletionStreamOutput]: if not byte_payload.startswith(b"data:"): return None # empty line if byte_payload.strip() == b"data: [DONE]": raise StopIteration("[DONE] signal received.") # Decode payload payload = byte_payload.decode("utf-8") json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) # Either an error as being returned if json_payload.get("error") is not None: raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type")) # Or parse token payload return ChatCompletionStreamOutput.parse_obj_as_instance(json_payload) async def _async_yield_from(client: "ClientSession", response: "ClientResponse") -> AsyncIterable[bytes]: async for byte_payload in response.content: yield byte_payload.strip() await client.close() # "TGI servers" are servers running with the `text-generation-inference` backend. # This backend is the go-to solution to run large language models at scale. However, # for some smaller models (e.g. "gpt2") the default `transformers` + `api-inference` # solution is still in use. # # Both approaches have very similar APIs, but not exactly the same. What we do first in # the `text_generation` method is to assume the model is served via TGI. If we realize # it's not the case (i.e. we receive an HTTP 400 Bad Request), we fallback to the # default API with a warning message. When that's the case, We remember the unsupported # attributes for this model in the `_UNSUPPORTED_TEXT_GENERATION_KWARGS` global variable. # # In addition, TGI servers have a built-in API route for chat-completion, which is not # available on the default API. We use this route to provide a more consistent behavior # when available. # # For more details, see https://github.com/huggingface/text-generation-inference and # https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task. _UNSUPPORTED_TEXT_GENERATION_KWARGS: Dict[Optional[str], List[str]] = {} def _set_unsupported_text_generation_kwargs(model: Optional[str], unsupported_kwargs: List[str]) -> None: _UNSUPPORTED_TEXT_GENERATION_KWARGS.setdefault(model, []).extend(unsupported_kwargs) def _get_unsupported_text_generation_kwargs(model: Optional[str]) -> List[str]: return _UNSUPPORTED_TEXT_GENERATION_KWARGS.get(model, []) # TEXT GENERATION ERRORS # ---------------------- # Text-generation errors are parsed separately to handle as much as possible the errors returned by the text generation # inference project (https://github.com/huggingface/text-generation-inference). # ---------------------- def raise_text_generation_error(http_error: HTTPError) -> NoReturn: """ Try to parse text-generation-inference error message and raise HTTPError in any case. Args: error (`HTTPError`): The HTTPError that have been raised. """ # Try to parse a Text Generation Inference error try: # Hacky way to retrieve payload in case of aiohttp error payload = getattr(http_error, "response_error_payload", None) or http_error.response.json() error = payload.get("error") error_type = payload.get("error_type") except Exception: # no payload raise http_error # If error_type => more information than `hf_raise_for_status` if error_type is not None: exception = _parse_text_generation_error(error, error_type) raise exception from http_error # Otherwise, fallback to default error raise http_error def _parse_text_generation_error(error: Optional[str], error_type: Optional[str]) -> TextGenerationError: if error_type == "generation": return GenerationError(error) # type: ignore if error_type == "incomplete_generation": return IncompleteGenerationError(error) # type: ignore if error_type == "overloaded": return OverloadedError(error) # type: ignore if error_type == "validation": return ValidationError(error) # type: ignore return UnknownError(error) # type: ignore huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/000077500000000000000000000000001500667546600246475ustar00rootroot00000000000000huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/__init__.py000066400000000000000000000000001500667546600267460ustar00rootroot00000000000000huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/_async_client.py000066400000000000000000005071721500667546600300470ustar00rootroot00000000000000# coding=utf-8 # Copyright 2023-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # WARNING # This entire file has been adapted from the sync-client code in `src/huggingface_hub/inference/_client.py`. # Any change in InferenceClient will be automatically reflected in AsyncInferenceClient. # To re-generate the code, run `make style` or `python ./utils/generate_async_inference_client.py --update`. # WARNING import asyncio import base64 import logging import re import warnings from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Set, Union, overload from huggingface_hub import constants from huggingface_hub.errors import InferenceTimeoutError from huggingface_hub.inference._common import ( TASKS_EXPECTING_IMAGES, ContentT, ModelStatus, RequestParameters, _async_stream_chat_completion_response, _async_stream_text_generation_response, _b64_encode, _b64_to_image, _bytes_to_dict, _bytes_to_image, _bytes_to_list, _get_unsupported_text_generation_kwargs, _import_numpy, _open_as_binary, _set_unsupported_text_generation_kwargs, raise_text_generation_error, ) from huggingface_hub.inference._generated.types import ( AudioClassificationOutputElement, AudioClassificationOutputTransform, AudioToAudioOutputElement, AutomaticSpeechRecognitionOutput, ChatCompletionInputGrammarType, ChatCompletionInputStreamOptions, ChatCompletionInputTool, ChatCompletionInputToolChoiceClass, ChatCompletionInputToolChoiceEnum, ChatCompletionOutput, ChatCompletionStreamOutput, DocumentQuestionAnsweringOutputElement, FillMaskOutputElement, ImageClassificationOutputElement, ImageClassificationOutputTransform, ImageSegmentationOutputElement, ImageSegmentationSubtask, ImageToImageTargetSize, ImageToTextOutput, ObjectDetectionOutputElement, Padding, QuestionAnsweringOutputElement, SummarizationOutput, SummarizationTruncationStrategy, TableQuestionAnsweringOutputElement, TextClassificationOutputElement, TextClassificationOutputTransform, TextGenerationInputGrammarType, TextGenerationOutput, TextGenerationStreamOutput, TextToSpeechEarlyStoppingEnum, TokenClassificationAggregationStrategy, TokenClassificationOutputElement, TranslationOutput, TranslationTruncationStrategy, VisualQuestionAnsweringOutputElement, ZeroShotClassificationOutputElement, ZeroShotImageClassificationOutputElement, ) from huggingface_hub.inference._providers import PROVIDER_T, get_provider_helper from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status from huggingface_hub.utils._auth import get_token from huggingface_hub.utils._deprecation import _deprecate_method from .._common import _async_yield_from, _import_aiohttp if TYPE_CHECKING: import numpy as np from aiohttp import ClientResponse, ClientSession from PIL.Image import Image logger = logging.getLogger(__name__) MODEL_KWARGS_NOT_USED_REGEX = re.compile(r"The following `model_kwargs` are not used by the model: \[(.*?)\]") class AsyncInferenceClient: """ Initialize a new Inference Client. [`InferenceClient`] aims to provide a unified experience to perform inference. The client can be used seamlessly with either the (free) Inference API, self-hosted Inference Endpoints, or third-party Inference Providers. Args: model (`str`, `optional`): The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `meta-llama/Meta-Llama-3-8B-Instruct` or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is automatically selected for the task. Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2 arguments are mutually exclusive. If using `base_url` for chat completion, the `/chat/completions` suffix path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) documentation for details). When passing a URL as `model`, the client will not append any suffix path to it. provider (`str`, *optional*): Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`. Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. If model is a URL or `base_url` is passed, then `provider` is not used. token (`str`, *optional*): Hugging Face token. Will default to the locally saved token if not provided. Note: for better compatibility with OpenAI's client, `token` has been aliased as `api_key`. Those 2 arguments are mutually exclusive and have the exact same behavior. timeout (`float`, `optional`): The maximum number of seconds to wait for a response from the server. Defaults to None, meaning it will loop until the server is available. headers (`Dict[str, str]`, `optional`): Additional headers to send to the server. By default only the authorization and user-agent headers are sent. Values in this dictionary will override the default values. bill_to (`str`, `optional`): The billing account to use for the requests. By default the requests are billed on the user's account. Requests can only be billed to an organization the user is a member of, and which has subscribed to Enterprise Hub. cookies (`Dict[str, str]`, `optional`): Additional cookies to send to the server. trust_env ('bool', 'optional'): Trust environment settings for proxy configuration if the parameter is `True` (`False` by default). proxies (`Any`, `optional`): Proxies to use for the request. base_url (`str`, `optional`): Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`] follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None. api_key (`str`, `optional`): Token to use for authentication. This is a duplicated argument from `token` to make [`InferenceClient`] follow the same pattern as `openai.OpenAI` client. Cannot be used if `token` is set. Defaults to None. """ def __init__( self, model: Optional[str] = None, *, provider: Union[Literal["auto"], PROVIDER_T, None] = None, token: Optional[str] = None, timeout: Optional[float] = None, headers: Optional[Dict[str, str]] = None, cookies: Optional[Dict[str, str]] = None, trust_env: bool = False, proxies: Optional[Any] = None, bill_to: Optional[str] = None, # OpenAI compatibility base_url: Optional[str] = None, api_key: Optional[str] = None, ) -> None: if model is not None and base_url is not None: raise ValueError( "Received both `model` and `base_url` arguments. Please provide only one of them." " `base_url` is an alias for `model` to make the API compatible with OpenAI's client." " If using `base_url` for chat completion, the `/chat/completions` suffix path will be appended to the base url." " When passing a URL as `model`, the client will not append any suffix path to it." ) if token is not None and api_key is not None: raise ValueError( "Received both `token` and `api_key` arguments. Please provide only one of them." " `api_key` is an alias for `token` to make the API compatible with OpenAI's client." " It has the exact same behavior as `token`." ) token = token if token is not None else api_key if isinstance(token, bool): # Legacy behavior: previously is was possible to pass `token=False` to disable authentication. This is not # supported anymore as authentication is required. Better to explicitly raise here rather than risking # sending the locally saved token without the user knowing about it. if token is False: raise ValueError( "Cannot use `token=False` to disable authentication as authentication is required to run Inference." ) warnings.warn( "Using `token=True` to automatically use the locally saved token is deprecated and will be removed in a future release. " "Please use `token=None` instead (default).", DeprecationWarning, ) token = get_token() self.model: Optional[str] = base_url or model self.token: Optional[str] = token self.headers = {**headers} if headers is not None else {} if bill_to is not None: if ( constants.HUGGINGFACE_HEADER_X_BILL_TO in self.headers and self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO] != bill_to ): warnings.warn( f"Overriding existing '{self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO]}' value in headers with '{bill_to}'.", UserWarning, ) self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO] = bill_to if token is not None and not token.startswith("hf_"): warnings.warn( "You've provided an external provider's API key, so requests will be billed directly by the provider. " "The `bill_to` parameter is only applicable for Hugging Face billing and will be ignored.", UserWarning, ) # Configure provider self.provider = provider self.cookies = cookies self.timeout = timeout self.trust_env = trust_env self.proxies = proxies # Keep track of the sessions to close them properly self._sessions: Dict["ClientSession", Set["ClientResponse"]] = dict() def __repr__(self): return f"" @overload async def _inner_post( # type: ignore[misc] self, request_parameters: RequestParameters, *, stream: Literal[False] = ... ) -> bytes: ... @overload async def _inner_post( # type: ignore[misc] self, request_parameters: RequestParameters, *, stream: Literal[True] = ... ) -> AsyncIterable[bytes]: ... @overload async def _inner_post( self, request_parameters: RequestParameters, *, stream: bool = False ) -> Union[bytes, AsyncIterable[bytes]]: ... async def _inner_post( self, request_parameters: RequestParameters, *, stream: bool = False ) -> Union[bytes, AsyncIterable[bytes]]: """Make a request to the inference server.""" aiohttp = _import_aiohttp() # TODO: this should be handled in provider helpers directly if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers: request_parameters.headers["Accept"] = "image/png" with _open_as_binary(request_parameters.data) as data_as_binary: # Do not use context manager as we don't want to close the connection immediately when returning # a stream session = self._get_client_session(headers=request_parameters.headers) try: response = await session.post( request_parameters.url, json=request_parameters.json, data=data_as_binary, proxy=self.proxies ) response_error_payload = None if response.status != 200: try: response_error_payload = await response.json() # get payload before connection closed except Exception: pass response.raise_for_status() if stream: return _async_yield_from(session, response) else: content = await response.read() await session.close() return content except asyncio.TimeoutError as error: await session.close() # Convert any `TimeoutError` to a `InferenceTimeoutError` raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore except aiohttp.ClientResponseError as error: error.response_error_payload = response_error_payload await session.close() raise error except Exception: await session.close() raise async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_value, traceback): await self.close() def __del__(self): if len(self._sessions) > 0: warnings.warn( "Deleting 'AsyncInferenceClient' client but some sessions are still open. " "This can happen if you've stopped streaming data from the server before the stream was complete. " "To close the client properly, you must call `await client.close()` " "or use an async context (e.g. `async with AsyncInferenceClient(): ...`." ) async def close(self): """Close all open sessions. By default, 'aiohttp.ClientSession' objects are closed automatically when a call is completed. However, if you are streaming data from the server and you stop before the stream is complete, you must call this method to close the session properly. Another possibility is to use an async context (e.g. `async with AsyncInferenceClient(): ...`). """ await asyncio.gather(*[session.close() for session in self._sessions.keys()]) async def audio_classification( self, audio: ContentT, *, model: Optional[str] = None, top_k: Optional[int] = None, function_to_apply: Optional["AudioClassificationOutputTransform"] = None, ) -> List[AudioClassificationOutputElement]: """ Perform audio classification on the provided audio content. Args: audio (Union[str, Path, bytes, BinaryIO]): The audio content to classify. It can be raw audio bytes, a local audio file, or a URL pointing to an audio file. model (`str`, *optional*): The model to use for audio classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for audio classification will be used. top_k (`int`, *optional*): When specified, limits the output to the top K most probable classes. function_to_apply (`"AudioClassificationOutputTransform"`, *optional*): The function to apply to the model outputs in order to retrieve the scores. Returns: `List[AudioClassificationOutputElement]`: List of [`AudioClassificationOutputElement`] items containing the predicted labels and their confidence. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> await client.audio_classification("audio.flac") [ AudioClassificationOutputElement(score=0.4976358711719513, label='hap'), AudioClassificationOutputElement(score=0.3677836060523987, label='neu'), ... ] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="audio-classification", model=model_id) request_parameters = provider_helper.prepare_request( inputs=audio, parameters={"function_to_apply": function_to_apply, "top_k": top_k}, headers=self.headers, model=model_id, api_key=self.token, ) response = await self._inner_post(request_parameters) return AudioClassificationOutputElement.parse_obj_as_list(response) async def audio_to_audio( self, audio: ContentT, *, model: Optional[str] = None, ) -> List[AudioToAudioOutputElement]: """ Performs multiple tasks related to audio-to-audio depending on the model (eg: speech enhancement, source separation). Args: audio (Union[str, Path, bytes, BinaryIO]): The audio content for the model. It can be raw audio bytes, a local audio file, or a URL pointing to an audio file. model (`str`, *optional*): The model can be any model which takes an audio file and returns another audio file. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for audio_to_audio will be used. Returns: `List[AudioToAudioOutputElement]`: A list of [`AudioToAudioOutputElement`] items containing audios label, content-type, and audio content in blob. Raises: `InferenceTimeoutError`: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> audio_output = await client.audio_to_audio("audio.flac") >>> async for i, item in enumerate(audio_output): >>> with open(f"output_{i}.flac", "wb") as f: f.write(item.blob) ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="audio-to-audio", model=model_id) request_parameters = provider_helper.prepare_request( inputs=audio, parameters={}, headers=self.headers, model=model_id, api_key=self.token, ) response = await self._inner_post(request_parameters) audio_output = AudioToAudioOutputElement.parse_obj_as_list(response) for item in audio_output: item.blob = base64.b64decode(item.blob) return audio_output async def automatic_speech_recognition( self, audio: ContentT, *, model: Optional[str] = None, extra_body: Optional[Dict] = None, ) -> AutomaticSpeechRecognitionOutput: """ Perform automatic speech recognition (ASR or audio-to-text) on the given audio content. Args: audio (Union[str, Path, bytes, BinaryIO]): The content to transcribe. It can be raw audio bytes, local audio file, or a URL to an audio file. model (`str`, *optional*): The model to use for ASR. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for ASR will be used. extra_body (`Dict`, *optional*): Additional provider-specific parameters to pass to the model. Refer to the provider's documentation for supported parameters. Returns: [`AutomaticSpeechRecognitionOutput`]: An item containing the transcribed text and optionally the timestamp chunks. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> await client.automatic_speech_recognition("hello_world.flac").text "hello world" ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="automatic-speech-recognition", model=model_id) request_parameters = provider_helper.prepare_request( inputs=audio, parameters={**(extra_body or {})}, headers=self.headers, model=model_id, api_key=self.token, ) response = await self._inner_post(request_parameters) return AutomaticSpeechRecognitionOutput.parse_obj_as_instance(response) @overload async def chat_completion( # type: ignore self, messages: List[Dict], *, model: Optional[str] = None, stream: Literal[False] = False, frequency_penalty: Optional[float] = None, logit_bias: Optional[List[float]] = None, logprobs: Optional[bool] = None, max_tokens: Optional[int] = None, n: Optional[int] = None, presence_penalty: Optional[float] = None, response_format: Optional[ChatCompletionInputGrammarType] = None, seed: Optional[int] = None, stop: Optional[List[str]] = None, stream_options: Optional[ChatCompletionInputStreamOptions] = None, temperature: Optional[float] = None, tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None, tool_prompt: Optional[str] = None, tools: Optional[List[ChatCompletionInputTool]] = None, top_logprobs: Optional[int] = None, top_p: Optional[float] = None, extra_body: Optional[Dict] = None, ) -> ChatCompletionOutput: ... @overload async def chat_completion( # type: ignore self, messages: List[Dict], *, model: Optional[str] = None, stream: Literal[True] = True, frequency_penalty: Optional[float] = None, logit_bias: Optional[List[float]] = None, logprobs: Optional[bool] = None, max_tokens: Optional[int] = None, n: Optional[int] = None, presence_penalty: Optional[float] = None, response_format: Optional[ChatCompletionInputGrammarType] = None, seed: Optional[int] = None, stop: Optional[List[str]] = None, stream_options: Optional[ChatCompletionInputStreamOptions] = None, temperature: Optional[float] = None, tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None, tool_prompt: Optional[str] = None, tools: Optional[List[ChatCompletionInputTool]] = None, top_logprobs: Optional[int] = None, top_p: Optional[float] = None, extra_body: Optional[Dict] = None, ) -> AsyncIterable[ChatCompletionStreamOutput]: ... @overload async def chat_completion( self, messages: List[Dict], *, model: Optional[str] = None, stream: bool = False, frequency_penalty: Optional[float] = None, logit_bias: Optional[List[float]] = None, logprobs: Optional[bool] = None, max_tokens: Optional[int] = None, n: Optional[int] = None, presence_penalty: Optional[float] = None, response_format: Optional[ChatCompletionInputGrammarType] = None, seed: Optional[int] = None, stop: Optional[List[str]] = None, stream_options: Optional[ChatCompletionInputStreamOptions] = None, temperature: Optional[float] = None, tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None, tool_prompt: Optional[str] = None, tools: Optional[List[ChatCompletionInputTool]] = None, top_logprobs: Optional[int] = None, top_p: Optional[float] = None, extra_body: Optional[Dict] = None, ) -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]: ... async def chat_completion( self, messages: List[Dict], *, model: Optional[str] = None, stream: bool = False, # Parameters from ChatCompletionInput (handled manually) frequency_penalty: Optional[float] = None, logit_bias: Optional[List[float]] = None, logprobs: Optional[bool] = None, max_tokens: Optional[int] = None, n: Optional[int] = None, presence_penalty: Optional[float] = None, response_format: Optional[ChatCompletionInputGrammarType] = None, seed: Optional[int] = None, stop: Optional[List[str]] = None, stream_options: Optional[ChatCompletionInputStreamOptions] = None, temperature: Optional[float] = None, tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None, tool_prompt: Optional[str] = None, tools: Optional[List[ChatCompletionInputTool]] = None, top_logprobs: Optional[int] = None, top_p: Optional[float] = None, extra_body: Optional[Dict] = None, ) -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]: """ A method for completing conversations using a specified language model. The `client.chat_completion` method is aliased as `client.chat.completions.create` for compatibility with OpenAI's client. Inputs and outputs are strictly the same and using either syntax will yield the same results. Check out the [Inference guide](https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility) for more details about OpenAI's compatibility. You can pass provider-specific parameters to the model by using the `extra_body` argument. Args: messages (List of [`ChatCompletionInputMessage`]): Conversation history consisting of roles and content pairs. model (`str`, *optional*): The model to use for chat-completion. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for chat-based text-generation will be used. See https://huggingface.co/tasks/text-generation for more details. If `model` is a model ID, it is passed to the server as the `model` parameter. If you want to define a custom URL while setting `model` in the request payload, you must set `base_url` when initializing [`InferenceClient`]. frequency_penalty (`float`, *optional*): Penalizes new tokens based on their existing frequency in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0. logit_bias (`List[float]`, *optional*): Adjusts the likelihood of specific tokens appearing in the generated output. logprobs (`bool`, *optional*): Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message. max_tokens (`int`, *optional*): Maximum number of tokens allowed in the response. Defaults to 100. n (`int`, *optional*): The number of completions to generate for each prompt. presence_penalty (`float`, *optional*): Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. response_format ([`ChatCompletionInputGrammarType`], *optional*): Grammar constraints. Can be either a JSONSchema or a regex. seed (Optional[`int`], *optional*): Seed for reproducible control flow. Defaults to None. stop (`List[str]`, *optional*): Up to four strings which trigger the end of the response. Defaults to None. stream (`bool`, *optional*): Enable realtime streaming of responses. Defaults to False. stream_options ([`ChatCompletionInputStreamOptions`], *optional*): Options for streaming completions. temperature (`float`, *optional*): Controls randomness of the generations. Lower values ensure less random completions. Range: [0, 2]. Defaults to 1.0. top_logprobs (`int`, *optional*): An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with an associated log probability. logprobs must be set to true if this parameter is used. top_p (`float`, *optional*): Fraction of the most likely next words to sample from. Must be between 0 and 1. Defaults to 1.0. tool_choice ([`ChatCompletionInputToolChoiceClass`] or [`ChatCompletionInputToolChoiceEnum`], *optional*): The tool to use for the completion. Defaults to "auto". tool_prompt (`str`, *optional*): A prompt to be appended before the tools. tools (List of [`ChatCompletionInputTool`], *optional*): A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. extra_body (`Dict`, *optional*): Additional provider-specific parameters to pass to the model. Refer to the provider's documentation for supported parameters. Returns: [`ChatCompletionOutput`] or Iterable of [`ChatCompletionStreamOutput`]: Generated text returned from the server: - if `stream=False`, the generated text is returned as a [`ChatCompletionOutput`] (default). - if `stream=True`, the generated text is returned token by token as a sequence of [`ChatCompletionStreamOutput`]. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> messages = [{"role": "user", "content": "What is the capital of France?"}] >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-8B-Instruct") >>> await client.chat_completion(messages, max_tokens=100) ChatCompletionOutput( choices=[ ChatCompletionOutputComplete( finish_reason='eos_token', index=0, message=ChatCompletionOutputMessage( role='assistant', content='The capital of France is Paris.', name=None, tool_calls=None ), logprobs=None ) ], created=1719907176, id='', model='meta-llama/Meta-Llama-3-8B-Instruct', object='text_completion', system_fingerprint='2.0.4-sha-f426a33', usage=ChatCompletionOutputUsage( completion_tokens=8, prompt_tokens=17, total_tokens=25 ) ) ``` Example using streaming: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> messages = [{"role": "user", "content": "What is the capital of France?"}] >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-8B-Instruct") >>> async for token in await client.chat_completion(messages, max_tokens=10, stream=True): ... print(token) ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content='The', role='assistant'), index=0, finish_reason=None)], created=1710498504) ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' capital', role='assistant'), index=0, finish_reason=None)], created=1710498504) (...) ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' may', role='assistant'), index=0, finish_reason=None)], created=1710498504) ``` Example using OpenAI's syntax: ```py # Must be run in an async context # instead of `from openai import OpenAI` from huggingface_hub import AsyncInferenceClient # instead of `client = OpenAI(...)` client = AsyncInferenceClient( base_url=..., api_key=..., ) output = await client.chat.completions.create( model="meta-llama/Meta-Llama-3-8B-Instruct", messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Count to 10"}, ], stream=True, max_tokens=1024, ) for chunk in output: print(chunk.choices[0].delta.content) ``` Example using a third-party provider directly with extra (provider-specific) parameters. Usage will be billed on your Together AI account. ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient( ... provider="together", # Use Together AI provider ... api_key="", # Pass your Together API key directly ... ) >>> client.chat_completion( ... model="meta-llama/Meta-Llama-3-8B-Instruct", ... messages=[{"role": "user", "content": "What is the capital of France?"}], ... extra_body={"safety_model": "Meta-Llama/Llama-Guard-7b"}, ... ) ``` Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account. ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient( ... provider="sambanova", # Use Sambanova provider ... api_key="hf_...", # Pass your HF token ... ) >>> client.chat_completion( ... model="meta-llama/Meta-Llama-3-8B-Instruct", ... messages=[{"role": "user", "content": "What is the capital of France?"}], ... ) ``` Example using Image + Text as input: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient # provide a remote URL >>> image_url ="https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" # or a base64-encoded image >>> image_path = "/path/to/image.jpeg" >>> with open(image_path, "rb") as f: ... base64_image = base64.b64encode(f.read()).decode("utf-8") >>> image_url = f"data:image/jpeg;base64,{base64_image}" >>> client = AsyncInferenceClient("meta-llama/Llama-3.2-11B-Vision-Instruct") >>> output = await client.chat.completions.create( ... messages=[ ... { ... "role": "user", ... "content": [ ... { ... "type": "image_url", ... "image_url": {"url": image_url}, ... }, ... { ... "type": "text", ... "text": "Describe this image in one sentence.", ... }, ... ], ... }, ... ], ... ) >>> output The image depicts the iconic Statue of Liberty situated in New York Harbor, New York, on a clear day. ``` Example using tools: ```py # Must be run in an async context >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") >>> messages = [ ... { ... "role": "system", ... "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.", ... }, ... { ... "role": "user", ... "content": "What's the weather like the next 3 days in San Francisco, CA?", ... }, ... ] >>> tools = [ ... { ... "type": "function", ... "function": { ... "name": "get_current_weather", ... "description": "Get the current weather", ... "parameters": { ... "type": "object", ... "properties": { ... "location": { ... "type": "string", ... "description": "The city and state, e.g. San Francisco, CA", ... }, ... "format": { ... "type": "string", ... "enum": ["celsius", "fahrenheit"], ... "description": "The temperature unit to use. Infer this from the users location.", ... }, ... }, ... "required": ["location", "format"], ... }, ... }, ... }, ... { ... "type": "function", ... "function": { ... "name": "get_n_day_weather_forecast", ... "description": "Get an N-day weather forecast", ... "parameters": { ... "type": "object", ... "properties": { ... "location": { ... "type": "string", ... "description": "The city and state, e.g. San Francisco, CA", ... }, ... "format": { ... "type": "string", ... "enum": ["celsius", "fahrenheit"], ... "description": "The temperature unit to use. Infer this from the users location.", ... }, ... "num_days": { ... "type": "integer", ... "description": "The number of days to forecast", ... }, ... }, ... "required": ["location", "format", "num_days"], ... }, ... }, ... }, ... ] >>> response = await client.chat_completion( ... model="meta-llama/Meta-Llama-3-70B-Instruct", ... messages=messages, ... tools=tools, ... tool_choice="auto", ... max_tokens=500, ... ) >>> response.choices[0].message.tool_calls[0].function ChatCompletionOutputFunctionDefinition( arguments={ 'location': 'San Francisco, CA', 'format': 'fahrenheit', 'num_days': 3 }, name='get_n_day_weather_forecast', description=None ) ``` Example using response_format: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") >>> messages = [ ... { ... "role": "user", ... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I saw and when?", ... }, ... ] >>> response_format = { ... "type": "json", ... "value": { ... "properties": { ... "location": {"type": "string"}, ... "activity": {"type": "string"}, ... "animals_seen": {"type": "integer", "minimum": 1, "maximum": 5}, ... "animals": {"type": "array", "items": {"type": "string"}}, ... }, ... "required": ["location", "activity", "animals_seen", "animals"], ... }, ... } >>> response = await client.chat_completion( ... messages=messages, ... response_format=response_format, ... max_tokens=500, ... ) >>> response.choices[0].message.content '{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}' ``` """ # Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently. # `self.model` takes precedence over 'model' argument for building URL. # `model` takes precedence for payload value. model_id_or_url = self.model or model payload_model = model or self.model # Get the provider helper provider_helper = get_provider_helper( self.provider, task="conversational", model=model_id_or_url if model_id_or_url is not None and model_id_or_url.startswith(("http://", "https://")) else payload_model, ) # Prepare the payload parameters = { "model": payload_model, "frequency_penalty": frequency_penalty, "logit_bias": logit_bias, "logprobs": logprobs, "max_tokens": max_tokens, "n": n, "presence_penalty": presence_penalty, "response_format": response_format, "seed": seed, "stop": stop, "temperature": temperature, "tool_choice": tool_choice, "tool_prompt": tool_prompt, "tools": tools, "top_logprobs": top_logprobs, "top_p": top_p, "stream": stream, "stream_options": stream_options, **(extra_body or {}), } request_parameters = provider_helper.prepare_request( inputs=messages, parameters=parameters, headers=self.headers, model=model_id_or_url, api_key=self.token, ) data = await self._inner_post(request_parameters, stream=stream) if stream: return _async_stream_chat_completion_response(data) # type: ignore[arg-type] return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type] async def document_question_answering( self, image: ContentT, question: str, *, model: Optional[str] = None, doc_stride: Optional[int] = None, handle_impossible_answer: Optional[bool] = None, lang: Optional[str] = None, max_answer_len: Optional[int] = None, max_question_len: Optional[int] = None, max_seq_len: Optional[int] = None, top_k: Optional[int] = None, word_boxes: Optional[List[Union[List[float], str]]] = None, ) -> List[DocumentQuestionAnsweringOutputElement]: """ Answer questions on document images. Args: image (`Union[str, Path, bytes, BinaryIO]`): The input image for the context. It can be raw bytes, an image file, or a URL to an online image. question (`str`): Question to be answered. model (`str`, *optional*): The model to use for the document question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended document question answering model will be used. Defaults to None. doc_stride (`int`, *optional*): If the words in the document are too long to fit with the question for the model, it will be split in several chunks with some overlap. This argument controls the size of that overlap. handle_impossible_answer (`bool`, *optional*): Whether to accept impossible as an answer lang (`str`, *optional*): Language to use while running OCR. Defaults to english. max_answer_len (`int`, *optional*): The maximum length of predicted answers (e.g., only answers with a shorter length are considered). max_question_len (`int`, *optional*): The maximum length of the question after tokenization. It will be truncated if needed. max_seq_len (`int`, *optional*): The maximum length of the total sentence (context + question) in tokens of each chunk passed to the model. The context will be split in several chunks (using doc_stride as overlap) if needed. top_k (`int`, *optional*): The number of answers to return (will be chosen by order of likelihood). Can return less than top_k answers if there are not enough options available within the context. word_boxes (`List[Union[List[float], str`, *optional*): A list of words and bounding boxes (normalized 0->1000). If provided, the inference will skip the OCR step and use the provided bounding boxes instead. Returns: `List[DocumentQuestionAnsweringOutputElement]`: a list of [`DocumentQuestionAnsweringOutputElement`] items containing the predicted label, associated probability, word ids, and page number. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> await client.document_question_answering(image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png", question="What is the invoice number?") [DocumentQuestionAnsweringOutputElement(answer='us-001', end=16, score=0.9999666213989258, start=16)] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="document-question-answering", model=model_id) inputs: Dict[str, Any] = {"question": question, "image": _b64_encode(image)} request_parameters = provider_helper.prepare_request( inputs=inputs, parameters={ "doc_stride": doc_stride, "handle_impossible_answer": handle_impossible_answer, "lang": lang, "max_answer_len": max_answer_len, "max_question_len": max_question_len, "max_seq_len": max_seq_len, "top_k": top_k, "word_boxes": word_boxes, }, headers=self.headers, model=model_id, api_key=self.token, ) response = await self._inner_post(request_parameters) return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response) async def feature_extraction( self, text: str, *, normalize: Optional[bool] = None, prompt_name: Optional[str] = None, truncate: Optional[bool] = None, truncation_direction: Optional[Literal["Left", "Right"]] = None, model: Optional[str] = None, ) -> "np.ndarray": """ Generate embeddings for a given text. Args: text (`str`): The text to embed. model (`str`, *optional*): The model to use for the feature extraction task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended feature extraction model will be used. Defaults to None. normalize (`bool`, *optional*): Whether to normalize the embeddings or not. Only available on server powered by Text-Embedding-Inference. prompt_name (`str`, *optional*): The name of the prompt that should be used by for encoding. If not set, no prompt will be applied. Must be a key in the `Sentence Transformers` configuration `prompts` dictionary. For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ",...}, then the sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?" because the prompt text will be prepended before any text to encode. truncate (`bool`, *optional*): Whether to truncate the embeddings or not. Only available on server powered by Text-Embedding-Inference. truncation_direction (`Literal["Left", "Right"]`, *optional*): Which side of the input should be truncated when `truncate=True` is passed. Returns: `np.ndarray`: The embedding representing the input text as a float32 numpy array. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> await client.feature_extraction("Hi, who are you?") array([[ 2.424802 , 2.93384 , 1.1750331 , ..., 1.240499, -0.13776633, -0.7889173 ], [-0.42943227, -0.6364878 , -1.693462 , ..., 0.41978157, -2.4336355 , 0.6162071 ], ..., [ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32) ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="feature-extraction", model=model_id) request_parameters = provider_helper.prepare_request( inputs=text, parameters={ "normalize": normalize, "prompt_name": prompt_name, "truncate": truncate, "truncation_direction": truncation_direction, }, headers=self.headers, model=model_id, api_key=self.token, ) response = await self._inner_post(request_parameters) np = _import_numpy() return np.array(provider_helper.get_response(response), dtype="float32") async def fill_mask( self, text: str, *, model: Optional[str] = None, targets: Optional[List[str]] = None, top_k: Optional[int] = None, ) -> List[FillMaskOutputElement]: """ Fill in a hole with a missing word (token to be precise). Args: text (`str`): a string to be filled from, must contain the [MASK] token (check model card for exact name of the mask). model (`str`, *optional*): The model to use for the fill mask task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended fill mask model will be used. targets (`List[str`, *optional*): When passed, the model will limit the scores to the passed targets instead of looking up in the whole vocabulary. If the provided targets are not in the model vocab, they will be tokenized and the first resulting token will be used (with a warning, and that might be slower). top_k (`int`, *optional*): When passed, overrides the number of predictions to return. Returns: `List[FillMaskOutputElement]`: a list of [`FillMaskOutputElement`] items containing the predicted label, associated probability, token reference, and completed text. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> await client.fill_mask("The goal of life is .") [ FillMaskOutputElement(score=0.06897063553333282, token=11098, token_str=' happiness', sequence='The goal of life is happiness.'), FillMaskOutputElement(score=0.06554922461509705, token=45075, token_str=' immortality', sequence='The goal of life is immortality.') ] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="fill-mask", model=model_id) request_parameters = provider_helper.prepare_request( inputs=text, parameters={"targets": targets, "top_k": top_k}, headers=self.headers, model=model_id, api_key=self.token, ) response = await self._inner_post(request_parameters) return FillMaskOutputElement.parse_obj_as_list(response) async def image_classification( self, image: ContentT, *, model: Optional[str] = None, function_to_apply: Optional["ImageClassificationOutputTransform"] = None, top_k: Optional[int] = None, ) -> List[ImageClassificationOutputElement]: """ Perform image classification on the given image using the specified model. Args: image (`Union[str, Path, bytes, BinaryIO]`): The image to classify. It can be raw bytes, an image file, or a URL to an online image. model (`str`, *optional*): The model to use for image classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used. function_to_apply (`"ImageClassificationOutputTransform"`, *optional*): The function to apply to the model outputs in order to retrieve the scores. top_k (`int`, *optional*): When specified, limits the output to the top K most probable classes. Returns: `List[ImageClassificationOutputElement]`: a list of [`ImageClassificationOutputElement`] items containing the predicted label and associated probability. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> await client.image_classification("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg") [ImageClassificationOutputElement(label='Blenheim spaniel', score=0.9779096841812134), ...] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="image-classification", model=model_id) request_parameters = provider_helper.prepare_request( inputs=image, parameters={"function_to_apply": function_to_apply, "top_k": top_k}, headers=self.headers, model=model_id, api_key=self.token, ) response = await self._inner_post(request_parameters) return ImageClassificationOutputElement.parse_obj_as_list(response) async def image_segmentation( self, image: ContentT, *, model: Optional[str] = None, mask_threshold: Optional[float] = None, overlap_mask_area_threshold: Optional[float] = None, subtask: Optional["ImageSegmentationSubtask"] = None, threshold: Optional[float] = None, ) -> List[ImageSegmentationOutputElement]: """ Perform image segmentation on the given image using the specified model. You must have `PIL` installed if you want to work with images (`pip install Pillow`). Args: image (`Union[str, Path, bytes, BinaryIO]`): The image to segment. It can be raw bytes, an image file, or a URL to an online image. model (`str`, *optional*): The model to use for image segmentation. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for image segmentation will be used. mask_threshold (`float`, *optional*): Threshold to use when turning the predicted masks into binary values. overlap_mask_area_threshold (`float`, *optional*): Mask overlap threshold to eliminate small, disconnected segments. subtask (`"ImageSegmentationSubtask"`, *optional*): Segmentation task to be performed, depending on model capabilities. threshold (`float`, *optional*): Probability threshold to filter out predicted masks. Returns: `List[ImageSegmentationOutputElement]`: A list of [`ImageSegmentationOutputElement`] items containing the segmented masks and associated attributes. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> await client.image_segmentation("cat.jpg") [ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=), ...] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="image-segmentation", model=model_id) request_parameters = provider_helper.prepare_request( inputs=image, parameters={ "mask_threshold": mask_threshold, "overlap_mask_area_threshold": overlap_mask_area_threshold, "subtask": subtask, "threshold": threshold, }, headers=self.headers, model=model_id, api_key=self.token, ) response = await self._inner_post(request_parameters) output = ImageSegmentationOutputElement.parse_obj_as_list(response) for item in output: item.mask = _b64_to_image(item.mask) # type: ignore [assignment] return output async def image_to_image( self, image: ContentT, prompt: Optional[str] = None, *, negative_prompt: Optional[str] = None, num_inference_steps: Optional[int] = None, guidance_scale: Optional[float] = None, model: Optional[str] = None, target_size: Optional[ImageToImageTargetSize] = None, **kwargs, ) -> "Image": """ Perform image-to-image translation using a specified model. You must have `PIL` installed if you want to work with images (`pip install Pillow`). Args: image (`Union[str, Path, bytes, BinaryIO]`): The input image for translation. It can be raw bytes, an image file, or a URL to an online image. prompt (`str`, *optional*): The text prompt to guide the image generation. negative_prompt (`str`, *optional*): One prompt to guide what NOT to include in image generation. num_inference_steps (`int`, *optional*): For diffusion models. The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*): For diffusion models. A higher guidance scale value encourages the model to generate images closely linked to the text prompt at the expense of lower image quality. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. target_size (`ImageToImageTargetSize`, *optional*): The size in pixel of the output image. Returns: `Image`: The translated image. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> image = await client.image_to_image("cat.jpg", prompt="turn the cat into a tiger") >>> image.save("tiger.jpg") ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="image-to-image", model=model_id) request_parameters = provider_helper.prepare_request( inputs=image, parameters={ "prompt": prompt, "negative_prompt": negative_prompt, "target_size": target_size, "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, **kwargs, }, headers=self.headers, model=model_id, api_key=self.token, ) response = await self._inner_post(request_parameters) return _bytes_to_image(response) async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput: """ Takes an input image and return text. Models can have very different outputs depending on your use case (image captioning, optical character recognition (OCR), Pix2Struct, etc). Please have a look to the model card to learn more about a model's specificities. Args: image (`Union[str, Path, bytes, BinaryIO]`): The input image to caption. It can be raw bytes, an image file, or a URL to an online image.. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. Returns: [`ImageToTextOutput`]: The generated text. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> await client.image_to_text("cat.jpg") 'a cat standing in a grassy field ' >>> await client.image_to_text("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg") 'a dog laying on the grass next to a flower pot ' ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="image-to-text", model=model_id) request_parameters = provider_helper.prepare_request( inputs=image, parameters={}, headers=self.headers, model=model_id, api_key=self.token, ) response = await self._inner_post(request_parameters) output = ImageToTextOutput.parse_obj(response) return output[0] if isinstance(output, list) else output async def object_detection( self, image: ContentT, *, model: Optional[str] = None, threshold: Optional[float] = None ) -> List[ObjectDetectionOutputElement]: """ Perform object detection on the given image using the specified model. You must have `PIL` installed if you want to work with images (`pip install Pillow`). Args: image (`Union[str, Path, bytes, BinaryIO]`): The image to detect objects on. It can be raw bytes, an image file, or a URL to an online image. model (`str`, *optional*): The model to use for object detection. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for object detection (DETR) will be used. threshold (`float`, *optional*): The probability necessary to make a prediction. Returns: `List[ObjectDetectionOutputElement]`: A list of [`ObjectDetectionOutputElement`] items containing the bounding boxes and associated attributes. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. `ValueError`: If the request output is not a List. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> await client.object_detection("people.jpg") [ObjectDetectionOutputElement(score=0.9486683011054993, label='person', box=ObjectDetectionBoundingBox(xmin=59, ymin=39, xmax=420, ymax=510)), ...] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="object-detection", model=model_id) request_parameters = provider_helper.prepare_request( inputs=image, parameters={"threshold": threshold}, headers=self.headers, model=model_id, api_key=self.token, ) response = await self._inner_post(request_parameters) return ObjectDetectionOutputElement.parse_obj_as_list(response) async def question_answering( self, question: str, context: str, *, model: Optional[str] = None, align_to_words: Optional[bool] = None, doc_stride: Optional[int] = None, handle_impossible_answer: Optional[bool] = None, max_answer_len: Optional[int] = None, max_question_len: Optional[int] = None, max_seq_len: Optional[int] = None, top_k: Optional[int] = None, ) -> Union[QuestionAnsweringOutputElement, List[QuestionAnsweringOutputElement]]: """ Retrieve the answer to a question from a given text. Args: question (`str`): Question to be answered. context (`str`): The context of the question. model (`str`): The model to use for the question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. align_to_words (`bool`, *optional*): Attempts to align the answer to real words. Improves quality on space separated languages. Might hurt on non-space-separated languages (like Japanese or Chinese) doc_stride (`int`, *optional*): If the context is too long to fit with the question for the model, it will be split in several chunks with some overlap. This argument controls the size of that overlap. handle_impossible_answer (`bool`, *optional*): Whether to accept impossible as an answer. max_answer_len (`int`, *optional*): The maximum length of predicted answers (e.g., only answers with a shorter length are considered). max_question_len (`int`, *optional*): The maximum length of the question after tokenization. It will be truncated if needed. max_seq_len (`int`, *optional*): The maximum length of the total sentence (context + question) in tokens of each chunk passed to the model. The context will be split in several chunks (using docStride as overlap) if needed. top_k (`int`, *optional*): The number of answers to return (will be chosen by order of likelihood). Note that we return less than topk answers if there are not enough options available within the context. Returns: Union[`QuestionAnsweringOutputElement`, List[`QuestionAnsweringOutputElement`]]: When top_k is 1 or not provided, it returns a single `QuestionAnsweringOutputElement`. When top_k is greater than 1, it returns a list of `QuestionAnsweringOutputElement`. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> await client.question_answering(question="What's my name?", context="My name is Clara and I live in Berkeley.") QuestionAnsweringOutputElement(answer='Clara', end=16, score=0.9326565265655518, start=11) ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="question-answering", model=model_id) request_parameters = provider_helper.prepare_request( inputs=None, parameters={ "align_to_words": align_to_words, "doc_stride": doc_stride, "handle_impossible_answer": handle_impossible_answer, "max_answer_len": max_answer_len, "max_question_len": max_question_len, "max_seq_len": max_seq_len, "top_k": top_k, }, extra_payload={"question": question, "context": context}, headers=self.headers, model=model_id, api_key=self.token, ) response = await self._inner_post(request_parameters) # Parse the response as a single `QuestionAnsweringOutputElement` when top_k is 1 or not provided, or a list of `QuestionAnsweringOutputElement` to ensure backward compatibility. output = QuestionAnsweringOutputElement.parse_obj(response) return output async def sentence_similarity( self, sentence: str, other_sentences: List[str], *, model: Optional[str] = None ) -> List[float]: """ Compute the semantic similarity between a sentence and a list of other sentences by comparing their embeddings. Args: sentence (`str`): The main sentence to compare to others. other_sentences (`List[str]`): The list of sentences to compare to. model (`str`, *optional*): The model to use for the sentence similarity task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended sentence similarity model will be used. Defaults to None. Returns: `List[float]`: The embedding representing the input text. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> await client.sentence_similarity( ... "Machine learning is so easy.", ... other_sentences=[ ... "Deep learning is so straightforward.", ... "This is so difficult, like rocket science.", ... "I can't believe how much I struggled with this.", ... ], ... ) [0.7785726189613342, 0.45876261591911316, 0.2906220555305481] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="sentence-similarity", model=model_id) request_parameters = provider_helper.prepare_request( inputs={"source_sentence": sentence, "sentences": other_sentences}, parameters={}, extra_payload={}, headers=self.headers, model=model_id, api_key=self.token, ) response = await self._inner_post(request_parameters) return _bytes_to_list(response) async def summarization( self, text: str, *, model: Optional[str] = None, clean_up_tokenization_spaces: Optional[bool] = None, generate_parameters: Optional[Dict[str, Any]] = None, truncation: Optional["SummarizationTruncationStrategy"] = None, ) -> SummarizationOutput: """ Generate a summary of a given text using a specified model. Args: text (`str`): The input text to summarize. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for summarization will be used. clean_up_tokenization_spaces (`bool`, *optional*): Whether to clean up the potential extra spaces in the text output. generate_parameters (`Dict[str, Any]`, *optional*): Additional parametrization of the text generation algorithm. truncation (`"SummarizationTruncationStrategy"`, *optional*): The truncation strategy to use. Returns: [`SummarizationOutput`]: The generated summary text. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> await client.summarization("The Eiffel tower...") SummarizationOutput(generated_text="The Eiffel tower is one of the most famous landmarks in the world....") ``` """ parameters = { "clean_up_tokenization_spaces": clean_up_tokenization_spaces, "generate_parameters": generate_parameters, "truncation": truncation, } model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="summarization", model=model_id) request_parameters = provider_helper.prepare_request( inputs=text, parameters=parameters, headers=self.headers, model=model_id, api_key=self.token, ) response = await self._inner_post(request_parameters) return SummarizationOutput.parse_obj_as_list(response)[0] async def table_question_answering( self, table: Dict[str, Any], query: str, *, model: Optional[str] = None, padding: Optional["Padding"] = None, sequential: Optional[bool] = None, truncation: Optional[bool] = None, ) -> TableQuestionAnsweringOutputElement: """ Retrieve the answer to a question from information given in a table. Args: table (`str`): A table of data represented as a dict of lists where entries are headers and the lists are all the values, all lists must have the same size. query (`str`): The query in plain text that you want to ask the table. model (`str`): The model to use for the table-question-answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. padding (`"Padding"`, *optional*): Activates and controls padding. sequential (`bool`, *optional*): Whether to do inference sequentially or as a batch. Batching is faster, but models like SQA require the inference to be done sequentially to extract relations within sequences, given their conversational nature. truncation (`bool`, *optional*): Activates and controls truncation. Returns: [`TableQuestionAnsweringOutputElement`]: a table question answering output containing the answer, coordinates, cells and the aggregator used. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> query = "How many stars does the transformers repository have?" >>> table = {"Repository": ["Transformers", "Datasets", "Tokenizers"], "Stars": ["36542", "4512", "3934"]} >>> await client.table_question_answering(table, query, model="google/tapas-base-finetuned-wtq") TableQuestionAnsweringOutputElement(answer='36542', coordinates=[[0, 1]], cells=['36542'], aggregator='AVERAGE') ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="table-question-answering", model=model_id) request_parameters = provider_helper.prepare_request( inputs=None, parameters={"model": model, "padding": padding, "sequential": sequential, "truncation": truncation}, extra_payload={"query": query, "table": table}, headers=self.headers, model=model_id, api_key=self.token, ) response = await self._inner_post(request_parameters) return TableQuestionAnsweringOutputElement.parse_obj_as_instance(response) async def tabular_classification(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[str]: """ Classifying a target category (a group) based on a set of attributes. Args: table (`Dict[str, Any]`): Set of attributes to classify. model (`str`, *optional*): The model to use for the tabular classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended tabular classification model will be used. Defaults to None. Returns: `List`: a list of labels, one per row in the initial table. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> table = { ... "fixed_acidity": ["7.4", "7.8", "10.3"], ... "volatile_acidity": ["0.7", "0.88", "0.32"], ... "citric_acid": ["0", "0", "0.45"], ... "residual_sugar": ["1.9", "2.6", "6.4"], ... "chlorides": ["0.076", "0.098", "0.073"], ... "free_sulfur_dioxide": ["11", "25", "5"], ... "total_sulfur_dioxide": ["34", "67", "13"], ... "density": ["0.9978", "0.9968", "0.9976"], ... "pH": ["3.51", "3.2", "3.23"], ... "sulphates": ["0.56", "0.68", "0.82"], ... "alcohol": ["9.4", "9.8", "12.6"], ... } >>> await client.tabular_classification(table=table, model="julien-c/wine-quality") ["5", "5", "5"] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="tabular-classification", model=model_id) request_parameters = provider_helper.prepare_request( inputs=None, extra_payload={"table": table}, parameters={}, headers=self.headers, model=model_id, api_key=self.token, ) response = await self._inner_post(request_parameters) return _bytes_to_list(response) async def tabular_regression(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[float]: """ Predicting a numerical target value given a set of attributes/features in a table. Args: table (`Dict[str, Any]`): Set of attributes stored in a table. The attributes used to predict the target can be both numerical and categorical. model (`str`, *optional*): The model to use for the tabular regression task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended tabular regression model will be used. Defaults to None. Returns: `List`: a list of predicted numerical target values. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> table = { ... "Height": ["11.52", "12.48", "12.3778"], ... "Length1": ["23.2", "24", "23.9"], ... "Length2": ["25.4", "26.3", "26.5"], ... "Length3": ["30", "31.2", "31.1"], ... "Species": ["Bream", "Bream", "Bream"], ... "Width": ["4.02", "4.3056", "4.6961"], ... } >>> await client.tabular_regression(table, model="scikit-learn/Fish-Weight") [110, 120, 130] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="tabular-regression", model=model_id) request_parameters = provider_helper.prepare_request( inputs=None, parameters={}, extra_payload={"table": table}, headers=self.headers, model=model_id, api_key=self.token, ) response = await self._inner_post(request_parameters) return _bytes_to_list(response) async def text_classification( self, text: str, *, model: Optional[str] = None, top_k: Optional[int] = None, function_to_apply: Optional["TextClassificationOutputTransform"] = None, ) -> List[TextClassificationOutputElement]: """ Perform text classification (e.g. sentiment-analysis) on the given text. Args: text (`str`): A string to be classified. model (`str`, *optional*): The model to use for the text classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended text classification model will be used. Defaults to None. top_k (`int`, *optional*): When specified, limits the output to the top K most probable classes. function_to_apply (`"TextClassificationOutputTransform"`, *optional*): The function to apply to the model outputs in order to retrieve the scores. Returns: `List[TextClassificationOutputElement]`: a list of [`TextClassificationOutputElement`] items containing the predicted label and associated probability. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> await client.text_classification("I like you") [ TextClassificationOutputElement(label='POSITIVE', score=0.9998695850372314), TextClassificationOutputElement(label='NEGATIVE', score=0.0001304351753788069), ] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="text-classification", model=model_id) request_parameters = provider_helper.prepare_request( inputs=text, parameters={ "function_to_apply": function_to_apply, "top_k": top_k, }, headers=self.headers, model=model_id, api_key=self.token, ) response = await self._inner_post(request_parameters) return TextClassificationOutputElement.parse_obj_as_list(response)[0] # type: ignore [return-value] @overload async def text_generation( # type: ignore self, prompt: str, *, details: Literal[False] = ..., stream: Literal[False] = ..., model: Optional[str] = None, # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) adapter_id: Optional[str] = None, best_of: Optional[int] = None, decoder_input_details: Optional[bool] = None, do_sample: Optional[bool] = False, # Manual default value frequency_penalty: Optional[float] = None, grammar: Optional[TextGenerationInputGrammarType] = None, max_new_tokens: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = False, # Manual default value seed: Optional[int] = None, stop: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, top_p: Optional[float] = None, truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: Optional[bool] = None, ) -> str: ... @overload async def text_generation( # type: ignore self, prompt: str, *, details: Literal[True] = ..., stream: Literal[False] = ..., model: Optional[str] = None, # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) adapter_id: Optional[str] = None, best_of: Optional[int] = None, decoder_input_details: Optional[bool] = None, do_sample: Optional[bool] = False, # Manual default value frequency_penalty: Optional[float] = None, grammar: Optional[TextGenerationInputGrammarType] = None, max_new_tokens: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = False, # Manual default value seed: Optional[int] = None, stop: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, top_p: Optional[float] = None, truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: Optional[bool] = None, ) -> TextGenerationOutput: ... @overload async def text_generation( # type: ignore self, prompt: str, *, details: Literal[False] = ..., stream: Literal[True] = ..., model: Optional[str] = None, # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) adapter_id: Optional[str] = None, best_of: Optional[int] = None, decoder_input_details: Optional[bool] = None, do_sample: Optional[bool] = False, # Manual default value frequency_penalty: Optional[float] = None, grammar: Optional[TextGenerationInputGrammarType] = None, max_new_tokens: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = False, # Manual default value seed: Optional[int] = None, stop: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, top_p: Optional[float] = None, truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: Optional[bool] = None, ) -> AsyncIterable[str]: ... @overload async def text_generation( # type: ignore self, prompt: str, *, details: Literal[True] = ..., stream: Literal[True] = ..., model: Optional[str] = None, # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) adapter_id: Optional[str] = None, best_of: Optional[int] = None, decoder_input_details: Optional[bool] = None, do_sample: Optional[bool] = False, # Manual default value frequency_penalty: Optional[float] = None, grammar: Optional[TextGenerationInputGrammarType] = None, max_new_tokens: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = False, # Manual default value seed: Optional[int] = None, stop: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, top_p: Optional[float] = None, truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: Optional[bool] = None, ) -> AsyncIterable[TextGenerationStreamOutput]: ... @overload async def text_generation( self, prompt: str, *, details: Literal[True] = ..., stream: bool = ..., model: Optional[str] = None, # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) adapter_id: Optional[str] = None, best_of: Optional[int] = None, decoder_input_details: Optional[bool] = None, do_sample: Optional[bool] = False, # Manual default value frequency_penalty: Optional[float] = None, grammar: Optional[TextGenerationInputGrammarType] = None, max_new_tokens: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = False, # Manual default value seed: Optional[int] = None, stop: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, top_p: Optional[float] = None, truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: Optional[bool] = None, ) -> Union[TextGenerationOutput, AsyncIterable[TextGenerationStreamOutput]]: ... async def text_generation( self, prompt: str, *, details: bool = False, stream: bool = False, model: Optional[str] = None, # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) adapter_id: Optional[str] = None, best_of: Optional[int] = None, decoder_input_details: Optional[bool] = None, do_sample: Optional[bool] = False, # Manual default value frequency_penalty: Optional[float] = None, grammar: Optional[TextGenerationInputGrammarType] = None, max_new_tokens: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = False, # Manual default value seed: Optional[int] = None, stop: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, top_p: Optional[float] = None, truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: Optional[bool] = None, ) -> Union[str, TextGenerationOutput, AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]: """ Given a prompt, generate the following text. If you want to generate a response from chat messages, you should use the [`InferenceClient.chat_completion`] method. It accepts a list of messages instead of a single text prompt and handles the chat templating for you. Args: prompt (`str`): Input text. details (`bool`, *optional*): By default, text_generation returns a string. Pass `details=True` if you want a detailed output (tokens, probabilities, seed, finish reason, etc.). Only available for models running on with the `text-generation-inference` backend. stream (`bool`, *optional*): By default, text_generation returns the full generated text. Pass `stream=True` if you want a stream of tokens to be returned. Only available for models running on with the `text-generation-inference` backend. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. adapter_id (`str`, *optional*): Lora adapter id. best_of (`int`, *optional*): Generate best_of sequences and return the one if the highest token logprobs. decoder_input_details (`bool`, *optional*): Return the decoder input token logprobs and ids. You must set `details=True` as well for it to be taken into account. Defaults to `False`. do_sample (`bool`, *optional*): Activate logits sampling frequency_penalty (`float`, *optional*): Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. grammar ([`TextGenerationInputGrammarType`], *optional*): Grammar constraints. Can be either a JSONSchema or a regex. max_new_tokens (`int`, *optional*): Maximum number of generated tokens. Defaults to 100. repetition_penalty (`float`, *optional*): The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. return_full_text (`bool`, *optional*): Whether to prepend the prompt to the generated text seed (`int`, *optional*): Random sampling seed stop (`List[str]`, *optional*): Stop generating tokens if a member of `stop` is generated. stop_sequences (`List[str]`, *optional*): Deprecated argument. Use `stop` instead. temperature (`float`, *optional*): The value used to module the logits distribution. top_n_tokens (`int`, *optional*): Return information about the `top_n_tokens` most likely tokens at each generation step, instead of just the sampled token. top_k (`int`, *optional`): The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p (`float`, *optional`): If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. truncate (`int`, *optional`): Truncate inputs tokens to the given size. typical_p (`float`, *optional`): Typical Decoding mass See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`, *optional`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Returns: `Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]`: Generated text returned from the server: - if `stream=False` and `details=False`, the generated text is returned as a `str` (default) - if `stream=True` and `details=False`, the generated text is returned token by token as a `Iterable[str]` - if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.TextGenerationOutput`] - if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.TextGenerationStreamOutput`] Raises: `ValidationError`: If input values are not valid. No HTTP call is made to the server. [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() # Case 1: generate text >>> await client.text_generation("The huggingface_hub library is ", max_new_tokens=12) '100% open source and built to be easy to use.' # Case 2: iterate over the generated tokens. Useful for large generation. >>> async for token in await client.text_generation("The huggingface_hub library is ", max_new_tokens=12, stream=True): ... print(token) 100 % open source and built to be easy to use . # Case 3: get more details about the generation process. >>> await client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True) TextGenerationOutput( generated_text='100% open source and built to be easy to use.', details=TextGenerationDetails( finish_reason='length', generated_tokens=12, seed=None, prefill=[ TextGenerationPrefillOutputToken(id=487, text='The', logprob=None), TextGenerationPrefillOutputToken(id=53789, text=' hugging', logprob=-13.171875), (...) TextGenerationPrefillOutputToken(id=204, text=' ', logprob=-7.0390625) ], tokens=[ TokenElement(id=1425, text='100', logprob=-1.0175781, special=False), TokenElement(id=16, text='%', logprob=-0.0463562, special=False), (...) TokenElement(id=25, text='.', logprob=-0.5703125, special=False) ], best_of_sequences=None ) ) # Case 4: iterate over the generated tokens with more details. # Last object is more complete, containing the full generated text and the finish reason. >>> async for details in await client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True, stream=True): ... print(details) ... TextGenerationStreamOutput(token=TokenElement(id=1425, text='100', logprob=-1.0175781, special=False), generated_text=None, details=None) TextGenerationStreamOutput(token=TokenElement(id=16, text='%', logprob=-0.0463562, special=False), generated_text=None, details=None) TextGenerationStreamOutput(token=TokenElement(id=1314, text=' open', logprob=-1.3359375, special=False), generated_text=None, details=None) TextGenerationStreamOutput(token=TokenElement(id=3178, text=' source', logprob=-0.28100586, special=False), generated_text=None, details=None) TextGenerationStreamOutput(token=TokenElement(id=273, text=' and', logprob=-0.5961914, special=False), generated_text=None, details=None) TextGenerationStreamOutput(token=TokenElement(id=3426, text=' built', logprob=-1.9423828, special=False), generated_text=None, details=None) TextGenerationStreamOutput(token=TokenElement(id=271, text=' to', logprob=-1.4121094, special=False), generated_text=None, details=None) TextGenerationStreamOutput(token=TokenElement(id=314, text=' be', logprob=-1.5224609, special=False), generated_text=None, details=None) TextGenerationStreamOutput(token=TokenElement(id=1833, text=' easy', logprob=-2.1132812, special=False), generated_text=None, details=None) TextGenerationStreamOutput(token=TokenElement(id=271, text=' to', logprob=-0.08520508, special=False), generated_text=None, details=None) TextGenerationStreamOutput(token=TokenElement(id=745, text=' use', logprob=-0.39453125, special=False), generated_text=None, details=None) TextGenerationStreamOutput(token=TokenElement( id=25, text='.', logprob=-0.5703125, special=False), generated_text='100% open source and built to be easy to use.', details=TextGenerationStreamOutputStreamDetails(finish_reason='length', generated_tokens=12, seed=None) ) # Case 5: generate constrained output using grammar >>> response = await client.text_generation( ... prompt="I saw a puppy a cat and a raccoon during my bike ride in the park", ... model="HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1", ... max_new_tokens=100, ... repetition_penalty=1.3, ... grammar={ ... "type": "json", ... "value": { ... "properties": { ... "location": {"type": "string"}, ... "activity": {"type": "string"}, ... "animals_seen": {"type": "integer", "minimum": 1, "maximum": 5}, ... "animals": {"type": "array", "items": {"type": "string"}}, ... }, ... "required": ["location", "activity", "animals_seen", "animals"], ... }, ... }, ... ) >>> json.loads(response) { "activity": "bike riding", "animals": ["puppy", "cat", "raccoon"], "animals_seen": 3, "location": "park" } ``` """ if decoder_input_details and not details: warnings.warn( "`decoder_input_details=True` has been passed to the server but `details=False` is set meaning that" " the output from the server will be truncated." ) decoder_input_details = False if stop_sequences is not None: warnings.warn( "`stop_sequences` is a deprecated argument for `text_generation` task" " and will be removed in version '0.28.0'. Use `stop` instead.", FutureWarning, ) if stop is None: stop = stop_sequences # use deprecated arg if provided # Build payload parameters = { "adapter_id": adapter_id, "best_of": best_of, "decoder_input_details": decoder_input_details, "details": details, "do_sample": do_sample, "frequency_penalty": frequency_penalty, "grammar": grammar, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, "return_full_text": return_full_text, "seed": seed, "stop": stop if stop is not None else [], "temperature": temperature, "top_k": top_k, "top_n_tokens": top_n_tokens, "top_p": top_p, "truncate": truncate, "typical_p": typical_p, "watermark": watermark, } # Remove some parameters if not a TGI server unsupported_kwargs = _get_unsupported_text_generation_kwargs(model) if len(unsupported_kwargs) > 0: # The server does not support some parameters # => means it is not a TGI server # => remove unsupported parameters and warn the user ignored_parameters = [] for key in unsupported_kwargs: if parameters.get(key): ignored_parameters.append(key) parameters.pop(key, None) if len(ignored_parameters) > 0: warnings.warn( "API endpoint/model for text-generation is not served via TGI. Ignoring following parameters:" f" {', '.join(ignored_parameters)}.", UserWarning, ) if details: warnings.warn( "API endpoint/model for text-generation is not served via TGI. Parameter `details=True` will" " be ignored meaning only the generated text will be returned.", UserWarning, ) details = False if stream: raise ValueError( "API endpoint/model for text-generation is not served via TGI. Cannot return output as a stream." " Please pass `stream=False` as input." ) model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="text-generation", model=model_id) request_parameters = provider_helper.prepare_request( inputs=prompt, parameters=parameters, extra_payload={"stream": stream}, headers=self.headers, model=model_id, api_key=self.token, ) # Handle errors separately for more precise error messages try: bytes_output = await self._inner_post(request_parameters, stream=stream) except _import_aiohttp().ClientResponseError as e: match = MODEL_KWARGS_NOT_USED_REGEX.search(e.response_error_payload["error"]) if e.status == 400 and match: unused_params = [kwarg.strip("' ") for kwarg in match.group(1).split(",")] _set_unsupported_text_generation_kwargs(model, unused_params) return await self.text_generation( # type: ignore prompt=prompt, details=details, stream=stream, model=model_id, adapter_id=adapter_id, best_of=best_of, decoder_input_details=decoder_input_details, do_sample=do_sample, frequency_penalty=frequency_penalty, grammar=grammar, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, return_full_text=return_full_text, seed=seed, stop=stop, temperature=temperature, top_k=top_k, top_n_tokens=top_n_tokens, top_p=top_p, truncate=truncate, typical_p=typical_p, watermark=watermark, ) raise_text_generation_error(e) # Parse output if stream: return _async_stream_text_generation_response(bytes_output, details) # type: ignore data = _bytes_to_dict(bytes_output) # type: ignore[arg-type] # Data can be a single element (dict) or an iterable of dicts where we select the first element of. if isinstance(data, list): data = data[0] response = provider_helper.get_response(data, request_parameters) return TextGenerationOutput.parse_obj_as_instance(response) if details else response["generated_text"] async def text_to_image( self, prompt: str, *, negative_prompt: Optional[str] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: Optional[int] = None, guidance_scale: Optional[float] = None, model: Optional[str] = None, scheduler: Optional[str] = None, seed: Optional[int] = None, extra_body: Optional[Dict[str, Any]] = None, ) -> "Image": """ Generate an image based on a given text using a specified model. You must have `PIL` installed if you want to work with images (`pip install Pillow`). You can pass provider-specific parameters to the model by using the `extra_body` argument. Args: prompt (`str`): The prompt to generate an image from. negative_prompt (`str`, *optional*): One prompt to guide what NOT to include in image generation. height (`int`, *optional*): The height in pixels of the output image width (`int`, *optional*): The width in pixels of the output image num_inference_steps (`int`, *optional*): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*): A higher guidance scale value encourages the model to generate images closely linked to the text prompt, but values too high may cause saturation and other artifacts. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended text-to-image model will be used. Defaults to None. scheduler (`str`, *optional*): Override the scheduler with a compatible one. seed (`int`, *optional*): Seed for the random number generator. extra_body (`Dict[str, Any]`, *optional*): Additional provider-specific parameters to pass to the model. Refer to the provider's documentation for supported parameters. Returns: `Image`: The generated image. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> image = await client.text_to_image("An astronaut riding a horse on the moon.") >>> image.save("astronaut.png") >>> image = await client.text_to_image( ... "An astronaut riding a horse on the moon.", ... negative_prompt="low resolution, blurry", ... model="stabilityai/stable-diffusion-2-1", ... ) >>> image.save("better_astronaut.png") ``` Example using a third-party provider directly. Usage will be billed on your fal.ai account. ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient( ... provider="fal-ai", # Use fal.ai provider ... api_key="fal-ai-api-key", # Pass your fal.ai API key ... ) >>> image = client.text_to_image( ... "A majestic lion in a fantasy forest", ... model="black-forest-labs/FLUX.1-schnell", ... ) >>> image.save("lion.png") ``` Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account. ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient( ... provider="replicate", # Use replicate provider ... api_key="hf_...", # Pass your HF token ... ) >>> image = client.text_to_image( ... "An astronaut riding a horse on the moon.", ... model="black-forest-labs/FLUX.1-dev", ... ) >>> image.save("astronaut.png") ``` Example using Replicate provider with extra parameters ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient( ... provider="replicate", # Use replicate provider ... api_key="hf_...", # Pass your HF token ... ) >>> image = client.text_to_image( ... "An astronaut riding a horse on the moon.", ... model="black-forest-labs/FLUX.1-schnell", ... extra_body={"output_quality": 100}, ... ) >>> image.save("astronaut.png") ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="text-to-image", model=model_id) request_parameters = provider_helper.prepare_request( inputs=prompt, parameters={ "negative_prompt": negative_prompt, "height": height, "width": width, "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, "scheduler": scheduler, "seed": seed, **(extra_body or {}), }, headers=self.headers, model=model_id, api_key=self.token, ) response = await self._inner_post(request_parameters) response = provider_helper.get_response(response) return _bytes_to_image(response) async def text_to_video( self, prompt: str, *, model: Optional[str] = None, guidance_scale: Optional[float] = None, negative_prompt: Optional[List[str]] = None, num_frames: Optional[float] = None, num_inference_steps: Optional[int] = None, seed: Optional[int] = None, extra_body: Optional[Dict[str, Any]] = None, ) -> bytes: """ Generate a video based on a given text. You can pass provider-specific parameters to the model by using the `extra_body` argument. Args: prompt (`str`): The prompt to generate a video from. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended text-to-video model will be used. Defaults to None. guidance_scale (`float`, *optional*): A higher guidance scale value encourages the model to generate videos closely linked to the text prompt, but values too high may cause saturation and other artifacts. negative_prompt (`List[str]`, *optional*): One or several prompt to guide what NOT to include in video generation. num_frames (`float`, *optional*): The num_frames parameter determines how many video frames are generated. num_inference_steps (`int`, *optional*): The number of denoising steps. More denoising steps usually lead to a higher quality video at the expense of slower inference. seed (`int`, *optional*): Seed for the random number generator. extra_body (`Dict[str, Any]`, *optional*): Additional provider-specific parameters to pass to the model. Refer to the provider's documentation for supported parameters. Returns: `bytes`: The generated video. Example: Example using a third-party provider directly. Usage will be billed on your fal.ai account. ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient( ... provider="fal-ai", # Using fal.ai provider ... api_key="fal-ai-api-key", # Pass your fal.ai API key ... ) >>> video = client.text_to_video( ... "A majestic lion running in a fantasy forest", ... model="tencent/HunyuanVideo", ... ) >>> with open("lion.mp4", "wb") as file: ... file.write(video) ``` Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account. ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient( ... provider="replicate", # Using replicate provider ... api_key="hf_...", # Pass your HF token ... ) >>> video = client.text_to_video( ... "A cat running in a park", ... model="genmo/mochi-1-preview", ... ) >>> with open("cat.mp4", "wb") as file: ... file.write(video) ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="text-to-video", model=model_id) request_parameters = provider_helper.prepare_request( inputs=prompt, parameters={ "guidance_scale": guidance_scale, "negative_prompt": negative_prompt, "num_frames": num_frames, "num_inference_steps": num_inference_steps, "seed": seed, **(extra_body or {}), }, headers=self.headers, model=model_id, api_key=self.token, ) response = await self._inner_post(request_parameters) response = provider_helper.get_response(response, request_parameters) return response async def text_to_speech( self, text: str, *, model: Optional[str] = None, do_sample: Optional[bool] = None, early_stopping: Optional[Union[bool, "TextToSpeechEarlyStoppingEnum"]] = None, epsilon_cutoff: Optional[float] = None, eta_cutoff: Optional[float] = None, max_length: Optional[int] = None, max_new_tokens: Optional[int] = None, min_length: Optional[int] = None, min_new_tokens: Optional[int] = None, num_beam_groups: Optional[int] = None, num_beams: Optional[int] = None, penalty_alpha: Optional[float] = None, temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, typical_p: Optional[float] = None, use_cache: Optional[bool] = None, extra_body: Optional[Dict[str, Any]] = None, ) -> bytes: """ Synthesize an audio of a voice pronouncing a given text. You can pass provider-specific parameters to the model by using the `extra_body` argument. Args: text (`str`): The text to synthesize. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended text-to-speech model will be used. Defaults to None. do_sample (`bool`, *optional*): Whether to use sampling instead of greedy decoding when generating new tokens. early_stopping (`Union[bool, "TextToSpeechEarlyStoppingEnum"]`, *optional*): Controls the stopping condition for beam-based methods. epsilon_cutoff (`float`, *optional*): If set to float strictly between 0 and 1, only tokens with a conditional probability greater than epsilon_cutoff will be sampled. In the paper, suggested values range from 3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191) for more details. eta_cutoff (`float`, *optional*): Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to float strictly between 0 and 1, a token is only considered if it is greater than either eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model. See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191) for more details. max_length (`int`, *optional*): The maximum length (in tokens) of the generated text, including the input. max_new_tokens (`int`, *optional*): The maximum number of tokens to generate. Takes precedence over max_length. min_length (`int`, *optional*): The minimum length (in tokens) of the generated text, including the input. min_new_tokens (`int`, *optional*): The minimum number of tokens to generate. Takes precedence over min_length. num_beam_groups (`int`, *optional*): Number of groups to divide num_beams into in order to ensure diversity among different groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details. num_beams (`int`, *optional*): Number of beams to use for beam search. penalty_alpha (`float`, *optional*): The value balances the model confidence and the degeneration penalty in contrastive search decoding. temperature (`float`, *optional*): The value used to modulate the next token probabilities. top_k (`int`, *optional*): The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p (`float`, *optional*): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. typical_p (`float`, *optional*): Local typicality measures how similar the conditional probability of predicting a target token next is to the expected conditional probability of predicting a random token next, given the partial text already generated. If set to float < 1, the smallest set of the most locally typical tokens with probabilities that add up to typical_p or higher are kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details. use_cache (`bool`, *optional*): Whether the model should use the past last key/values attentions to speed up decoding extra_body (`Dict[str, Any]`, *optional*): Additional provider-specific parameters to pass to the model. Refer to the provider's documentation for supported parameters. Returns: `bytes`: The generated audio. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py # Must be run in an async context >>> from pathlib import Path >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> audio = await client.text_to_speech("Hello world") >>> Path("hello_world.flac").write_bytes(audio) ``` Example using a third-party provider directly. Usage will be billed on your Replicate account. ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient( ... provider="replicate", ... api_key="your-replicate-api-key", # Pass your Replicate API key directly ... ) >>> audio = client.text_to_speech( ... text="Hello world", ... model="OuteAI/OuteTTS-0.3-500M", ... ) >>> Path("hello_world.flac").write_bytes(audio) ``` Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account. ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient( ... provider="replicate", ... api_key="hf_...", # Pass your HF token ... ) >>> audio =client.text_to_speech( ... text="Hello world", ... model="OuteAI/OuteTTS-0.3-500M", ... ) >>> Path("hello_world.flac").write_bytes(audio) ``` Example using Replicate provider with extra parameters ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient( ... provider="replicate", # Use replicate provider ... api_key="hf_...", # Pass your HF token ... ) >>> audio = client.text_to_speech( ... "Hello, my name is Kororo, an awesome text-to-speech model.", ... model="hexgrad/Kokoro-82M", ... extra_body={"voice": "af_nicole"}, ... ) >>> Path("hello.flac").write_bytes(audio) ``` Example music-gen using "YuE-s1-7B-anneal-en-cot" on fal.ai ```py >>> from huggingface_hub import InferenceClient >>> lyrics = ''' ... [verse] ... In the town where I was born ... Lived a man who sailed to sea ... And he told us of his life ... In the land of submarines ... So we sailed on to the sun ... 'Til we found a sea of green ... And we lived beneath the waves ... In our yellow submarine ... [chorus] ... We all live in a yellow submarine ... Yellow submarine, yellow submarine ... We all live in a yellow submarine ... Yellow submarine, yellow submarine ... ''' >>> genres = "pavarotti-style tenor voice" >>> client = InferenceClient( ... provider="fal-ai", ... model="m-a-p/YuE-s1-7B-anneal-en-cot", ... api_key=..., ... ) >>> audio = client.text_to_speech(lyrics, extra_body={"genres": genres}) >>> with open("output.mp3", "wb") as f: ... f.write(audio) ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="text-to-speech", model=model_id) request_parameters = provider_helper.prepare_request( inputs=text, parameters={ "do_sample": do_sample, "early_stopping": early_stopping, "epsilon_cutoff": epsilon_cutoff, "eta_cutoff": eta_cutoff, "max_length": max_length, "max_new_tokens": max_new_tokens, "min_length": min_length, "min_new_tokens": min_new_tokens, "num_beam_groups": num_beam_groups, "num_beams": num_beams, "penalty_alpha": penalty_alpha, "temperature": temperature, "top_k": top_k, "top_p": top_p, "typical_p": typical_p, "use_cache": use_cache, **(extra_body or {}), }, headers=self.headers, model=model_id, api_key=self.token, ) response = await self._inner_post(request_parameters) response = provider_helper.get_response(response) return response async def token_classification( self, text: str, *, model: Optional[str] = None, aggregation_strategy: Optional["TokenClassificationAggregationStrategy"] = None, ignore_labels: Optional[List[str]] = None, stride: Optional[int] = None, ) -> List[TokenClassificationOutputElement]: """ Perform token classification on the given text. Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text. Args: text (`str`): A string to be classified. model (`str`, *optional*): The model to use for the token classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended token classification model will be used. Defaults to None. aggregation_strategy (`"TokenClassificationAggregationStrategy"`, *optional*): The strategy used to fuse tokens based on model predictions ignore_labels (`List[str`, *optional*): A list of labels to ignore stride (`int`, *optional*): The number of overlapping tokens between chunks when splitting the input text. Returns: `List[TokenClassificationOutputElement]`: List of [`TokenClassificationOutputElement`] items containing the entity group, confidence score, word, start and end index. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> await client.token_classification("My name is Sarah Jessica Parker but you can call me Jessica") [ TokenClassificationOutputElement( entity_group='PER', score=0.9971321225166321, word='Sarah Jessica Parker', start=11, end=31, ), TokenClassificationOutputElement( entity_group='PER', score=0.9773476123809814, word='Jessica', start=52, end=59, ) ] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="token-classification", model=model_id) request_parameters = provider_helper.prepare_request( inputs=text, parameters={ "aggregation_strategy": aggregation_strategy, "ignore_labels": ignore_labels, "stride": stride, }, headers=self.headers, model=model_id, api_key=self.token, ) response = await self._inner_post(request_parameters) return TokenClassificationOutputElement.parse_obj_as_list(response) async def translation( self, text: str, *, model: Optional[str] = None, src_lang: Optional[str] = None, tgt_lang: Optional[str] = None, clean_up_tokenization_spaces: Optional[bool] = None, truncation: Optional["TranslationTruncationStrategy"] = None, generate_parameters: Optional[Dict[str, Any]] = None, ) -> TranslationOutput: """ Convert text from one language to another. Check out https://huggingface.co/tasks/translation for more information on how to choose the best model for your specific use case. Source and target languages usually depend on the model. However, it is possible to specify source and target languages for certain models. If you are working with one of these models, you can use `src_lang` and `tgt_lang` arguments to pass the relevant information. Args: text (`str`): A string to be translated. model (`str`, *optional*): The model to use for the translation task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended translation model will be used. Defaults to None. src_lang (`str`, *optional*): The source language of the text. Required for models that can translate from multiple languages. tgt_lang (`str`, *optional*): Target language to translate to. Required for models that can translate to multiple languages. clean_up_tokenization_spaces (`bool`, *optional*): Whether to clean up the potential extra spaces in the text output. truncation (`"TranslationTruncationStrategy"`, *optional*): The truncation strategy to use. generate_parameters (`Dict[str, Any]`, *optional*): Additional parametrization of the text generation algorithm. Returns: [`TranslationOutput`]: The generated translated text. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. `ValueError`: If only one of the `src_lang` and `tgt_lang` arguments are provided. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> await client.translation("My name is Wolfgang and I live in Berlin") 'Mein Name ist Wolfgang und ich lebe in Berlin.' >>> await client.translation("My name is Wolfgang and I live in Berlin", model="Helsinki-NLP/opus-mt-en-fr") TranslationOutput(translation_text='Je m'appelle Wolfgang et je vis à Berlin.') ``` Specifying languages: ```py >>> client.translation("My name is Sarah Jessica Parker but you can call me Jessica", model="facebook/mbart-large-50-many-to-many-mmt", src_lang="en_XX", tgt_lang="fr_XX") "Mon nom est Sarah Jessica Parker mais vous pouvez m'appeler Jessica" ``` """ # Throw error if only one of `src_lang` and `tgt_lang` was given if src_lang is not None and tgt_lang is None: raise ValueError("You cannot specify `src_lang` without specifying `tgt_lang`.") if src_lang is None and tgt_lang is not None: raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.") model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="translation", model=model_id) request_parameters = provider_helper.prepare_request( inputs=text, parameters={ "src_lang": src_lang, "tgt_lang": tgt_lang, "clean_up_tokenization_spaces": clean_up_tokenization_spaces, "truncation": truncation, "generate_parameters": generate_parameters, }, headers=self.headers, model=model_id, api_key=self.token, ) response = await self._inner_post(request_parameters) return TranslationOutput.parse_obj_as_list(response)[0] async def visual_question_answering( self, image: ContentT, question: str, *, model: Optional[str] = None, top_k: Optional[int] = None, ) -> List[VisualQuestionAnsweringOutputElement]: """ Answering open-ended questions based on an image. Args: image (`Union[str, Path, bytes, BinaryIO]`): The input image for the context. It can be raw bytes, an image file, or a URL to an online image. question (`str`): Question to be answered. model (`str`, *optional*): The model to use for the visual question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended visual question answering model will be used. Defaults to None. top_k (`int`, *optional*): The number of answers to return (will be chosen by order of likelihood). Note that we return less than topk answers if there are not enough options available within the context. Returns: `List[VisualQuestionAnsweringOutputElement]`: a list of [`VisualQuestionAnsweringOutputElement`] items containing the predicted label and associated probability. Raises: `InferenceTimeoutError`: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> await client.visual_question_answering( ... image="https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg", ... question="What is the animal doing?" ... ) [ VisualQuestionAnsweringOutputElement(score=0.778609573841095, answer='laying down'), VisualQuestionAnsweringOutputElement(score=0.6957435607910156, answer='sitting'), ] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="visual-question-answering", model=model_id) request_parameters = provider_helper.prepare_request( inputs=image, parameters={"top_k": top_k}, headers=self.headers, model=model_id, api_key=self.token, extra_payload={"question": question, "image": _b64_encode(image)}, ) response = await self._inner_post(request_parameters) return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response) async def zero_shot_classification( self, text: str, candidate_labels: List[str], *, multi_label: Optional[bool] = False, hypothesis_template: Optional[str] = None, model: Optional[str] = None, ) -> List[ZeroShotClassificationOutputElement]: """ Provide as input a text and a set of candidate labels to classify the input text. Args: text (`str`): The input text to classify. candidate_labels (`List[str]`): The set of possible class labels to classify the text into. labels (`List[str]`, *optional*): (deprecated) List of strings. Each string is the verbalization of a possible label for the input text. multi_label (`bool`, *optional*): Whether multiple candidate labels can be true. If false, the scores are normalized such that the sum of the label likelihoods for each sequence is 1. If true, the labels are considered independent and probabilities are normalized for each candidate. hypothesis_template (`str`, *optional*): The sentence used in conjunction with `candidate_labels` to attempt the text classification by replacing the placeholder with the candidate labels. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. This parameter overrides the model defined at the instance level. If not provided, the default recommended zero-shot classification model will be used. Returns: `List[ZeroShotClassificationOutputElement]`: List of [`ZeroShotClassificationOutputElement`] items containing the predicted labels and their confidence. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. Example with `multi_label=False`: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> text = ( ... "A new model offers an explanation for how the Galilean satellites formed around the solar system's" ... "largest world. Konstantin Batygin did not set out to solve one of the solar system's most puzzling" ... " mysteries when he went for a run up a hill in Nice, France." ... ) >>> labels = ["space & cosmos", "scientific discovery", "microbiology", "robots", "archeology"] >>> await client.zero_shot_classification(text, labels) [ ZeroShotClassificationOutputElement(label='scientific discovery', score=0.7961668968200684), ZeroShotClassificationOutputElement(label='space & cosmos', score=0.18570658564567566), ZeroShotClassificationOutputElement(label='microbiology', score=0.00730885099619627), ZeroShotClassificationOutputElement(label='archeology', score=0.006258360575884581), ZeroShotClassificationOutputElement(label='robots', score=0.004559356719255447), ] >>> await client.zero_shot_classification(text, labels, multi_label=True) [ ZeroShotClassificationOutputElement(label='scientific discovery', score=0.9829297661781311), ZeroShotClassificationOutputElement(label='space & cosmos', score=0.755190908908844), ZeroShotClassificationOutputElement(label='microbiology', score=0.0005462635890580714), ZeroShotClassificationOutputElement(label='archeology', score=0.00047131875180639327), ZeroShotClassificationOutputElement(label='robots', score=0.00030448526376858354), ] ``` Example with `multi_label=True` and a custom `hypothesis_template`: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> await client.zero_shot_classification( ... text="I really like our dinner and I'm very happy. I don't like the weather though.", ... labels=["positive", "negative", "pessimistic", "optimistic"], ... multi_label=True, ... hypothesis_template="This text is {} towards the weather" ... ) [ ZeroShotClassificationOutputElement(label='negative', score=0.9231801629066467), ZeroShotClassificationOutputElement(label='pessimistic', score=0.8760990500450134), ZeroShotClassificationOutputElement(label='optimistic', score=0.0008674879791215062), ZeroShotClassificationOutputElement(label='positive', score=0.0005250611575320363) ] ``` """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="zero-shot-classification", model=model_id) request_parameters = provider_helper.prepare_request( inputs=text, parameters={ "candidate_labels": candidate_labels, "multi_label": multi_label, "hypothesis_template": hypothesis_template, }, headers=self.headers, model=model_id, api_key=self.token, ) response = await self._inner_post(request_parameters) output = _bytes_to_dict(response) return [ ZeroShotClassificationOutputElement.parse_obj_as_instance({"label": label, "score": score}) for label, score in zip(output["labels"], output["scores"]) ] async def zero_shot_image_classification( self, image: ContentT, candidate_labels: List[str], *, model: Optional[str] = None, hypothesis_template: Optional[str] = None, # deprecated argument labels: List[str] = None, # type: ignore ) -> List[ZeroShotImageClassificationOutputElement]: """ Provide input image and text labels to predict text labels for the image. Args: image (`Union[str, Path, bytes, BinaryIO]`): The input image to caption. It can be raw bytes, an image file, or a URL to an online image. candidate_labels (`List[str]`): The candidate labels for this image labels (`List[str]`, *optional*): (deprecated) List of string possible labels. There must be at least 2 labels. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. This parameter overrides the model defined at the instance level. If not provided, the default recommended zero-shot image classification model will be used. hypothesis_template (`str`, *optional*): The sentence used in conjunction with `candidate_labels` to attempt the image classification by replacing the placeholder with the candidate labels. Returns: `List[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> await client.zero_shot_image_classification( ... "https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg", ... labels=["dog", "cat", "horse"], ... ) [ZeroShotImageClassificationOutputElement(label='dog', score=0.956),...] ``` """ # Raise ValueError if input is less than 2 labels if len(candidate_labels) < 2: raise ValueError("You must specify at least 2 classes to compare.") model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="zero-shot-image-classification", model=model_id) request_parameters = provider_helper.prepare_request( inputs=image, parameters={ "candidate_labels": candidate_labels, "hypothesis_template": hypothesis_template, }, headers=self.headers, model=model_id, api_key=self.token, ) response = await self._inner_post(request_parameters) return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response) @_deprecate_method( version="0.33.0", message=( "HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)." " Use `HfApi.list_models(..., inference_provider='...')` to list warm models per provider." ), ) async def list_deployed_models( self, frameworks: Union[None, str, Literal["all"], List[str]] = None ) -> Dict[str, List[str]]: """ List models deployed on the HF Serverless Inference API service. This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that are supported and account for 95% of the hosted models. However, if you want a complete list of models you can specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more frameworks are checked, the more time it will take. This endpoint method does not return a live list of all models available for the HF Inference API service. It searches over a cached list of models that were recently available and the list may not be up to date. If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`]. This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to check its availability, you can directly use [`~InferenceClient.get_model_status`]. Args: frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*): The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to "all", all available frameworks will be tested. It is also possible to provide a single framework or a custom set of frameworks to check. Returns: `Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs. Example: ```py # Must be run in an async contextthon >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() # Discover zero-shot-classification models currently deployed >>> models = await client.list_deployed_models() >>> models["zero-shot-classification"] ['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...] # List from only 1 framework >>> await client.list_deployed_models("text-generation-inference") {'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...} ``` """ if self.provider != "hf-inference": raise ValueError(f"Listing deployed models is not supported on '{self.provider}'.") # Resolve which frameworks to check if frameworks is None: frameworks = constants.MAIN_INFERENCE_API_FRAMEWORKS elif frameworks == "all": frameworks = constants.ALL_INFERENCE_API_FRAMEWORKS elif isinstance(frameworks, str): frameworks = [frameworks] frameworks = list(set(frameworks)) # Fetch them iteratively models_by_task: Dict[str, List[str]] = {} def _unpack_response(framework: str, items: List[Dict]) -> None: for model in items: if framework == "sentence-transformers": # Model running with the `sentence-transformers` framework can work with both tasks even if not # branded as such in the API response models_by_task.setdefault("feature-extraction", []).append(model["model_id"]) models_by_task.setdefault("sentence-similarity", []).append(model["model_id"]) else: models_by_task.setdefault(model["task"], []).append(model["model_id"]) for framework in frameworks: response = get_session().get( f"{constants.INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token) ) hf_raise_for_status(response) _unpack_response(framework, response.json()) # Sort alphabetically for discoverability and return for task, models in models_by_task.items(): models_by_task[task] = sorted(set(models), key=lambda x: x.lower()) return models_by_task def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession": aiohttp = _import_aiohttp() client_headers = self.headers.copy() if headers is not None: client_headers.update(headers) # Return a new aiohttp ClientSession with correct settings. session = aiohttp.ClientSession( headers=client_headers, cookies=self.cookies, timeout=aiohttp.ClientTimeout(self.timeout), trust_env=self.trust_env, ) # Keep track of sessions to close them later self._sessions[session] = set() # Override the `._request` method to register responses to be closed session._wrapped_request = session._request async def _request(method, url, **kwargs): response = await session._wrapped_request(method, url, **kwargs) self._sessions[session].add(response) return response session._request = _request # Override the 'close' method to # 1. close ongoing responses # 2. deregister the session when closed session._close = session.close async def close_session(): for response in self._sessions[session]: response.close() await session._close() self._sessions.pop(session, None) session.close = close_session return session async def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]: """ Get information about the deployed endpoint. This endpoint is only available on endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI). Endpoints powered by `transformers` return an empty payload. Args: model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. Returns: `Dict[str, Any]`: Information about the endpoint. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") >>> await client.get_endpoint_info() { 'model_id': 'meta-llama/Meta-Llama-3-70B-Instruct', 'model_sha': None, 'model_dtype': 'torch.float16', 'model_device_type': 'cuda', 'model_pipeline_tag': None, 'max_concurrent_requests': 128, 'max_best_of': 2, 'max_stop_sequences': 4, 'max_input_length': 8191, 'max_total_tokens': 8192, 'waiting_served_ratio': 0.3, 'max_batch_total_tokens': 1259392, 'max_waiting_tokens': 20, 'max_batch_size': None, 'validation_workers': 32, 'max_client_batch_size': 4, 'version': '2.0.2', 'sha': 'dccab72549635c7eb5ddb17f43f0b7cdff07c214', 'docker_label': 'sha-dccab72' } ``` """ if self.provider != "hf-inference": raise ValueError(f"Getting endpoint info is not supported on '{self.provider}'.") model = model or self.model if model is None: raise ValueError("Model id not provided.") if model.startswith(("http://", "https://")): url = model.rstrip("/") + "/info" else: url = f"{constants.INFERENCE_ENDPOINT}/models/{model}/info" async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client: response = await client.get(url, proxy=self.proxies) response.raise_for_status() return await response.json() async def health_check(self, model: Optional[str] = None) -> bool: """ Check the health of the deployed endpoint. Health check is only available with Inference Endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI). For Inference API, please use [`InferenceClient.get_model_status`] instead. Args: model (`str`, *optional*): URL of the Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. Returns: `bool`: True if everything is working fine. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient("https://jzgu0buei5.us-east-1.aws.endpoints.huggingface.cloud") >>> await client.health_check() True ``` """ if self.provider != "hf-inference": raise ValueError(f"Health check is not supported on '{self.provider}'.") model = model or self.model if model is None: raise ValueError("Model id not provided.") if not model.startswith(("http://", "https://")): raise ValueError( "Model must be an Inference Endpoint URL. For serverless Inference API, please use `InferenceClient.get_model_status`." ) url = model.rstrip("/") + "/health" async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client: response = await client.get(url, proxy=self.proxies) return response.status == 200 @_deprecate_method( version="0.33.0", message=( "HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)." " Use `HfApi.model_info` to get the model status both with HF Inference API and external providers." ), ) async def get_model_status(self, model: Optional[str] = None) -> ModelStatus: """ Get the status of a model hosted on the HF Inference API. This endpoint is mostly useful when you already know which model you want to use and want to check its availability. If you want to discover already deployed models, you should rather use [`~InferenceClient.list_deployed_models`]. Args: model (`str`, *optional*): Identifier of the model for witch the status gonna be checked. If model is not provided, the model associated with this instance of [`InferenceClient`] will be used. Only HF Inference API service can be checked so the identifier cannot be a URL. Returns: [`ModelStatus`]: An instance of ModelStatus dataclass, containing information, about the state of the model: load, state, compute type and framework. Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> await client.get_model_status("meta-llama/Meta-Llama-3-8B-Instruct") ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference') ``` """ if self.provider != "hf-inference": raise ValueError(f"Getting model status is not supported on '{self.provider}'.") model = model or self.model if model is None: raise ValueError("Model id not provided.") if model.startswith("https://"): raise NotImplementedError("Model status is only available for Inference API endpoints.") url = f"{constants.INFERENCE_ENDPOINT}/status/{model}" async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client: response = await client.get(url, proxy=self.proxies) response.raise_for_status() response_data = await response.json() if "error" in response_data: raise ValueError(response_data["error"]) return ModelStatus( loaded=response_data["loaded"], state=response_data["state"], compute_type=response_data["compute_type"], framework=response_data["framework"], ) @property def chat(self) -> "ProxyClientChat": return ProxyClientChat(self) class _ProxyClient: """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client.""" def __init__(self, client: AsyncInferenceClient): self._client = client class ProxyClientChat(_ProxyClient): """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client.""" @property def completions(self) -> "ProxyClientChatCompletions": return ProxyClientChatCompletions(self._client) class ProxyClientChatCompletions(_ProxyClient): """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client.""" @property def create(self): return self._client.chat_completion huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/000077500000000000000000000000001500667546600260135ustar00rootroot00000000000000huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/__init__.py000066400000000000000000000142431500667546600301300ustar00rootroot00000000000000# This file is auto-generated by `utils/generate_inference_types.py`. # Do not modify it manually. # # ruff: noqa: F401 from .audio_classification import ( AudioClassificationInput, AudioClassificationOutputElement, AudioClassificationOutputTransform, AudioClassificationParameters, ) from .audio_to_audio import AudioToAudioInput, AudioToAudioOutputElement from .automatic_speech_recognition import ( AutomaticSpeechRecognitionEarlyStoppingEnum, AutomaticSpeechRecognitionGenerationParameters, AutomaticSpeechRecognitionInput, AutomaticSpeechRecognitionOutput, AutomaticSpeechRecognitionOutputChunk, AutomaticSpeechRecognitionParameters, ) from .base import BaseInferenceType from .chat_completion import ( ChatCompletionInput, ChatCompletionInputFunctionDefinition, ChatCompletionInputFunctionName, ChatCompletionInputGrammarType, ChatCompletionInputGrammarTypeType, ChatCompletionInputMessage, ChatCompletionInputMessageChunk, ChatCompletionInputMessageChunkType, ChatCompletionInputStreamOptions, ChatCompletionInputTool, ChatCompletionInputToolCall, ChatCompletionInputToolChoiceClass, ChatCompletionInputToolChoiceEnum, ChatCompletionInputURL, ChatCompletionOutput, ChatCompletionOutputComplete, ChatCompletionOutputFunctionDefinition, ChatCompletionOutputLogprob, ChatCompletionOutputLogprobs, ChatCompletionOutputMessage, ChatCompletionOutputToolCall, ChatCompletionOutputTopLogprob, ChatCompletionOutputUsage, ChatCompletionStreamOutput, ChatCompletionStreamOutputChoice, ChatCompletionStreamOutputDelta, ChatCompletionStreamOutputDeltaToolCall, ChatCompletionStreamOutputFunction, ChatCompletionStreamOutputLogprob, ChatCompletionStreamOutputLogprobs, ChatCompletionStreamOutputTopLogprob, ChatCompletionStreamOutputUsage, ) from .depth_estimation import DepthEstimationInput, DepthEstimationOutput from .document_question_answering import ( DocumentQuestionAnsweringInput, DocumentQuestionAnsweringInputData, DocumentQuestionAnsweringOutputElement, DocumentQuestionAnsweringParameters, ) from .feature_extraction import FeatureExtractionInput, FeatureExtractionInputTruncationDirection from .fill_mask import FillMaskInput, FillMaskOutputElement, FillMaskParameters from .image_classification import ( ImageClassificationInput, ImageClassificationOutputElement, ImageClassificationOutputTransform, ImageClassificationParameters, ) from .image_segmentation import ( ImageSegmentationInput, ImageSegmentationOutputElement, ImageSegmentationParameters, ImageSegmentationSubtask, ) from .image_to_image import ImageToImageInput, ImageToImageOutput, ImageToImageParameters, ImageToImageTargetSize from .image_to_text import ( ImageToTextEarlyStoppingEnum, ImageToTextGenerationParameters, ImageToTextInput, ImageToTextOutput, ImageToTextParameters, ) from .object_detection import ( ObjectDetectionBoundingBox, ObjectDetectionInput, ObjectDetectionOutputElement, ObjectDetectionParameters, ) from .question_answering import ( QuestionAnsweringInput, QuestionAnsweringInputData, QuestionAnsweringOutputElement, QuestionAnsweringParameters, ) from .sentence_similarity import SentenceSimilarityInput, SentenceSimilarityInputData from .summarization import ( SummarizationInput, SummarizationOutput, SummarizationParameters, SummarizationTruncationStrategy, ) from .table_question_answering import ( Padding, TableQuestionAnsweringInput, TableQuestionAnsweringInputData, TableQuestionAnsweringOutputElement, TableQuestionAnsweringParameters, ) from .text2text_generation import ( Text2TextGenerationInput, Text2TextGenerationOutput, Text2TextGenerationParameters, Text2TextGenerationTruncationStrategy, ) from .text_classification import ( TextClassificationInput, TextClassificationOutputElement, TextClassificationOutputTransform, TextClassificationParameters, ) from .text_generation import ( TextGenerationInput, TextGenerationInputGenerateParameters, TextGenerationInputGrammarType, TextGenerationOutput, TextGenerationOutputBestOfSequence, TextGenerationOutputDetails, TextGenerationOutputFinishReason, TextGenerationOutputPrefillToken, TextGenerationOutputToken, TextGenerationStreamOutput, TextGenerationStreamOutputStreamDetails, TextGenerationStreamOutputToken, TypeEnum, ) from .text_to_audio import ( TextToAudioEarlyStoppingEnum, TextToAudioGenerationParameters, TextToAudioInput, TextToAudioOutput, TextToAudioParameters, ) from .text_to_image import TextToImageInput, TextToImageOutput, TextToImageParameters from .text_to_speech import ( TextToSpeechEarlyStoppingEnum, TextToSpeechGenerationParameters, TextToSpeechInput, TextToSpeechOutput, TextToSpeechParameters, ) from .text_to_video import TextToVideoInput, TextToVideoOutput, TextToVideoParameters from .token_classification import ( TokenClassificationAggregationStrategy, TokenClassificationInput, TokenClassificationOutputElement, TokenClassificationParameters, ) from .translation import TranslationInput, TranslationOutput, TranslationParameters, TranslationTruncationStrategy from .video_classification import ( VideoClassificationInput, VideoClassificationOutputElement, VideoClassificationOutputTransform, VideoClassificationParameters, ) from .visual_question_answering import ( VisualQuestionAnsweringInput, VisualQuestionAnsweringInputData, VisualQuestionAnsweringOutputElement, VisualQuestionAnsweringParameters, ) from .zero_shot_classification import ( ZeroShotClassificationInput, ZeroShotClassificationOutputElement, ZeroShotClassificationParameters, ) from .zero_shot_image_classification import ( ZeroShotImageClassificationInput, ZeroShotImageClassificationOutputElement, ZeroShotImageClassificationParameters, ) from .zero_shot_object_detection import ( ZeroShotObjectDetectionBoundingBox, ZeroShotObjectDetectionInput, ZeroShotObjectDetectionOutputElement, ZeroShotObjectDetectionParameters, ) huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/audio_classification.py000066400000000000000000000030451500667546600325430ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import Literal, Optional from .base import BaseInferenceType, dataclass_with_extra AudioClassificationOutputTransform = Literal["sigmoid", "softmax", "none"] @dataclass_with_extra class AudioClassificationParameters(BaseInferenceType): """Additional inference parameters for Audio Classification""" function_to_apply: Optional["AudioClassificationOutputTransform"] = None """The function to apply to the model outputs in order to retrieve the scores.""" top_k: Optional[int] = None """When specified, limits the output to the top K most probable classes.""" @dataclass_with_extra class AudioClassificationInput(BaseInferenceType): """Inputs for Audio Classification inference""" inputs: str """The input audio data as a base64-encoded string. If no `parameters` are provided, you can also provide the audio data as a raw bytes payload. """ parameters: Optional[AudioClassificationParameters] = None """Additional inference parameters for Audio Classification""" @dataclass_with_extra class AudioClassificationOutputElement(BaseInferenceType): """Outputs for Audio Classification inference""" label: str """The predicted class label.""" score: float """The corresponding probability.""" huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/audio_to_audio.py000066400000000000000000000015731500667546600313570ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import Any from .base import BaseInferenceType, dataclass_with_extra @dataclass_with_extra class AudioToAudioInput(BaseInferenceType): """Inputs for Audio to Audio inference""" inputs: Any """The input audio data""" @dataclass_with_extra class AudioToAudioOutputElement(BaseInferenceType): """Outputs of inference for the Audio To Audio task A generated audio file with its label. """ blob: Any """The generated audio file.""" content_type: str """The content type of audio file.""" label: str """The label of the audio file.""" automatic_speech_recognition.py000066400000000000000000000126131500667546600342260ustar00rootroot00000000000000huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import List, Literal, Optional, Union from .base import BaseInferenceType, dataclass_with_extra AutomaticSpeechRecognitionEarlyStoppingEnum = Literal["never"] @dataclass_with_extra class AutomaticSpeechRecognitionGenerationParameters(BaseInferenceType): """Parametrization of the text generation process""" do_sample: Optional[bool] = None """Whether to use sampling instead of greedy decoding when generating new tokens.""" early_stopping: Optional[Union[bool, "AutomaticSpeechRecognitionEarlyStoppingEnum"]] = None """Controls the stopping condition for beam-based methods.""" epsilon_cutoff: Optional[float] = None """If set to float strictly between 0 and 1, only tokens with a conditional probability greater than epsilon_cutoff will be sampled. In the paper, suggested values range from 3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191) for more details. """ eta_cutoff: Optional[float] = None """Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to float strictly between 0 and 1, a token is only considered if it is greater than either eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model. See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191) for more details. """ max_length: Optional[int] = None """The maximum length (in tokens) of the generated text, including the input.""" max_new_tokens: Optional[int] = None """The maximum number of tokens to generate. Takes precedence over max_length.""" min_length: Optional[int] = None """The minimum length (in tokens) of the generated text, including the input.""" min_new_tokens: Optional[int] = None """The minimum number of tokens to generate. Takes precedence over min_length.""" num_beam_groups: Optional[int] = None """Number of groups to divide num_beams into in order to ensure diversity among different groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details. """ num_beams: Optional[int] = None """Number of beams to use for beam search.""" penalty_alpha: Optional[float] = None """The value balances the model confidence and the degeneration penalty in contrastive search decoding. """ temperature: Optional[float] = None """The value used to modulate the next token probabilities.""" top_k: Optional[int] = None """The number of highest probability vocabulary tokens to keep for top-k-filtering.""" top_p: Optional[float] = None """If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. """ typical_p: Optional[float] = None """Local typicality measures how similar the conditional probability of predicting a target token next is to the expected conditional probability of predicting a random token next, given the partial text already generated. If set to float < 1, the smallest set of the most locally typical tokens with probabilities that add up to typical_p or higher are kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details. """ use_cache: Optional[bool] = None """Whether the model should use the past last key/values attentions to speed up decoding""" @dataclass_with_extra class AutomaticSpeechRecognitionParameters(BaseInferenceType): """Additional inference parameters for Automatic Speech Recognition""" generation_parameters: Optional[AutomaticSpeechRecognitionGenerationParameters] = None """Parametrization of the text generation process""" return_timestamps: Optional[bool] = None """Whether to output corresponding timestamps with the generated text""" @dataclass_with_extra class AutomaticSpeechRecognitionInput(BaseInferenceType): """Inputs for Automatic Speech Recognition inference""" inputs: str """The input audio data as a base64-encoded string. If no `parameters` are provided, you can also provide the audio data as a raw bytes payload. """ parameters: Optional[AutomaticSpeechRecognitionParameters] = None """Additional inference parameters for Automatic Speech Recognition""" @dataclass_with_extra class AutomaticSpeechRecognitionOutputChunk(BaseInferenceType): text: str """A chunk of text identified by the model""" timestamp: List[float] """The start and end timestamps corresponding with the text""" @dataclass_with_extra class AutomaticSpeechRecognitionOutput(BaseInferenceType): """Outputs of inference for the Automatic Speech Recognition task""" text: str """The recognized text.""" chunks: Optional[List[AutomaticSpeechRecognitionOutputChunk]] = None """When returnTimestamps is enabled, chunks contains a list of audio chunks identified by the model. """ huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/base.py000066400000000000000000000151371500667546600273060ustar00rootroot00000000000000# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains a base class for all inference types.""" import inspect import json from dataclasses import asdict, dataclass from typing import Any, Dict, List, Type, TypeVar, Union, get_args T = TypeVar("T", bound="BaseInferenceType") def _repr_with_extra(self): fields = list(self.__dataclass_fields__.keys()) other_fields = list(k for k in self.__dict__ if k not in fields) return f"{self.__class__.__name__}({', '.join(f'{k}={self.__dict__[k]!r}' for k in fields + other_fields)})" def dataclass_with_extra(cls: Type[T]) -> Type[T]: """Decorator to add a custom __repr__ method to a dataclass, showing all fields, including extra ones. This decorator only works with dataclasses that inherit from `BaseInferenceType`. """ cls = dataclass(cls) cls.__repr__ = _repr_with_extra # type: ignore[method-assign] return cls @dataclass class BaseInferenceType(dict): """Base class for all inference types. Object is a dataclass and a dict for backward compatibility but plan is to remove the dict part in the future. Handle parsing from dict, list and json strings in a permissive way to ensure future-compatibility (e.g. all fields are made optional, and non-expected fields are added as dict attributes). """ @classmethod def parse_obj_as_list(cls: Type[T], data: Union[bytes, str, List, Dict]) -> List[T]: """Alias to parse server response and return a single instance. See `parse_obj` for more details. """ output = cls.parse_obj(data) if not isinstance(output, list): raise ValueError(f"Invalid input data for {cls}. Expected a list, but got {type(output)}.") return output @classmethod def parse_obj_as_instance(cls: Type[T], data: Union[bytes, str, List, Dict]) -> T: """Alias to parse server response and return a single instance. See `parse_obj` for more details. """ output = cls.parse_obj(data) if isinstance(output, list): raise ValueError(f"Invalid input data for {cls}. Expected a single instance, but got a list.") return output @classmethod def parse_obj(cls: Type[T], data: Union[bytes, str, List, Dict]) -> Union[List[T], T]: """Parse server response as a dataclass or list of dataclasses. To enable future-compatibility, we want to handle cases where the server return more fields than expected. In such cases, we don't want to raise an error but still create the dataclass object. Remaining fields are added as dict attributes. """ # Parse server response (from bytes) if isinstance(data, bytes): data = data.decode() if isinstance(data, str): data = json.loads(data) # If a list, parse each item individually if isinstance(data, List): return [cls.parse_obj(d) for d in data] # type: ignore [misc] # At this point, we expect a dict if not isinstance(data, dict): raise ValueError(f"Invalid data type: {type(data)}") init_values = {} other_values = {} for key, value in data.items(): key = normalize_key(key) if key in cls.__dataclass_fields__ and cls.__dataclass_fields__[key].init: if isinstance(value, dict) or isinstance(value, list): field_type = cls.__dataclass_fields__[key].type # if `field_type` is a `BaseInferenceType`, parse it if inspect.isclass(field_type) and issubclass(field_type, BaseInferenceType): value = field_type.parse_obj(value) # otherwise, recursively parse nested dataclasses (if possible) # `get_args` returns handle Union and Optional for us else: expected_types = get_args(field_type) for expected_type in expected_types: if getattr(expected_type, "_name", None) == "List": expected_type = get_args(expected_type)[ 0 ] # assume same type for all items in the list if inspect.isclass(expected_type) and issubclass(expected_type, BaseInferenceType): value = expected_type.parse_obj(value) break init_values[key] = value else: other_values[key] = value # Make all missing fields default to None # => ensure that dataclass initialization will never fail even if the server does not return all fields. for key in cls.__dataclass_fields__: if key not in init_values: init_values[key] = None # Initialize dataclass with expected values item = cls(**init_values) # Add remaining fields as dict attributes item.update(other_values) # Add remaining fields as extra dataclass fields. # They won't be part of the dataclass fields but will be accessible as attributes. # Use @dataclass_with_extra to show them in __repr__. item.__dict__.update(other_values) return item def __post_init__(self): self.update(asdict(self)) def __setitem__(self, __key: Any, __value: Any) -> None: # Hacky way to keep dataclass values in sync when dict is updated super().__setitem__(__key, __value) if __key in self.__dataclass_fields__ and getattr(self, __key, None) != __value: self.__setattr__(__key, __value) return def __setattr__(self, __name: str, __value: Any) -> None: # Hacky way to keep dict values is sync when dataclass is updated super().__setattr__(__name, __value) if self.get(__name) != __value: self[__name] = __value return def normalize_key(key: str) -> str: # e.g "content-type" -> "content_type", "Accept" -> "accept" return key.replace("-", "_").replace(" ", "_").lower() huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/chat_completion.py000066400000000000000000000240021500667546600315330ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import Any, List, Literal, Optional, Union from .base import BaseInferenceType, dataclass_with_extra @dataclass_with_extra class ChatCompletionInputURL(BaseInferenceType): url: str ChatCompletionInputMessageChunkType = Literal["text", "image_url"] @dataclass_with_extra class ChatCompletionInputMessageChunk(BaseInferenceType): type: "ChatCompletionInputMessageChunkType" image_url: Optional[ChatCompletionInputURL] = None text: Optional[str] = None @dataclass_with_extra class ChatCompletionInputFunctionDefinition(BaseInferenceType): name: str parameters: Any description: Optional[str] = None @dataclass_with_extra class ChatCompletionInputToolCall(BaseInferenceType): function: ChatCompletionInputFunctionDefinition id: str type: str @dataclass_with_extra class ChatCompletionInputMessage(BaseInferenceType): role: str content: Optional[Union[List[ChatCompletionInputMessageChunk], str]] = None name: Optional[str] = None tool_calls: Optional[List[ChatCompletionInputToolCall]] = None ChatCompletionInputGrammarTypeType = Literal["json", "regex", "json_schema"] @dataclass_with_extra class ChatCompletionInputGrammarType(BaseInferenceType): type: "ChatCompletionInputGrammarTypeType" value: Any """A string that represents a [JSON Schema](https://json-schema.org/). JSON Schema is a declarative language that allows to annotate JSON documents with types and descriptions. """ @dataclass_with_extra class ChatCompletionInputStreamOptions(BaseInferenceType): include_usage: Optional[bool] = None """If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value. """ @dataclass_with_extra class ChatCompletionInputFunctionName(BaseInferenceType): name: str @dataclass_with_extra class ChatCompletionInputToolChoiceClass(BaseInferenceType): function: ChatCompletionInputFunctionName ChatCompletionInputToolChoiceEnum = Literal["auto", "none", "required"] @dataclass_with_extra class ChatCompletionInputTool(BaseInferenceType): function: ChatCompletionInputFunctionDefinition type: str @dataclass_with_extra class ChatCompletionInput(BaseInferenceType): """Chat Completion Input. Auto-generated from TGI specs. For more details, check out https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts. """ messages: List[ChatCompletionInputMessage] """A list of messages comprising the conversation so far.""" frequency_penalty: Optional[float] = None """Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. """ logit_bias: Optional[List[float]] = None """UNUSED Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token. """ logprobs: Optional[bool] = None """Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message. """ max_tokens: Optional[int] = None """The maximum number of tokens that can be generated in the chat completion.""" model: Optional[str] = None """[UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. """ n: Optional[int] = None """UNUSED How many chat completion choices to generate for each input message. Note that you will be charged based on the number of generated tokens across all of the choices. Keep n as 1 to minimize costs. """ presence_penalty: Optional[float] = None """Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics """ response_format: Optional[ChatCompletionInputGrammarType] = None seed: Optional[int] = None stop: Optional[List[str]] = None """Up to 4 sequences where the API will stop generating further tokens.""" stream: Optional[bool] = None stream_options: Optional[ChatCompletionInputStreamOptions] = None temperature: Optional[float] = None """What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or `top_p` but not both. """ tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None tool_prompt: Optional[str] = None """A prompt to be appended before the tools""" tools: Optional[List[ChatCompletionInputTool]] = None """A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. """ top_logprobs: Optional[int] = None """An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with an associated log probability. logprobs must be set to true if this parameter is used. """ top_p: Optional[float] = None """An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. """ @dataclass_with_extra class ChatCompletionOutputTopLogprob(BaseInferenceType): logprob: float token: str @dataclass_with_extra class ChatCompletionOutputLogprob(BaseInferenceType): logprob: float token: str top_logprobs: List[ChatCompletionOutputTopLogprob] @dataclass_with_extra class ChatCompletionOutputLogprobs(BaseInferenceType): content: List[ChatCompletionOutputLogprob] @dataclass_with_extra class ChatCompletionOutputFunctionDefinition(BaseInferenceType): arguments: str name: str description: Optional[str] = None @dataclass_with_extra class ChatCompletionOutputToolCall(BaseInferenceType): function: ChatCompletionOutputFunctionDefinition id: str type: str @dataclass_with_extra class ChatCompletionOutputMessage(BaseInferenceType): role: str content: Optional[str] = None tool_call_id: Optional[str] = None tool_calls: Optional[List[ChatCompletionOutputToolCall]] = None @dataclass_with_extra class ChatCompletionOutputComplete(BaseInferenceType): finish_reason: str index: int message: ChatCompletionOutputMessage logprobs: Optional[ChatCompletionOutputLogprobs] = None @dataclass_with_extra class ChatCompletionOutputUsage(BaseInferenceType): completion_tokens: int prompt_tokens: int total_tokens: int @dataclass_with_extra class ChatCompletionOutput(BaseInferenceType): """Chat Completion Output. Auto-generated from TGI specs. For more details, check out https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts. """ choices: List[ChatCompletionOutputComplete] created: int id: str model: str system_fingerprint: str usage: ChatCompletionOutputUsage @dataclass_with_extra class ChatCompletionStreamOutputFunction(BaseInferenceType): arguments: str name: Optional[str] = None @dataclass_with_extra class ChatCompletionStreamOutputDeltaToolCall(BaseInferenceType): function: ChatCompletionStreamOutputFunction id: str index: int type: str @dataclass_with_extra class ChatCompletionStreamOutputDelta(BaseInferenceType): role: str content: Optional[str] = None tool_call_id: Optional[str] = None tool_calls: Optional[List[ChatCompletionStreamOutputDeltaToolCall]] = None @dataclass_with_extra class ChatCompletionStreamOutputTopLogprob(BaseInferenceType): logprob: float token: str @dataclass_with_extra class ChatCompletionStreamOutputLogprob(BaseInferenceType): logprob: float token: str top_logprobs: List[ChatCompletionStreamOutputTopLogprob] @dataclass_with_extra class ChatCompletionStreamOutputLogprobs(BaseInferenceType): content: List[ChatCompletionStreamOutputLogprob] @dataclass_with_extra class ChatCompletionStreamOutputChoice(BaseInferenceType): delta: ChatCompletionStreamOutputDelta index: int finish_reason: Optional[str] = None logprobs: Optional[ChatCompletionStreamOutputLogprobs] = None @dataclass_with_extra class ChatCompletionStreamOutputUsage(BaseInferenceType): completion_tokens: int prompt_tokens: int total_tokens: int @dataclass_with_extra class ChatCompletionStreamOutput(BaseInferenceType): """Chat Completion Stream Output. Auto-generated from TGI specs. For more details, check out https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts. """ choices: List[ChatCompletionStreamOutputChoice] created: int id: str model: str system_fingerprint: str usage: Optional[ChatCompletionStreamOutputUsage] = None huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/depth_estimation.py000066400000000000000000000016411500667546600317270ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import Any, Dict, Optional from .base import BaseInferenceType, dataclass_with_extra @dataclass_with_extra class DepthEstimationInput(BaseInferenceType): """Inputs for Depth Estimation inference""" inputs: Any """The input image data""" parameters: Optional[Dict[str, Any]] = None """Additional inference parameters for Depth Estimation""" @dataclass_with_extra class DepthEstimationOutput(BaseInferenceType): """Outputs of inference for the Depth Estimation task""" depth: Any """The predicted depth as an image""" predicted_depth: Any """The predicted depth as a tensor""" huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/document_question_answering.py000066400000000000000000000062021500667546600342070ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import Any, List, Optional, Union from .base import BaseInferenceType, dataclass_with_extra @dataclass_with_extra class DocumentQuestionAnsweringInputData(BaseInferenceType): """One (document, question) pair to answer""" image: Any """The image on which the question is asked""" question: str """A question to ask of the document""" @dataclass_with_extra class DocumentQuestionAnsweringParameters(BaseInferenceType): """Additional inference parameters for Document Question Answering""" doc_stride: Optional[int] = None """If the words in the document are too long to fit with the question for the model, it will be split in several chunks with some overlap. This argument controls the size of that overlap. """ handle_impossible_answer: Optional[bool] = None """Whether to accept impossible as an answer""" lang: Optional[str] = None """Language to use while running OCR. Defaults to english.""" max_answer_len: Optional[int] = None """The maximum length of predicted answers (e.g., only answers with a shorter length are considered). """ max_question_len: Optional[int] = None """The maximum length of the question after tokenization. It will be truncated if needed.""" max_seq_len: Optional[int] = None """The maximum length of the total sentence (context + question) in tokens of each chunk passed to the model. The context will be split in several chunks (using doc_stride as overlap) if needed. """ top_k: Optional[int] = None """The number of answers to return (will be chosen by order of likelihood). Can return less than top_k answers if there are not enough options available within the context. """ word_boxes: Optional[List[Union[List[float], str]]] = None """A list of words and bounding boxes (normalized 0->1000). If provided, the inference will skip the OCR step and use the provided bounding boxes instead. """ @dataclass_with_extra class DocumentQuestionAnsweringInput(BaseInferenceType): """Inputs for Document Question Answering inference""" inputs: DocumentQuestionAnsweringInputData """One (document, question) pair to answer""" parameters: Optional[DocumentQuestionAnsweringParameters] = None """Additional inference parameters for Document Question Answering""" @dataclass_with_extra class DocumentQuestionAnsweringOutputElement(BaseInferenceType): """Outputs of inference for the Document Question Answering task""" answer: str """The answer to the question.""" end: int """The end word index of the answer (in the OCR’d version of the input or provided word boxes). """ score: float """The probability associated to the answer.""" start: int """The start word index of the answer (in the OCR’d version of the input or provided word boxes). """ huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/feature_extraction.py000066400000000000000000000030011500667546600322520ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import List, Literal, Optional, Union from .base import BaseInferenceType, dataclass_with_extra FeatureExtractionInputTruncationDirection = Literal["Left", "Right"] @dataclass_with_extra class FeatureExtractionInput(BaseInferenceType): """Feature Extraction Input. Auto-generated from TEI specs. For more details, check out https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tei-import.ts. """ inputs: Union[List[str], str] """The text or list of texts to embed.""" normalize: Optional[bool] = None prompt_name: Optional[str] = None """The name of the prompt that should be used by for encoding. If not set, no prompt will be applied. Must be a key in the `sentence-transformers` configuration `prompts` dictionary. For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ", ...}, then the sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?" because the prompt text will be prepended before any text to encode. """ truncate: Optional[bool] = None truncation_direction: Optional["FeatureExtractionInputTruncationDirection"] = None huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/fill_mask.py000066400000000000000000000032541500667546600303320ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import Any, List, Optional from .base import BaseInferenceType, dataclass_with_extra @dataclass_with_extra class FillMaskParameters(BaseInferenceType): """Additional inference parameters for Fill Mask""" targets: Optional[List[str]] = None """When passed, the model will limit the scores to the passed targets instead of looking up in the whole vocabulary. If the provided targets are not in the model vocab, they will be tokenized and the first resulting token will be used (with a warning, and that might be slower). """ top_k: Optional[int] = None """When passed, overrides the number of predictions to return.""" @dataclass_with_extra class FillMaskInput(BaseInferenceType): """Inputs for Fill Mask inference""" inputs: str """The text with masked tokens""" parameters: Optional[FillMaskParameters] = None """Additional inference parameters for Fill Mask""" @dataclass_with_extra class FillMaskOutputElement(BaseInferenceType): """Outputs of inference for the Fill Mask task""" score: float """The corresponding probability""" sequence: str """The corresponding input with the mask token prediction.""" token: int """The predicted token id (to replace the masked one).""" token_str: Any fill_mask_output_token_str: Optional[str] = None """The predicted token (to replace the masked one).""" huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/image_classification.py000066400000000000000000000030611500667546600325220ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import Literal, Optional from .base import BaseInferenceType, dataclass_with_extra ImageClassificationOutputTransform = Literal["sigmoid", "softmax", "none"] @dataclass_with_extra class ImageClassificationParameters(BaseInferenceType): """Additional inference parameters for Image Classification""" function_to_apply: Optional["ImageClassificationOutputTransform"] = None """The function to apply to the model outputs in order to retrieve the scores.""" top_k: Optional[int] = None """When specified, limits the output to the top K most probable classes.""" @dataclass_with_extra class ImageClassificationInput(BaseInferenceType): """Inputs for Image Classification inference""" inputs: str """The input image data as a base64-encoded string. If no `parameters` are provided, you can also provide the image data as a raw bytes payload. """ parameters: Optional[ImageClassificationParameters] = None """Additional inference parameters for Image Classification""" @dataclass_with_extra class ImageClassificationOutputElement(BaseInferenceType): """Outputs of inference for the Image Classification task""" label: str """The predicted class label.""" score: float """The corresponding probability.""" huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/image_segmentation.py000066400000000000000000000036361500667546600322340ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import Literal, Optional from .base import BaseInferenceType, dataclass_with_extra ImageSegmentationSubtask = Literal["instance", "panoptic", "semantic"] @dataclass_with_extra class ImageSegmentationParameters(BaseInferenceType): """Additional inference parameters for Image Segmentation""" mask_threshold: Optional[float] = None """Threshold to use when turning the predicted masks into binary values.""" overlap_mask_area_threshold: Optional[float] = None """Mask overlap threshold to eliminate small, disconnected segments.""" subtask: Optional["ImageSegmentationSubtask"] = None """Segmentation task to be performed, depending on model capabilities.""" threshold: Optional[float] = None """Probability threshold to filter out predicted masks.""" @dataclass_with_extra class ImageSegmentationInput(BaseInferenceType): """Inputs for Image Segmentation inference""" inputs: str """The input image data as a base64-encoded string. If no `parameters` are provided, you can also provide the image data as a raw bytes payload. """ parameters: Optional[ImageSegmentationParameters] = None """Additional inference parameters for Image Segmentation""" @dataclass_with_extra class ImageSegmentationOutputElement(BaseInferenceType): """Outputs of inference for the Image Segmentation task A predicted mask / segment """ label: str """The label of the predicted segment.""" mask: str """The corresponding mask as a black-and-white image (base64-encoded).""" score: Optional[float] = None """The score or confidence degree the model has.""" huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/image_to_image.py000066400000000000000000000037741500667546600313260ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import Any, Optional from .base import BaseInferenceType, dataclass_with_extra @dataclass_with_extra class ImageToImageTargetSize(BaseInferenceType): """The size in pixel of the output image.""" height: int width: int @dataclass_with_extra class ImageToImageParameters(BaseInferenceType): """Additional inference parameters for Image To Image""" guidance_scale: Optional[float] = None """For diffusion models. A higher guidance scale value encourages the model to generate images closely linked to the text prompt at the expense of lower image quality. """ negative_prompt: Optional[str] = None """One prompt to guide what NOT to include in image generation.""" num_inference_steps: Optional[int] = None """For diffusion models. The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. """ prompt: Optional[str] = None """The text prompt to guide the image generation.""" target_size: Optional[ImageToImageTargetSize] = None """The size in pixel of the output image.""" @dataclass_with_extra class ImageToImageInput(BaseInferenceType): """Inputs for Image To Image inference""" inputs: str """The input image data as a base64-encoded string. If no `parameters` are provided, you can also provide the image data as a raw bytes payload. """ parameters: Optional[ImageToImageParameters] = None """Additional inference parameters for Image To Image""" @dataclass_with_extra class ImageToImageOutput(BaseInferenceType): """Outputs of inference for the Image To Image task""" image: Any """The output image returned as raw bytes in the payload.""" huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/image_to_text.py000066400000000000000000000113121500667546600312130ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import Any, Literal, Optional, Union from .base import BaseInferenceType, dataclass_with_extra ImageToTextEarlyStoppingEnum = Literal["never"] @dataclass_with_extra class ImageToTextGenerationParameters(BaseInferenceType): """Parametrization of the text generation process""" do_sample: Optional[bool] = None """Whether to use sampling instead of greedy decoding when generating new tokens.""" early_stopping: Optional[Union[bool, "ImageToTextEarlyStoppingEnum"]] = None """Controls the stopping condition for beam-based methods.""" epsilon_cutoff: Optional[float] = None """If set to float strictly between 0 and 1, only tokens with a conditional probability greater than epsilon_cutoff will be sampled. In the paper, suggested values range from 3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191) for more details. """ eta_cutoff: Optional[float] = None """Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to float strictly between 0 and 1, a token is only considered if it is greater than either eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model. See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191) for more details. """ max_length: Optional[int] = None """The maximum length (in tokens) of the generated text, including the input.""" max_new_tokens: Optional[int] = None """The maximum number of tokens to generate. Takes precedence over max_length.""" min_length: Optional[int] = None """The minimum length (in tokens) of the generated text, including the input.""" min_new_tokens: Optional[int] = None """The minimum number of tokens to generate. Takes precedence over min_length.""" num_beam_groups: Optional[int] = None """Number of groups to divide num_beams into in order to ensure diversity among different groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details. """ num_beams: Optional[int] = None """Number of beams to use for beam search.""" penalty_alpha: Optional[float] = None """The value balances the model confidence and the degeneration penalty in contrastive search decoding. """ temperature: Optional[float] = None """The value used to modulate the next token probabilities.""" top_k: Optional[int] = None """The number of highest probability vocabulary tokens to keep for top-k-filtering.""" top_p: Optional[float] = None """If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. """ typical_p: Optional[float] = None """Local typicality measures how similar the conditional probability of predicting a target token next is to the expected conditional probability of predicting a random token next, given the partial text already generated. If set to float < 1, the smallest set of the most locally typical tokens with probabilities that add up to typical_p or higher are kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details. """ use_cache: Optional[bool] = None """Whether the model should use the past last key/values attentions to speed up decoding""" @dataclass_with_extra class ImageToTextParameters(BaseInferenceType): """Additional inference parameters for Image To Text""" generation_parameters: Optional[ImageToTextGenerationParameters] = None """Parametrization of the text generation process""" max_new_tokens: Optional[int] = None """The amount of maximum tokens to generate.""" @dataclass_with_extra class ImageToTextInput(BaseInferenceType): """Inputs for Image To Text inference""" inputs: Any """The input image data""" parameters: Optional[ImageToTextParameters] = None """Additional inference parameters for Image To Text""" @dataclass_with_extra class ImageToTextOutput(BaseInferenceType): """Outputs of inference for the Image To Text task""" generated_text: Any image_to_text_output_generated_text: Optional[str] = None """The generated text.""" huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/object_detection.py000066400000000000000000000037201500667546600316730ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import Optional from .base import BaseInferenceType, dataclass_with_extra @dataclass_with_extra class ObjectDetectionParameters(BaseInferenceType): """Additional inference parameters for Object Detection""" threshold: Optional[float] = None """The probability necessary to make a prediction.""" @dataclass_with_extra class ObjectDetectionInput(BaseInferenceType): """Inputs for Object Detection inference""" inputs: str """The input image data as a base64-encoded string. If no `parameters` are provided, you can also provide the image data as a raw bytes payload. """ parameters: Optional[ObjectDetectionParameters] = None """Additional inference parameters for Object Detection""" @dataclass_with_extra class ObjectDetectionBoundingBox(BaseInferenceType): """The predicted bounding box. Coordinates are relative to the top left corner of the input image. """ xmax: int """The x-coordinate of the bottom-right corner of the bounding box.""" xmin: int """The x-coordinate of the top-left corner of the bounding box.""" ymax: int """The y-coordinate of the bottom-right corner of the bounding box.""" ymin: int """The y-coordinate of the top-left corner of the bounding box.""" @dataclass_with_extra class ObjectDetectionOutputElement(BaseInferenceType): """Outputs of inference for the Object Detection task""" box: ObjectDetectionBoundingBox """The predicted bounding box. Coordinates are relative to the top left corner of the input image. """ label: str """The predicted label for the bounding box.""" score: float """The associated score / probability.""" huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/question_answering.py000066400000000000000000000055221500667546600323150ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import Optional from .base import BaseInferenceType, dataclass_with_extra @dataclass_with_extra class QuestionAnsweringInputData(BaseInferenceType): """One (context, question) pair to answer""" context: str """The context to be used for answering the question""" question: str """The question to be answered""" @dataclass_with_extra class QuestionAnsweringParameters(BaseInferenceType): """Additional inference parameters for Question Answering""" align_to_words: Optional[bool] = None """Attempts to align the answer to real words. Improves quality on space separated languages. Might hurt on non-space-separated languages (like Japanese or Chinese) """ doc_stride: Optional[int] = None """If the context is too long to fit with the question for the model, it will be split in several chunks with some overlap. This argument controls the size of that overlap. """ handle_impossible_answer: Optional[bool] = None """Whether to accept impossible as an answer.""" max_answer_len: Optional[int] = None """The maximum length of predicted answers (e.g., only answers with a shorter length are considered). """ max_question_len: Optional[int] = None """The maximum length of the question after tokenization. It will be truncated if needed.""" max_seq_len: Optional[int] = None """The maximum length of the total sentence (context + question) in tokens of each chunk passed to the model. The context will be split in several chunks (using docStride as overlap) if needed. """ top_k: Optional[int] = None """The number of answers to return (will be chosen by order of likelihood). Note that we return less than topk answers if there are not enough options available within the context. """ @dataclass_with_extra class QuestionAnsweringInput(BaseInferenceType): """Inputs for Question Answering inference""" inputs: QuestionAnsweringInputData """One (context, question) pair to answer""" parameters: Optional[QuestionAnsweringParameters] = None """Additional inference parameters for Question Answering""" @dataclass_with_extra class QuestionAnsweringOutputElement(BaseInferenceType): """Outputs of inference for the Question Answering task""" answer: str """The answer to the question.""" end: int """The character position in the input where the answer ends.""" score: float """The probability associated to the answer.""" start: int """The character position in the input where the answer begins.""" huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/sentence_similarity.py000066400000000000000000000020341500667546600324360ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import Any, Dict, List, Optional from .base import BaseInferenceType, dataclass_with_extra @dataclass_with_extra class SentenceSimilarityInputData(BaseInferenceType): sentences: List[str] """A list of strings which will be compared against the source_sentence.""" source_sentence: str """The string that you wish to compare the other strings with. This can be a phrase, sentence, or longer passage, depending on the model being used. """ @dataclass_with_extra class SentenceSimilarityInput(BaseInferenceType): """Inputs for Sentence similarity inference""" inputs: SentenceSimilarityInputData parameters: Optional[Dict[str, Any]] = None """Additional inference parameters for Sentence Similarity""" huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/summarization.py000066400000000000000000000027171500667546600312760ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import Any, Dict, Literal, Optional from .base import BaseInferenceType, dataclass_with_extra SummarizationTruncationStrategy = Literal["do_not_truncate", "longest_first", "only_first", "only_second"] @dataclass_with_extra class SummarizationParameters(BaseInferenceType): """Additional inference parameters for summarization.""" clean_up_tokenization_spaces: Optional[bool] = None """Whether to clean up the potential extra spaces in the text output.""" generate_parameters: Optional[Dict[str, Any]] = None """Additional parametrization of the text generation algorithm.""" truncation: Optional["SummarizationTruncationStrategy"] = None """The truncation strategy to use.""" @dataclass_with_extra class SummarizationInput(BaseInferenceType): """Inputs for Summarization inference""" inputs: str """The input text to summarize.""" parameters: Optional[SummarizationParameters] = None """Additional inference parameters for summarization.""" @dataclass_with_extra class SummarizationOutput(BaseInferenceType): """Outputs of inference for the Summarization task""" summary_text: str """The summarized text.""" huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/table_question_answering.py000066400000000000000000000043651500667546600334700ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import Dict, List, Literal, Optional from .base import BaseInferenceType, dataclass_with_extra @dataclass_with_extra class TableQuestionAnsweringInputData(BaseInferenceType): """One (table, question) pair to answer""" question: str """The question to be answered about the table""" table: Dict[str, List[str]] """The table to serve as context for the questions""" Padding = Literal["do_not_pad", "longest", "max_length"] @dataclass_with_extra class TableQuestionAnsweringParameters(BaseInferenceType): """Additional inference parameters for Table Question Answering""" padding: Optional["Padding"] = None """Activates and controls padding.""" sequential: Optional[bool] = None """Whether to do inference sequentially or as a batch. Batching is faster, but models like SQA require the inference to be done sequentially to extract relations within sequences, given their conversational nature. """ truncation: Optional[bool] = None """Activates and controls truncation.""" @dataclass_with_extra class TableQuestionAnsweringInput(BaseInferenceType): """Inputs for Table Question Answering inference""" inputs: TableQuestionAnsweringInputData """One (table, question) pair to answer""" parameters: Optional[TableQuestionAnsweringParameters] = None """Additional inference parameters for Table Question Answering""" @dataclass_with_extra class TableQuestionAnsweringOutputElement(BaseInferenceType): """Outputs of inference for the Table Question Answering task""" answer: str """The answer of the question given the table. If there is an aggregator, the answer will be preceded by `AGGREGATOR >`. """ cells: List[str] """List of strings made up of the answer cell values.""" coordinates: List[List[int]] """Coordinates of the cells of the answers.""" aggregator: Optional[str] = None """If the model has an aggregator, this returns the aggregator.""" huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/text2text_generation.py000066400000000000000000000031111500667546600325470ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import Any, Dict, Literal, Optional from .base import BaseInferenceType, dataclass_with_extra Text2TextGenerationTruncationStrategy = Literal["do_not_truncate", "longest_first", "only_first", "only_second"] @dataclass_with_extra class Text2TextGenerationParameters(BaseInferenceType): """Additional inference parameters for Text2text Generation""" clean_up_tokenization_spaces: Optional[bool] = None """Whether to clean up the potential extra spaces in the text output.""" generate_parameters: Optional[Dict[str, Any]] = None """Additional parametrization of the text generation algorithm""" truncation: Optional["Text2TextGenerationTruncationStrategy"] = None """The truncation strategy to use""" @dataclass_with_extra class Text2TextGenerationInput(BaseInferenceType): """Inputs for Text2text Generation inference""" inputs: str """The input text data""" parameters: Optional[Text2TextGenerationParameters] = None """Additional inference parameters for Text2text Generation""" @dataclass_with_extra class Text2TextGenerationOutput(BaseInferenceType): """Outputs of inference for the Text2text Generation task""" generated_text: Any text2_text_generation_output_generated_text: Optional[str] = None """The generated text.""" huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/text_classification.py000066400000000000000000000026451500667546600324330ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import Literal, Optional from .base import BaseInferenceType, dataclass_with_extra TextClassificationOutputTransform = Literal["sigmoid", "softmax", "none"] @dataclass_with_extra class TextClassificationParameters(BaseInferenceType): """Additional inference parameters for Text Classification""" function_to_apply: Optional["TextClassificationOutputTransform"] = None """The function to apply to the model outputs in order to retrieve the scores.""" top_k: Optional[int] = None """When specified, limits the output to the top K most probable classes.""" @dataclass_with_extra class TextClassificationInput(BaseInferenceType): """Inputs for Text Classification inference""" inputs: str """The text to classify""" parameters: Optional[TextClassificationParameters] = None """Additional inference parameters for Text Classification""" @dataclass_with_extra class TextClassificationOutputElement(BaseInferenceType): """Outputs of inference for the Text Classification task""" label: str """The predicted class label.""" score: float """The corresponding probability.""" huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/text_generation.py000066400000000000000000000134421500667546600315700ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import Any, List, Literal, Optional from .base import BaseInferenceType, dataclass_with_extra TypeEnum = Literal["json", "regex", "json_schema"] @dataclass_with_extra class TextGenerationInputGrammarType(BaseInferenceType): type: "TypeEnum" value: Any """A string that represents a [JSON Schema](https://json-schema.org/). JSON Schema is a declarative language that allows to annotate JSON documents with types and descriptions. """ @dataclass_with_extra class TextGenerationInputGenerateParameters(BaseInferenceType): adapter_id: Optional[str] = None """Lora adapter id""" best_of: Optional[int] = None """Generate best_of sequences and return the one if the highest token logprobs.""" decoder_input_details: Optional[bool] = None """Whether to return decoder input token logprobs and ids.""" details: Optional[bool] = None """Whether to return generation details.""" do_sample: Optional[bool] = None """Activate logits sampling.""" frequency_penalty: Optional[float] = None """The parameter for frequency penalty. 1.0 means no penalty Penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. """ grammar: Optional[TextGenerationInputGrammarType] = None max_new_tokens: Optional[int] = None """Maximum number of tokens to generate.""" repetition_penalty: Optional[float] = None """The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. """ return_full_text: Optional[bool] = None """Whether to prepend the prompt to the generated text""" seed: Optional[int] = None """Random sampling seed.""" stop: Optional[List[str]] = None """Stop generating tokens if a member of `stop` is generated.""" temperature: Optional[float] = None """The value used to module the logits distribution.""" top_k: Optional[int] = None """The number of highest probability vocabulary tokens to keep for top-k-filtering.""" top_n_tokens: Optional[int] = None """The number of highest probability vocabulary tokens to keep for top-n-filtering.""" top_p: Optional[float] = None """Top-p value for nucleus sampling.""" truncate: Optional[int] = None """Truncate inputs tokens to the given size.""" typical_p: Optional[float] = None """Typical Decoding mass See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information. """ watermark: Optional[bool] = None """Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226). """ @dataclass_with_extra class TextGenerationInput(BaseInferenceType): """Text Generation Input. Auto-generated from TGI specs. For more details, check out https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts. """ inputs: str parameters: Optional[TextGenerationInputGenerateParameters] = None stream: Optional[bool] = None TextGenerationOutputFinishReason = Literal["length", "eos_token", "stop_sequence"] @dataclass_with_extra class TextGenerationOutputPrefillToken(BaseInferenceType): id: int logprob: float text: str @dataclass_with_extra class TextGenerationOutputToken(BaseInferenceType): id: int logprob: float special: bool text: str @dataclass_with_extra class TextGenerationOutputBestOfSequence(BaseInferenceType): finish_reason: "TextGenerationOutputFinishReason" generated_text: str generated_tokens: int prefill: List[TextGenerationOutputPrefillToken] tokens: List[TextGenerationOutputToken] seed: Optional[int] = None top_tokens: Optional[List[List[TextGenerationOutputToken]]] = None @dataclass_with_extra class TextGenerationOutputDetails(BaseInferenceType): finish_reason: "TextGenerationOutputFinishReason" generated_tokens: int prefill: List[TextGenerationOutputPrefillToken] tokens: List[TextGenerationOutputToken] best_of_sequences: Optional[List[TextGenerationOutputBestOfSequence]] = None seed: Optional[int] = None top_tokens: Optional[List[List[TextGenerationOutputToken]]] = None @dataclass_with_extra class TextGenerationOutput(BaseInferenceType): """Text Generation Output. Auto-generated from TGI specs. For more details, check out https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts. """ generated_text: str details: Optional[TextGenerationOutputDetails] = None @dataclass_with_extra class TextGenerationStreamOutputStreamDetails(BaseInferenceType): finish_reason: "TextGenerationOutputFinishReason" generated_tokens: int input_length: int seed: Optional[int] = None @dataclass_with_extra class TextGenerationStreamOutputToken(BaseInferenceType): id: int logprob: float special: bool text: str @dataclass_with_extra class TextGenerationStreamOutput(BaseInferenceType): """Text Generation Stream Output. Auto-generated from TGI specs. For more details, check out https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts. """ index: int token: TextGenerationStreamOutputToken details: Optional[TextGenerationStreamOutputStreamDetails] = None generated_text: Optional[str] = None top_tokens: Optional[List[TextGenerationStreamOutputToken]] = None huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/text_to_audio.py000066400000000000000000000112051500667546600312330ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import Any, Literal, Optional, Union from .base import BaseInferenceType, dataclass_with_extra TextToAudioEarlyStoppingEnum = Literal["never"] @dataclass_with_extra class TextToAudioGenerationParameters(BaseInferenceType): """Parametrization of the text generation process""" do_sample: Optional[bool] = None """Whether to use sampling instead of greedy decoding when generating new tokens.""" early_stopping: Optional[Union[bool, "TextToAudioEarlyStoppingEnum"]] = None """Controls the stopping condition for beam-based methods.""" epsilon_cutoff: Optional[float] = None """If set to float strictly between 0 and 1, only tokens with a conditional probability greater than epsilon_cutoff will be sampled. In the paper, suggested values range from 3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191) for more details. """ eta_cutoff: Optional[float] = None """Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to float strictly between 0 and 1, a token is only considered if it is greater than either eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model. See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191) for more details. """ max_length: Optional[int] = None """The maximum length (in tokens) of the generated text, including the input.""" max_new_tokens: Optional[int] = None """The maximum number of tokens to generate. Takes precedence over max_length.""" min_length: Optional[int] = None """The minimum length (in tokens) of the generated text, including the input.""" min_new_tokens: Optional[int] = None """The minimum number of tokens to generate. Takes precedence over min_length.""" num_beam_groups: Optional[int] = None """Number of groups to divide num_beams into in order to ensure diversity among different groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details. """ num_beams: Optional[int] = None """Number of beams to use for beam search.""" penalty_alpha: Optional[float] = None """The value balances the model confidence and the degeneration penalty in contrastive search decoding. """ temperature: Optional[float] = None """The value used to modulate the next token probabilities.""" top_k: Optional[int] = None """The number of highest probability vocabulary tokens to keep for top-k-filtering.""" top_p: Optional[float] = None """If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. """ typical_p: Optional[float] = None """Local typicality measures how similar the conditional probability of predicting a target token next is to the expected conditional probability of predicting a random token next, given the partial text already generated. If set to float < 1, the smallest set of the most locally typical tokens with probabilities that add up to typical_p or higher are kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details. """ use_cache: Optional[bool] = None """Whether the model should use the past last key/values attentions to speed up decoding""" @dataclass_with_extra class TextToAudioParameters(BaseInferenceType): """Additional inference parameters for Text To Audio""" generation_parameters: Optional[TextToAudioGenerationParameters] = None """Parametrization of the text generation process""" @dataclass_with_extra class TextToAudioInput(BaseInferenceType): """Inputs for Text To Audio inference""" inputs: str """The input text data""" parameters: Optional[TextToAudioParameters] = None """Additional inference parameters for Text To Audio""" @dataclass_with_extra class TextToAudioOutput(BaseInferenceType): """Outputs of inference for the Text To Audio task""" audio: Any """The generated audio waveform.""" sampling_rate: float """The sampling rate of the generated audio waveform.""" huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/text_to_image.py000066400000000000000000000035571500667546600312270ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import Any, Optional from .base import BaseInferenceType, dataclass_with_extra @dataclass_with_extra class TextToImageParameters(BaseInferenceType): """Additional inference parameters for Text To Image""" guidance_scale: Optional[float] = None """A higher guidance scale value encourages the model to generate images closely linked to the text prompt, but values too high may cause saturation and other artifacts. """ height: Optional[int] = None """The height in pixels of the output image""" negative_prompt: Optional[str] = None """One prompt to guide what NOT to include in image generation.""" num_inference_steps: Optional[int] = None """The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. """ scheduler: Optional[str] = None """Override the scheduler with a compatible one.""" seed: Optional[int] = None """Seed for the random number generator.""" width: Optional[int] = None """The width in pixels of the output image""" @dataclass_with_extra class TextToImageInput(BaseInferenceType): """Inputs for Text To Image inference""" inputs: str """The input text data (sometimes called "prompt")""" parameters: Optional[TextToImageParameters] = None """Additional inference parameters for Text To Image""" @dataclass_with_extra class TextToImageOutput(BaseInferenceType): """Outputs of inference for the Text To Image task""" image: Any """The generated image returned as raw bytes in the payload.""" huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/text_to_speech.py000066400000000000000000000112301500667546600313770ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import Any, Literal, Optional, Union from .base import BaseInferenceType, dataclass_with_extra TextToSpeechEarlyStoppingEnum = Literal["never"] @dataclass_with_extra class TextToSpeechGenerationParameters(BaseInferenceType): """Parametrization of the text generation process""" do_sample: Optional[bool] = None """Whether to use sampling instead of greedy decoding when generating new tokens.""" early_stopping: Optional[Union[bool, "TextToSpeechEarlyStoppingEnum"]] = None """Controls the stopping condition for beam-based methods.""" epsilon_cutoff: Optional[float] = None """If set to float strictly between 0 and 1, only tokens with a conditional probability greater than epsilon_cutoff will be sampled. In the paper, suggested values range from 3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191) for more details. """ eta_cutoff: Optional[float] = None """Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to float strictly between 0 and 1, a token is only considered if it is greater than either eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model. See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191) for more details. """ max_length: Optional[int] = None """The maximum length (in tokens) of the generated text, including the input.""" max_new_tokens: Optional[int] = None """The maximum number of tokens to generate. Takes precedence over max_length.""" min_length: Optional[int] = None """The minimum length (in tokens) of the generated text, including the input.""" min_new_tokens: Optional[int] = None """The minimum number of tokens to generate. Takes precedence over min_length.""" num_beam_groups: Optional[int] = None """Number of groups to divide num_beams into in order to ensure diversity among different groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details. """ num_beams: Optional[int] = None """Number of beams to use for beam search.""" penalty_alpha: Optional[float] = None """The value balances the model confidence and the degeneration penalty in contrastive search decoding. """ temperature: Optional[float] = None """The value used to modulate the next token probabilities.""" top_k: Optional[int] = None """The number of highest probability vocabulary tokens to keep for top-k-filtering.""" top_p: Optional[float] = None """If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. """ typical_p: Optional[float] = None """Local typicality measures how similar the conditional probability of predicting a target token next is to the expected conditional probability of predicting a random token next, given the partial text already generated. If set to float < 1, the smallest set of the most locally typical tokens with probabilities that add up to typical_p or higher are kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details. """ use_cache: Optional[bool] = None """Whether the model should use the past last key/values attentions to speed up decoding""" @dataclass_with_extra class TextToSpeechParameters(BaseInferenceType): """Additional inference parameters for Text To Speech""" generation_parameters: Optional[TextToSpeechGenerationParameters] = None """Parametrization of the text generation process""" @dataclass_with_extra class TextToSpeechInput(BaseInferenceType): """Inputs for Text To Speech inference""" inputs: str """The input text data""" parameters: Optional[TextToSpeechParameters] = None """Additional inference parameters for Text To Speech""" @dataclass_with_extra class TextToSpeechOutput(BaseInferenceType): """Outputs of inference for the Text To Speech task""" audio: Any """The generated audio""" sampling_rate: Optional[float] = None """The sampling rate of the generated audio waveform.""" huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/text_to_video.py000066400000000000000000000033761500667546600312520ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import Any, List, Optional from .base import BaseInferenceType, dataclass_with_extra @dataclass_with_extra class TextToVideoParameters(BaseInferenceType): """Additional inference parameters for Text To Video""" guidance_scale: Optional[float] = None """A higher guidance scale value encourages the model to generate videos closely linked to the text prompt, but values too high may cause saturation and other artifacts. """ negative_prompt: Optional[List[str]] = None """One or several prompt to guide what NOT to include in video generation.""" num_frames: Optional[float] = None """The num_frames parameter determines how many video frames are generated.""" num_inference_steps: Optional[int] = None """The number of denoising steps. More denoising steps usually lead to a higher quality video at the expense of slower inference. """ seed: Optional[int] = None """Seed for the random number generator.""" @dataclass_with_extra class TextToVideoInput(BaseInferenceType): """Inputs for Text To Video inference""" inputs: str """The input text data (sometimes called "prompt")""" parameters: Optional[TextToVideoParameters] = None """Additional inference parameters for Text To Video""" @dataclass_with_extra class TextToVideoOutput(BaseInferenceType): """Outputs of inference for the Text To Video task""" video: Any """The generated video returned as raw bytes in the payload.""" huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/token_classification.py000066400000000000000000000035731500667546600325700ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import List, Literal, Optional from .base import BaseInferenceType, dataclass_with_extra TokenClassificationAggregationStrategy = Literal["none", "simple", "first", "average", "max"] @dataclass_with_extra class TokenClassificationParameters(BaseInferenceType): """Additional inference parameters for Token Classification""" aggregation_strategy: Optional["TokenClassificationAggregationStrategy"] = None """The strategy used to fuse tokens based on model predictions""" ignore_labels: Optional[List[str]] = None """A list of labels to ignore""" stride: Optional[int] = None """The number of overlapping tokens between chunks when splitting the input text.""" @dataclass_with_extra class TokenClassificationInput(BaseInferenceType): """Inputs for Token Classification inference""" inputs: str """The input text data""" parameters: Optional[TokenClassificationParameters] = None """Additional inference parameters for Token Classification""" @dataclass_with_extra class TokenClassificationOutputElement(BaseInferenceType): """Outputs of inference for the Token Classification task""" end: int """The character position in the input where this group ends.""" score: float """The associated score / probability""" start: int """The character position in the input where this group begins.""" word: str """The corresponding text""" entity: Optional[str] = None """The predicted label for a single token""" entity_group: Optional[str] = None """The predicted label for a group of one or more tokens""" huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/translation.py000066400000000000000000000033431500667546600307260ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import Any, Dict, Literal, Optional from .base import BaseInferenceType, dataclass_with_extra TranslationTruncationStrategy = Literal["do_not_truncate", "longest_first", "only_first", "only_second"] @dataclass_with_extra class TranslationParameters(BaseInferenceType): """Additional inference parameters for Translation""" clean_up_tokenization_spaces: Optional[bool] = None """Whether to clean up the potential extra spaces in the text output.""" generate_parameters: Optional[Dict[str, Any]] = None """Additional parametrization of the text generation algorithm.""" src_lang: Optional[str] = None """The source language of the text. Required for models that can translate from multiple languages. """ tgt_lang: Optional[str] = None """Target language to translate to. Required for models that can translate to multiple languages. """ truncation: Optional["TranslationTruncationStrategy"] = None """The truncation strategy to use.""" @dataclass_with_extra class TranslationInput(BaseInferenceType): """Inputs for Translation inference""" inputs: str """The text to translate.""" parameters: Optional[TranslationParameters] = None """Additional inference parameters for Translation""" @dataclass_with_extra class TranslationOutput(BaseInferenceType): """Outputs of inference for the Translation task""" translation_text: str """The translated text.""" huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/video_classification.py000066400000000000000000000032201500667546600325430ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import Any, Literal, Optional from .base import BaseInferenceType, dataclass_with_extra VideoClassificationOutputTransform = Literal["sigmoid", "softmax", "none"] @dataclass_with_extra class VideoClassificationParameters(BaseInferenceType): """Additional inference parameters for Video Classification""" frame_sampling_rate: Optional[int] = None """The sampling rate used to select frames from the video.""" function_to_apply: Optional["VideoClassificationOutputTransform"] = None """The function to apply to the model outputs in order to retrieve the scores.""" num_frames: Optional[int] = None """The number of sampled frames to consider for classification.""" top_k: Optional[int] = None """When specified, limits the output to the top K most probable classes.""" @dataclass_with_extra class VideoClassificationInput(BaseInferenceType): """Inputs for Video Classification inference""" inputs: Any """The input video data""" parameters: Optional[VideoClassificationParameters] = None """Additional inference parameters for Video Classification""" @dataclass_with_extra class VideoClassificationOutputElement(BaseInferenceType): """Outputs of inference for the Video Classification task""" label: str """The predicted class label.""" score: float """The corresponding probability.""" huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/visual_question_answering.py000066400000000000000000000032111500667546600336710ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import Any, Optional from .base import BaseInferenceType, dataclass_with_extra @dataclass_with_extra class VisualQuestionAnsweringInputData(BaseInferenceType): """One (image, question) pair to answer""" image: Any """The image.""" question: str """The question to answer based on the image.""" @dataclass_with_extra class VisualQuestionAnsweringParameters(BaseInferenceType): """Additional inference parameters for Visual Question Answering""" top_k: Optional[int] = None """The number of answers to return (will be chosen by order of likelihood). Note that we return less than topk answers if there are not enough options available within the context. """ @dataclass_with_extra class VisualQuestionAnsweringInput(BaseInferenceType): """Inputs for Visual Question Answering inference""" inputs: VisualQuestionAnsweringInputData """One (image, question) pair to answer""" parameters: Optional[VisualQuestionAnsweringParameters] = None """Additional inference parameters for Visual Question Answering""" @dataclass_with_extra class VisualQuestionAnsweringOutputElement(BaseInferenceType): """Outputs of inference for the Visual Question Answering task""" score: float """The associated score / probability""" answer: Optional[str] = None """The answer to the question""" huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/zero_shot_classification.py000066400000000000000000000033121500667546600334530ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import List, Optional from .base import BaseInferenceType, dataclass_with_extra @dataclass_with_extra class ZeroShotClassificationParameters(BaseInferenceType): """Additional inference parameters for Zero Shot Classification""" candidate_labels: List[str] """The set of possible class labels to classify the text into.""" hypothesis_template: Optional[str] = None """The sentence used in conjunction with `candidate_labels` to attempt the text classification by replacing the placeholder with the candidate labels. """ multi_label: Optional[bool] = None """Whether multiple candidate labels can be true. If false, the scores are normalized such that the sum of the label likelihoods for each sequence is 1. If true, the labels are considered independent and probabilities are normalized for each candidate. """ @dataclass_with_extra class ZeroShotClassificationInput(BaseInferenceType): """Inputs for Zero Shot Classification inference""" inputs: str """The text to classify""" parameters: ZeroShotClassificationParameters """Additional inference parameters for Zero Shot Classification""" @dataclass_with_extra class ZeroShotClassificationOutputElement(BaseInferenceType): """Outputs of inference for the Zero Shot Classification task""" label: str """The predicted class label.""" score: float """The corresponding probability.""" zero_shot_image_classification.py000066400000000000000000000027171500667546600345460ustar00rootroot00000000000000huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import List, Optional from .base import BaseInferenceType, dataclass_with_extra @dataclass_with_extra class ZeroShotImageClassificationParameters(BaseInferenceType): """Additional inference parameters for Zero Shot Image Classification""" candidate_labels: List[str] """The candidate labels for this image""" hypothesis_template: Optional[str] = None """The sentence used in conjunction with `candidate_labels` to attempt the image classification by replacing the placeholder with the candidate labels. """ @dataclass_with_extra class ZeroShotImageClassificationInput(BaseInferenceType): """Inputs for Zero Shot Image Classification inference""" inputs: str """The input image data to classify as a base64-encoded string.""" parameters: ZeroShotImageClassificationParameters """Additional inference parameters for Zero Shot Image Classification""" @dataclass_with_extra class ZeroShotImageClassificationOutputElement(BaseInferenceType): """Outputs of inference for the Zero Shot Image Classification task""" label: str """The predicted class label.""" score: float """The corresponding probability.""" huggingface_hub-0.31.1/src/huggingface_hub/inference/_generated/types/zero_shot_object_detection.py000066400000000000000000000031361500667546600337700ustar00rootroot00000000000000# Inference code generated from the JSON schema spec in @huggingface/tasks. # # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. from typing import List from .base import BaseInferenceType, dataclass_with_extra @dataclass_with_extra class ZeroShotObjectDetectionParameters(BaseInferenceType): """Additional inference parameters for Zero Shot Object Detection""" candidate_labels: List[str] """The candidate labels for this image""" @dataclass_with_extra class ZeroShotObjectDetectionInput(BaseInferenceType): """Inputs for Zero Shot Object Detection inference""" inputs: str """The input image data as a base64-encoded string.""" parameters: ZeroShotObjectDetectionParameters """Additional inference parameters for Zero Shot Object Detection""" @dataclass_with_extra class ZeroShotObjectDetectionBoundingBox(BaseInferenceType): """The predicted bounding box. Coordinates are relative to the top left corner of the input image. """ xmax: int xmin: int ymax: int ymin: int @dataclass_with_extra class ZeroShotObjectDetectionOutputElement(BaseInferenceType): """Outputs of inference for the Zero Shot Object Detection task""" box: ZeroShotObjectDetectionBoundingBox """The predicted bounding box. Coordinates are relative to the top left corner of the input image. """ label: str """A candidate label""" score: float """The associated score / probability""" huggingface_hub-0.31.1/src/huggingface_hub/inference/_providers/000077500000000000000000000000001500667546600247265ustar00rootroot00000000000000huggingface_hub-0.31.1/src/huggingface_hub/inference/_providers/__init__.py000066400000000000000000000162671500667546600270530ustar00rootroot00000000000000from typing import Dict, Literal, Optional, Union from huggingface_hub.utils import logging from ._common import TaskProviderHelper, _fetch_inference_provider_mapping from .black_forest_labs import BlackForestLabsTextToImageTask from .cerebras import CerebrasConversationalTask from .cohere import CohereConversationalTask from .fal_ai import ( FalAIAutomaticSpeechRecognitionTask, FalAITextToImageTask, FalAITextToSpeechTask, FalAITextToVideoTask, ) from .fireworks_ai import FireworksAIConversationalTask from .hf_inference import ( HFInferenceBinaryInputTask, HFInferenceConversational, HFInferenceFeatureExtractionTask, HFInferenceTask, ) from .hyperbolic import HyperbolicTextGenerationTask, HyperbolicTextToImageTask from .nebius import NebiusConversationalTask, NebiusTextGenerationTask, NebiusTextToImageTask from .novita import NovitaConversationalTask, NovitaTextGenerationTask, NovitaTextToVideoTask from .openai import OpenAIConversationalTask from .replicate import ReplicateTask, ReplicateTextToImageTask, ReplicateTextToSpeechTask from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask logger = logging.get_logger(__name__) PROVIDER_T = Literal[ "black-forest-labs", "cerebras", "cohere", "fal-ai", "fireworks-ai", "hf-inference", "hyperbolic", "nebius", "novita", "openai", "replicate", "sambanova", "together", ] PROVIDER_OR_POLICY_T = Union[PROVIDER_T, Literal["auto"]] PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = { "black-forest-labs": { "text-to-image": BlackForestLabsTextToImageTask(), }, "cerebras": { "conversational": CerebrasConversationalTask(), }, "cohere": { "conversational": CohereConversationalTask(), }, "fal-ai": { "automatic-speech-recognition": FalAIAutomaticSpeechRecognitionTask(), "text-to-image": FalAITextToImageTask(), "text-to-speech": FalAITextToSpeechTask(), "text-to-video": FalAITextToVideoTask(), }, "fireworks-ai": { "conversational": FireworksAIConversationalTask(), }, "hf-inference": { "text-to-image": HFInferenceTask("text-to-image"), "conversational": HFInferenceConversational(), "text-generation": HFInferenceTask("text-generation"), "text-classification": HFInferenceTask("text-classification"), "question-answering": HFInferenceTask("question-answering"), "audio-classification": HFInferenceBinaryInputTask("audio-classification"), "automatic-speech-recognition": HFInferenceBinaryInputTask("automatic-speech-recognition"), "fill-mask": HFInferenceTask("fill-mask"), "feature-extraction": HFInferenceFeatureExtractionTask(), "image-classification": HFInferenceBinaryInputTask("image-classification"), "image-segmentation": HFInferenceBinaryInputTask("image-segmentation"), "document-question-answering": HFInferenceTask("document-question-answering"), "image-to-text": HFInferenceBinaryInputTask("image-to-text"), "object-detection": HFInferenceBinaryInputTask("object-detection"), "audio-to-audio": HFInferenceBinaryInputTask("audio-to-audio"), "zero-shot-image-classification": HFInferenceBinaryInputTask("zero-shot-image-classification"), "zero-shot-classification": HFInferenceTask("zero-shot-classification"), "image-to-image": HFInferenceBinaryInputTask("image-to-image"), "sentence-similarity": HFInferenceTask("sentence-similarity"), "table-question-answering": HFInferenceTask("table-question-answering"), "tabular-classification": HFInferenceTask("tabular-classification"), "text-to-speech": HFInferenceTask("text-to-speech"), "token-classification": HFInferenceTask("token-classification"), "translation": HFInferenceTask("translation"), "summarization": HFInferenceTask("summarization"), "visual-question-answering": HFInferenceBinaryInputTask("visual-question-answering"), }, "hyperbolic": { "text-to-image": HyperbolicTextToImageTask(), "conversational": HyperbolicTextGenerationTask("conversational"), "text-generation": HyperbolicTextGenerationTask("text-generation"), }, "nebius": { "text-to-image": NebiusTextToImageTask(), "conversational": NebiusConversationalTask(), "text-generation": NebiusTextGenerationTask(), }, "novita": { "text-generation": NovitaTextGenerationTask(), "conversational": NovitaConversationalTask(), "text-to-video": NovitaTextToVideoTask(), }, "openai": { "conversational": OpenAIConversationalTask(), }, "replicate": { "text-to-image": ReplicateTextToImageTask(), "text-to-speech": ReplicateTextToSpeechTask(), "text-to-video": ReplicateTask("text-to-video"), }, "sambanova": { "conversational": SambanovaConversationalTask(), "feature-extraction": SambanovaFeatureExtractionTask(), }, "together": { "text-to-image": TogetherTextToImageTask(), "conversational": TogetherConversationalTask(), "text-generation": TogetherTextGenerationTask(), }, } def get_provider_helper( provider: Optional[PROVIDER_OR_POLICY_T], task: str, model: Optional[str] ) -> TaskProviderHelper: """Get provider helper instance by name and task. Args: provider (`str`, *optional*): name of the provider, or "auto" to automatically select the provider for the model. task (`str`): Name of the task model (`str`, *optional*): Name of the model Returns: TaskProviderHelper: Helper instance for the specified provider and task Raises: ValueError: If provider or task is not supported """ if (model is None and provider in (None, "auto")) or ( model is not None and model.startswith(("http://", "https://")) ): provider = "hf-inference" if provider is None: logger.info( "Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers." ) provider = "auto" if provider == "auto": if model is None: raise ValueError("Specifying a model is required when provider is 'auto'") provider_mapping = _fetch_inference_provider_mapping(model) provider = next(iter(provider_mapping)) provider_tasks = PROVIDERS.get(provider) # type: ignore if provider_tasks is None: raise ValueError( f"Provider '{provider}' not supported. Available values: 'auto' or any provider from {list(PROVIDERS.keys())}." "Passing 'auto' (default value) will automatically select the first provider available for the model, sorted " "by the user's order in https://hf.co/settings/inference-providers." ) if task not in provider_tasks: raise ValueError( f"Task '{task}' not supported for provider '{provider}'. Available tasks: {list(provider_tasks.keys())}" ) return provider_tasks[task] huggingface_hub-0.31.1/src/huggingface_hub/inference/_providers/_common.py000066400000000000000000000236011500667546600267310ustar00rootroot00000000000000from functools import lru_cache from typing import Any, Dict, Optional, Union from huggingface_hub import constants from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters from huggingface_hub.utils import build_hf_headers, get_token, logging logger = logging.get_logger(__name__) # Dev purposes only. # If you want to try to run inference for a new model locally before it's registered on huggingface.co # for a given Inference Provider, you can add it to the following dictionary. HARDCODED_MODEL_INFERENCE_MAPPING: Dict[str, Dict[str, InferenceProviderMapping]] = { # "HF model ID" => InferenceProviderMapping object initialized with "Model ID on Inference Provider's side" # # Example: # "Qwen/Qwen2.5-Coder-32B-Instruct": InferenceProviderMapping(hf_model_id="Qwen/Qwen2.5-Coder-32B-Instruct", # provider_id="Qwen2.5-Coder-32B-Instruct", # task="conversational", # status="live") "cerebras": {}, "cohere": {}, "fal-ai": {}, "fireworks-ai": {}, "hf-inference": {}, "hyperbolic": {}, "nebius": {}, "replicate": {}, "sambanova": {}, "together": {}, } def filter_none(d: Dict[str, Any]) -> Dict[str, Any]: return {k: v for k, v in d.items() if v is not None} class TaskProviderHelper: """Base class for task-specific provider helpers.""" def __init__(self, provider: str, base_url: str, task: str) -> None: self.provider = provider self.task = task self.base_url = base_url def prepare_request( self, *, inputs: Any, parameters: Dict[str, Any], headers: Dict, model: Optional[str], api_key: Optional[str], extra_payload: Optional[Dict[str, Any]] = None, ) -> RequestParameters: """ Prepare the request to be sent to the provider. Each step (api_key, model, headers, url, payload) can be customized in subclasses. """ # api_key from user, or local token, or raise error api_key = self._prepare_api_key(api_key) # mapped model from HF model ID provider_mapping_info = self._prepare_mapping_info(model) # default HF headers + user headers (to customize in subclasses) headers = self._prepare_headers(headers, api_key) # routed URL if HF token, or direct URL (to customize in '_prepare_route' in subclasses) url = self._prepare_url(api_key, provider_mapping_info.provider_id) # prepare payload (to customize in subclasses) payload = self._prepare_payload_as_dict(inputs, parameters, provider_mapping_info=provider_mapping_info) if payload is not None: payload = recursive_merge(payload, extra_payload or {}) # body data (to customize in subclasses) data = self._prepare_payload_as_bytes(inputs, parameters, provider_mapping_info, extra_payload) # check if both payload and data are set and return if payload is not None and data is not None: raise ValueError("Both payload and data cannot be set in the same request.") if payload is None and data is None: raise ValueError("Either payload or data must be set in the request.") return RequestParameters( url=url, task=self.task, model=provider_mapping_info.provider_id, json=payload, data=data, headers=headers ) def get_response( self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None, ) -> Any: """ Return the response in the expected format. Override this method in subclasses for customized response handling.""" return response def _prepare_api_key(self, api_key: Optional[str]) -> str: """Return the API key to use for the request. Usually not overwritten in subclasses.""" if api_key is None: api_key = get_token() if api_key is None: raise ValueError( f"You must provide an api_key to work with {self.provider} API or log in with `huggingface-cli login`." ) return api_key def _prepare_mapping_info(self, model: Optional[str]) -> InferenceProviderMapping: """Return the mapped model ID to use for the request. Usually not overwritten in subclasses.""" if model is None: raise ValueError(f"Please provide an HF model ID supported by {self.provider}.") # hardcoded mapping for local testing if HARDCODED_MODEL_INFERENCE_MAPPING.get(self.provider, {}).get(model): return HARDCODED_MODEL_INFERENCE_MAPPING[self.provider][model] provider_mapping = _fetch_inference_provider_mapping(model).get(self.provider) if provider_mapping is None: raise ValueError(f"Model {model} is not supported by provider {self.provider}.") if provider_mapping.task != self.task: raise ValueError( f"Model {model} is not supported for task {self.task} and provider {self.provider}. " f"Supported task: {provider_mapping.task}." ) if provider_mapping.status == "staging": logger.warning( f"Model {model} is in staging mode for provider {self.provider}. Meant for test purposes only." ) return provider_mapping def _prepare_headers(self, headers: Dict, api_key: str) -> Dict: """Return the headers to use for the request. Override this method in subclasses for customized headers. """ return {**build_hf_headers(token=api_key), **headers} def _prepare_url(self, api_key: str, mapped_model: str) -> str: """Return the URL to use for the request. Usually not overwritten in subclasses.""" base_url = self._prepare_base_url(api_key) route = self._prepare_route(mapped_model, api_key) return f"{base_url.rstrip('/')}/{route.lstrip('/')}" def _prepare_base_url(self, api_key: str) -> str: """Return the base URL to use for the request. Usually not overwritten in subclasses.""" # Route to the proxy if the api_key is a HF TOKEN if api_key.startswith("hf_"): logger.info(f"Calling '{self.provider}' provider through Hugging Face router.") return constants.INFERENCE_PROXY_TEMPLATE.format(provider=self.provider) else: logger.info(f"Calling '{self.provider}' provider directly.") return self.base_url def _prepare_route(self, mapped_model: str, api_key: str) -> str: """Return the route to use for the request. Override this method in subclasses for customized routes. """ return "" def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: """Return the payload to use for the request, as a dict. Override this method in subclasses for customized payloads. Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value. """ return None def _prepare_payload_as_bytes( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping, extra_payload: Optional[Dict], ) -> Optional[bytes]: """Return the body to use for the request, as bytes. Override this method in subclasses for customized body data. Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value. """ return None class BaseConversationalTask(TaskProviderHelper): """ Base class for conversational (chat completion) tasks. The schema follows the OpenAI API format defined here: https://platform.openai.com/docs/api-reference/chat """ def __init__(self, provider: str, base_url: str): super().__init__(provider=provider, base_url=base_url, task="conversational") def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "/v1/chat/completions" def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: return {"messages": inputs, **filter_none(parameters), "model": provider_mapping_info.provider_id} class BaseTextGenerationTask(TaskProviderHelper): """ Base class for text-generation (completion) tasks. The schema follows the OpenAI API format defined here: https://platform.openai.com/docs/api-reference/completions """ def __init__(self, provider: str, base_url: str): super().__init__(provider=provider, base_url=base_url, task="text-generation") def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "/v1/completions" def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: return {"prompt": inputs, **filter_none(parameters), "model": provider_mapping_info.provider_id} @lru_cache(maxsize=None) def _fetch_inference_provider_mapping(model: str) -> Dict: """ Fetch provider mappings for a model from the Hub. """ from huggingface_hub.hf_api import HfApi info = HfApi().model_info(model, expand=["inferenceProviderMapping"]) provider_mapping = info.inference_provider_mapping if provider_mapping is None: raise ValueError(f"No provider mapping found for model {model}") return provider_mapping def recursive_merge(dict1: Dict, dict2: Dict) -> Dict: return { **dict1, **{ key: recursive_merge(dict1[key], value) if (key in dict1 and isinstance(dict1[key], dict) and isinstance(value, dict)) else value for key, value in dict2.items() }, } huggingface_hub-0.31.1/src/huggingface_hub/inference/_providers/black_forest_labs.py000066400000000000000000000054321500667546600307430ustar00rootroot00000000000000import time from typing import Any, Dict, Optional, Union from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _as_dict from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none from huggingface_hub.utils import logging from huggingface_hub.utils._http import get_session logger = logging.get_logger(__name__) MAX_POLLING_ATTEMPTS = 6 POLLING_INTERVAL = 1.0 class BlackForestLabsTextToImageTask(TaskProviderHelper): def __init__(self): super().__init__(provider="black-forest-labs", base_url="https://api.us1.bfl.ai", task="text-to-image") def _prepare_headers(self, headers: Dict, api_key: str) -> Dict: headers = super()._prepare_headers(headers, api_key) if not api_key.startswith("hf_"): _ = headers.pop("authorization") headers["X-Key"] = api_key return headers def _prepare_route(self, mapped_model: str, api_key: str) -> str: return f"/v1/{mapped_model}" def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: parameters = filter_none(parameters) if "num_inference_steps" in parameters: parameters["steps"] = parameters.pop("num_inference_steps") if "guidance_scale" in parameters: parameters["guidance"] = parameters.pop("guidance_scale") return {"prompt": inputs, **parameters} def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: """ Polling mechanism for Black Forest Labs since the API is asynchronous. """ url = _as_dict(response).get("polling_url") session = get_session() for _ in range(MAX_POLLING_ATTEMPTS): time.sleep(POLLING_INTERVAL) response = session.get(url, headers={"Content-Type": "application/json"}) # type: ignore response.raise_for_status() # type: ignore response_json: Dict = response.json() # type: ignore status = response_json.get("status") logger.info( f"Polling generation result from {url}. Current status: {status}. " f"Will retry after {POLLING_INTERVAL} seconds if not ready." ) if ( status == "Ready" and isinstance(response_json.get("result"), dict) and (sample_url := response_json["result"].get("sample")) ): image_resp = session.get(sample_url) image_resp.raise_for_status() return image_resp.content raise TimeoutError(f"Failed to get the image URL after {MAX_POLLING_ATTEMPTS} attempts.") huggingface_hub-0.31.1/src/huggingface_hub/inference/_providers/cerebras.py000066400000000000000000000003661500667546600270730ustar00rootroot00000000000000from huggingface_hub.inference._providers._common import BaseConversationalTask class CerebrasConversationalTask(BaseConversationalTask): def __init__(self): super().__init__(provider="cerebras", base_url="https://api.cerebras.ai") huggingface_hub-0.31.1/src/huggingface_hub/inference/_providers/cohere.py000066400000000000000000000006431500667546600265500ustar00rootroot00000000000000from huggingface_hub.inference._providers._common import ( BaseConversationalTask, ) _PROVIDER = "cohere" _BASE_URL = "https://api.cohere.com" class CohereConversationalTask(BaseConversationalTask): def __init__(self): super().__init__(provider=_PROVIDER, base_url=_BASE_URL) def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "/compatibility/v1/chat/completions" huggingface_hub-0.31.1/src/huggingface_hub/inference/_providers/fal_ai.py000066400000000000000000000157721500667546600265270ustar00rootroot00000000000000import base64 import time from abc import ABC from typing import Any, Dict, Optional, Union from urllib.parse import urlparse from huggingface_hub import constants from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _as_dict from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none from huggingface_hub.utils import get_session, hf_raise_for_status from huggingface_hub.utils.logging import get_logger logger = get_logger(__name__) # Arbitrary polling interval _POLLING_INTERVAL = 0.5 class FalAITask(TaskProviderHelper, ABC): def __init__(self, task: str): super().__init__(provider="fal-ai", base_url="https://fal.run", task=task) def _prepare_headers(self, headers: Dict, api_key: str) -> Dict: headers = super()._prepare_headers(headers, api_key) if not api_key.startswith("hf_"): headers["authorization"] = f"Key {api_key}" return headers def _prepare_route(self, mapped_model: str, api_key: str) -> str: return f"/{mapped_model}" class FalAIAutomaticSpeechRecognitionTask(FalAITask): def __init__(self): super().__init__("automatic-speech-recognition") def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: if isinstance(inputs, str) and inputs.startswith(("http://", "https://")): # If input is a URL, pass it directly audio_url = inputs else: # If input is a file path, read it first if isinstance(inputs, str): with open(inputs, "rb") as f: inputs = f.read() audio_b64 = base64.b64encode(inputs).decode() content_type = "audio/mpeg" audio_url = f"data:{content_type};base64,{audio_b64}" return {"audio_url": audio_url, **filter_none(parameters)} def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: text = _as_dict(response)["text"] if not isinstance(text, str): raise ValueError(f"Unexpected output format from FalAI API. Expected string, got {type(text)}.") return text class FalAITextToImageTask(FalAITask): def __init__(self): super().__init__("text-to-image") def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: payload: Dict[str, Any] = { "prompt": inputs, **filter_none(parameters), } if "width" in payload and "height" in payload: payload["image_size"] = { "width": payload.pop("width"), "height": payload.pop("height"), } if provider_mapping_info.adapter_weights_path is not None: lora_path = constants.HUGGINGFACE_CO_URL_TEMPLATE.format( repo_id=provider_mapping_info.hf_model_id, revision="main", filename=provider_mapping_info.adapter_weights_path, ) payload["loras"] = [{"path": lora_path, "scale": 1}] if provider_mapping_info.provider_id == "fal-ai/lora": # little hack: fal requires the base model for stable-diffusion-based loras but not for flux-based # See payloads in https://fal.ai/models/fal-ai/lora/api vs https://fal.ai/models/fal-ai/flux-lora/api payload["model_name"] = "stabilityai/stable-diffusion-xl-base-1.0" return payload def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: url = _as_dict(response)["images"][0]["url"] return get_session().get(url).content class FalAITextToSpeechTask(FalAITask): def __init__(self): super().__init__("text-to-speech") def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: return {"text": inputs, **filter_none(parameters)} def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: url = _as_dict(response)["audio"]["url"] return get_session().get(url).content class FalAITextToVideoTask(FalAITask): def __init__(self): super().__init__("text-to-video") def _prepare_base_url(self, api_key: str) -> str: if api_key.startswith("hf_"): return super()._prepare_base_url(api_key) else: logger.info(f"Calling '{self.provider}' provider directly.") return "https://queue.fal.run" def _prepare_route(self, mapped_model: str, api_key: str) -> str: if api_key.startswith("hf_"): # Use the queue subdomain for HF routing return f"/{mapped_model}?_subdomain=queue" return f"/{mapped_model}" def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: return {"prompt": inputs, **filter_none(parameters)} def get_response( self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None, ) -> Any: response_dict = _as_dict(response) request_id = response_dict.get("request_id") if not request_id: raise ValueError("No request ID found in the response") if request_params is None: raise ValueError( "A `RequestParameters` object should be provided to get text-to-video responses with Fal AI." ) # extract the base url and query params parsed_url = urlparse(request_params.url) # a bit hacky way to concatenate the provider name without parsing `parsed_url.path` base_url = f"{parsed_url.scheme}://{parsed_url.netloc}{'/fal-ai' if parsed_url.netloc == 'router.huggingface.co' else ''}" query_param = f"?{parsed_url.query}" if parsed_url.query else "" # extracting the provider model id for status and result urls # from the response as it might be different from the mapped model in `request_params.url` model_id = urlparse(response_dict.get("response_url")).path status_url = f"{base_url}{str(model_id)}/status{query_param}" result_url = f"{base_url}{str(model_id)}{query_param}" status = response_dict.get("status") logger.info("Generating the video.. this can take several minutes.") while status != "COMPLETED": time.sleep(_POLLING_INTERVAL) status_response = get_session().get(status_url, headers=request_params.headers) hf_raise_for_status(status_response) status = status_response.json().get("status") response = get_session().get(result_url, headers=request_params.headers).json() url = _as_dict(response)["video"]["url"] return get_session().get(url).content huggingface_hub-0.31.1/src/huggingface_hub/inference/_providers/fireworks_ai.py000066400000000000000000000005211500667546600277620ustar00rootroot00000000000000from ._common import BaseConversationalTask class FireworksAIConversationalTask(BaseConversationalTask): def __init__(self): super().__init__(provider="fireworks-ai", base_url="https://api.fireworks.ai") def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "/inference/v1/chat/completions" huggingface_hub-0.31.1/src/huggingface_hub/inference/_providers/hf_inference.py000066400000000000000000000175761500667546600277330ustar00rootroot00000000000000import json from functools import lru_cache from pathlib import Path from typing import Any, Dict, Optional, Union from huggingface_hub import constants from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _b64_encode, _bytes_to_dict, _open_as_binary from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none from huggingface_hub.utils import build_hf_headers, get_session, get_token, hf_raise_for_status class HFInferenceTask(TaskProviderHelper): """Base class for HF Inference API tasks.""" def __init__(self, task: str): super().__init__( provider="hf-inference", base_url=constants.INFERENCE_PROXY_TEMPLATE.format(provider="hf-inference"), task=task, ) def _prepare_api_key(self, api_key: Optional[str]) -> str: # special case: for HF Inference we allow not providing an API key return api_key or get_token() # type: ignore[return-value] def _prepare_mapping_info(self, model: Optional[str]) -> InferenceProviderMapping: if model is not None and model.startswith(("http://", "https://")): return InferenceProviderMapping(providerId=model, hf_model_id=model, task=self.task, status="live") model_id = model if model is not None else _fetch_recommended_models().get(self.task) if model_id is None: raise ValueError( f"Task {self.task} has no recommended model for HF Inference. Please specify a model" " explicitly. Visit https://huggingface.co/tasks for more info." ) _check_supported_task(model_id, self.task) return InferenceProviderMapping(providerId=model_id, hf_model_id=model_id, task=self.task, status="live") def _prepare_url(self, api_key: str, mapped_model: str) -> str: # hf-inference provider can handle URLs (e.g. Inference Endpoints or TGI deployment) if mapped_model.startswith(("http://", "https://")): return mapped_model return ( # Feature-extraction and sentence-similarity are the only cases where we handle models with several tasks. f"{self.base_url}/models/{mapped_model}/pipeline/{self.task}" if self.task in ("feature-extraction", "sentence-similarity") # Otherwise, we use the default endpoint else f"{self.base_url}/models/{mapped_model}" ) def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: if isinstance(inputs, bytes): raise ValueError(f"Unexpected binary input for task {self.task}.") if isinstance(inputs, Path): raise ValueError(f"Unexpected path input for task {self.task} (got {inputs})") return {"inputs": inputs, "parameters": filter_none(parameters)} class HFInferenceBinaryInputTask(HFInferenceTask): def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: return None def _prepare_payload_as_bytes( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping, extra_payload: Optional[Dict], ) -> Optional[bytes]: parameters = filter_none({k: v for k, v in parameters.items() if v is not None}) extra_payload = extra_payload or {} has_parameters = len(parameters) > 0 or len(extra_payload) > 0 # Raise if not a binary object or a local path or a URL. if not isinstance(inputs, (bytes, Path)) and not isinstance(inputs, str): raise ValueError(f"Expected binary inputs or a local path or a URL. Got {inputs}") # Send inputs as raw content when no parameters are provided if not has_parameters: with _open_as_binary(inputs) as data: data_as_bytes = data if isinstance(data, bytes) else data.read() return data_as_bytes # Otherwise encode as b64 return json.dumps({"inputs": _b64_encode(inputs), "parameters": parameters, **extra_payload}).encode("utf-8") class HFInferenceConversational(HFInferenceTask): def __init__(self): super().__init__("conversational") def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: mapped_model = provider_mapping_info.provider_id payload_model = parameters.get("model") or mapped_model if payload_model is None or payload_model.startswith(("http://", "https://")): payload_model = "dummy" return {**filter_none(parameters), "model": payload_model, "messages": inputs} def _prepare_url(self, api_key: str, mapped_model: str) -> str: base_url = ( mapped_model if mapped_model.startswith(("http://", "https://")) else f"{constants.INFERENCE_PROXY_TEMPLATE.format(provider='hf-inference')}/models/{mapped_model}" ) return _build_chat_completion_url(base_url) def _build_chat_completion_url(model_url: str) -> str: # Strip trailing / model_url = model_url.rstrip("/") # Append /chat/completions if not already present if model_url.endswith("/v1"): model_url += "/chat/completions" # Append /v1/chat/completions if not already present if not model_url.endswith("/chat/completions"): model_url += "/v1/chat/completions" return model_url @lru_cache(maxsize=1) def _fetch_recommended_models() -> Dict[str, Optional[str]]: response = get_session().get(f"{constants.ENDPOINT}/api/tasks", headers=build_hf_headers()) hf_raise_for_status(response) return {task: next(iter(details["widgetModels"]), None) for task, details in response.json().items()} @lru_cache(maxsize=None) def _check_supported_task(model: str, task: str) -> None: from huggingface_hub.hf_api import HfApi model_info = HfApi().model_info(model) pipeline_tag = model_info.pipeline_tag tags = model_info.tags or [] is_conversational = "conversational" in tags if task in ("text-generation", "conversational"): if pipeline_tag == "text-generation": # text-generation + conversational tag -> both tasks allowed if is_conversational: return # text-generation without conversational tag -> only text-generation allowed if task == "text-generation": return raise ValueError(f"Model '{model}' doesn't support task '{task}'.") if pipeline_tag == "text2text-generation": if task == "text-generation": return raise ValueError(f"Model '{model}' doesn't support task '{task}'.") if pipeline_tag == "image-text-to-text": if is_conversational and task == "conversational": return # Only conversational allowed if tagged as conversational raise ValueError("Non-conversational image-text-to-text task is not supported.") if ( task in ("feature-extraction", "sentence-similarity") and pipeline_tag in ("feature-extraction", "sentence-similarity") and task in tags ): # feature-extraction and sentence-similarity are interchangeable for HF Inference return # For all other tasks, just check pipeline tag if pipeline_tag != task: raise ValueError( f"Model '{model}' doesn't support task '{task}'. Supported tasks: '{pipeline_tag}', got: '{task}'" ) return class HFInferenceFeatureExtractionTask(HFInferenceTask): def __init__(self): super().__init__("feature-extraction") def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: if isinstance(response, bytes): return _bytes_to_dict(response) return response huggingface_hub-0.31.1/src/huggingface_hub/inference/_providers/hyperbolic.py000066400000000000000000000037011500667546600274410ustar00rootroot00000000000000import base64 from typing import Any, Dict, Optional, Union from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _as_dict from huggingface_hub.inference._providers._common import BaseConversationalTask, TaskProviderHelper, filter_none class HyperbolicTextToImageTask(TaskProviderHelper): def __init__(self): super().__init__(provider="hyperbolic", base_url="https://api.hyperbolic.xyz", task="text-to-image") def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "/v1/images/generations" def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: mapped_model = provider_mapping_info.provider_id parameters = filter_none(parameters) if "num_inference_steps" in parameters: parameters["steps"] = parameters.pop("num_inference_steps") if "guidance_scale" in parameters: parameters["cfg_scale"] = parameters.pop("guidance_scale") # For Hyperbolic, the width and height are required parameters if "width" not in parameters: parameters["width"] = 512 if "height" not in parameters: parameters["height"] = 512 return {"prompt": inputs, "model_name": mapped_model, **parameters} def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: response_dict = _as_dict(response) return base64.b64decode(response_dict["images"][0]["image"]) class HyperbolicTextGenerationTask(BaseConversationalTask): """ Special case for Hyperbolic, where text-generation task is handled as a conversational task. """ def __init__(self, task: str): super().__init__( provider="hyperbolic", base_url="https://api.hyperbolic.xyz", ) self.task = task huggingface_hub-0.31.1/src/huggingface_hub/inference/_providers/nebius.py000066400000000000000000000041351500667546600265700ustar00rootroot00000000000000import base64 from typing import Any, Dict, Optional, Union from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _as_dict from huggingface_hub.inference._providers._common import ( BaseConversationalTask, BaseTextGenerationTask, TaskProviderHelper, filter_none, ) class NebiusTextGenerationTask(BaseTextGenerationTask): def __init__(self): super().__init__(provider="nebius", base_url="https://api.studio.nebius.ai") def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: output = _as_dict(response)["choices"][0] return { "generated_text": output["text"], "details": { "finish_reason": output.get("finish_reason"), "seed": output.get("seed"), }, } class NebiusConversationalTask(BaseConversationalTask): def __init__(self): super().__init__(provider="nebius", base_url="https://api.studio.nebius.ai") class NebiusTextToImageTask(TaskProviderHelper): def __init__(self): super().__init__(task="text-to-image", provider="nebius", base_url="https://api.studio.nebius.ai") def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "/v1/images/generations" def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: mapped_model = provider_mapping_info.provider_id parameters = filter_none(parameters) if "guidance_scale" in parameters: parameters.pop("guidance_scale") if parameters.get("response_format") not in ("b64_json", "url"): parameters["response_format"] = "b64_json" return {"prompt": inputs, **parameters, "model": mapped_model} def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: response_dict = _as_dict(response) return base64.b64decode(response_dict["data"][0]["b64_json"]) huggingface_hub-0.31.1/src/huggingface_hub/inference/_providers/new_provider.md000066400000000000000000000113241500667546600277540ustar00rootroot00000000000000## How to add a new provider? Before adding a new provider to the `huggingface_hub` library, make sure it has already been added to `huggingface.js` and is working on the Hub. Support in the Python library comes as a second step. In this guide, we are considering that the first part is complete. ### 1. Implement the provider helper Create a new file under `src/huggingface_hub/inference/_providers/{provider_name}.py` and copy-paste the following snippet. Implement the methods that require custom handling. Check out the base implementation to check default behavior. If you don't need to override a method, just remove it. At least one of `_prepare_payload_as_dict` or `_prepare_payload_as_bytes` must be overwritten. If the provider supports multiple tasks that require different implementations, create dedicated subclasses for each task, following the pattern shown in `fal_ai.py`. For `text-generation` and `conversational` tasks, one can just inherit from `BaseTextGenerationTask` and `BaseConversationalTask` respectively (defined in `_common.py`) and override the methods if needed. Examples can be found in `fireworks_ai.py` and `together.py`. ```py from typing import Any, Dict, Optional, Union from ._common import TaskProviderHelper class MyNewProviderTaskProviderHelper(TaskProviderHelper): def __init__(self): """Define high-level parameters.""" super().__init__(provider=..., base_url=..., task=...) def get_response( self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None, ) -> Any: """ Return the response in the expected format. Override this method in subclasses for customized response handling.""" return super().get_response(response) def _prepare_headers(self, headers: Dict, api_key: str) -> Dict: """Return the headers to use for the request. Override this method in subclasses for customized headers. """ return super()._prepare_headers(headers, api_key) def _prepare_route(self, mapped_model: str, api_key: str) -> str: """Return the route to use for the request. Override this method in subclasses for customized routes. """ return super()._prepare_route(mapped_model) def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]: """Return the payload to use for the request, as a dict. Override this method in subclasses for customized payloads. Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value. """ return super()._prepare_payload_as_dict(inputs, parameters, mapped_model) def _prepare_payload_as_bytes( self, inputs: Any, parameters: Dict, mapped_model: str, extra_payload: Optional[Dict] ) -> Optional[bytes]: """Return the body to use for the request, as bytes. Override this method in subclasses for customized body data. Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value. """ return super()._prepare_payload_as_bytes(inputs, parameters, mapped_model, extra_payload) ``` ### 2. Register the provider helper in `__init__.py` Go to `src/huggingface_hub/inference/_providers/__init__.py` and add your provider to `PROVIDER_T` and `PROVIDERS`. Please try to respect alphabetical order. ### 3. Update docstring in `InferenceClient.__init__` to document your provider ### 4. Add static tests in `tests/test_inference_providers.py` You only have to add a test for overwritten methods. ### 5. Add VCR tests in `tests/test_inference_client.py` #### a. Add test model mapping Add an entry to `_RECOMMENDED_MODELS_FOR_VCR` at the top of the test module, It contains a mapping task <> test model. `model-id` must be the HF model id. ```python _RECOMMENDED_MODELS_FOR_VCR = { "your-provider": { "task": "model-id", ... }, ... } ``` #### b. Set up authentication To record VCR cassettes, you'll need authentication: - If you are a member of the provider organization (e.g., Replicate organization: https://huggingface.co/replicate), you can set the `HF_INFERENCE_TEST_TOKEN` environment variable with your HF token: ```bash export HF_INFERENCE_TEST_TOKEN="your-hf-token" ``` - If you're not a member but the provider is officially released on the Hub, you can set the `HF_INFERENCE_TEST_TOKEN` environment variable as above. If you don't have enough inference credits, we can help you record the VCR cassettes. #### c. Record and commit tests 1. Run the tests for your provider: ```bash pytest tests/test_inference_client.py -k ``` 2. Commit the generated VCR cassettes with your PR huggingface_hub-0.31.1/src/huggingface_hub/inference/_providers/novita.py000066400000000000000000000047221500667546600266050ustar00rootroot00000000000000from typing import Any, Dict, Optional, Union from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _as_dict from huggingface_hub.inference._providers._common import ( BaseConversationalTask, BaseTextGenerationTask, TaskProviderHelper, filter_none, ) from huggingface_hub.utils import get_session _PROVIDER = "novita" _BASE_URL = "https://api.novita.ai" class NovitaTextGenerationTask(BaseTextGenerationTask): def __init__(self): super().__init__(provider=_PROVIDER, base_url=_BASE_URL) def _prepare_route(self, mapped_model: str, api_key: str) -> str: # there is no v1/ route for novita return "/v3/openai/completions" def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: output = _as_dict(response)["choices"][0] return { "generated_text": output["text"], "details": { "finish_reason": output.get("finish_reason"), "seed": output.get("seed"), }, } class NovitaConversationalTask(BaseConversationalTask): def __init__(self): super().__init__(provider=_PROVIDER, base_url=_BASE_URL) def _prepare_route(self, mapped_model: str, api_key: str) -> str: # there is no v1/ route for novita return "/v3/openai/chat/completions" class NovitaTextToVideoTask(TaskProviderHelper): def __init__(self): super().__init__(provider=_PROVIDER, base_url=_BASE_URL, task="text-to-video") def _prepare_route(self, mapped_model: str, api_key: str) -> str: return f"/v3/hf/{mapped_model}" def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: return {"prompt": inputs, **filter_none(parameters)} def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: response_dict = _as_dict(response) if not ( isinstance(response_dict, dict) and "video" in response_dict and isinstance(response_dict["video"], dict) and "video_url" in response_dict["video"] ): raise ValueError("Expected response format: { 'video': { 'video_url': string } }") video_url = response_dict["video"]["video_url"] return get_session().get(video_url).content huggingface_hub-0.31.1/src/huggingface_hub/inference/_providers/openai.py000066400000000000000000000020301500667546600265460ustar00rootroot00000000000000from typing import Optional from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._providers._common import BaseConversationalTask class OpenAIConversationalTask(BaseConversationalTask): def __init__(self): super().__init__(provider="openai", base_url="https://api.openai.com") def _prepare_api_key(self, api_key: Optional[str]) -> str: if api_key is None: raise ValueError("You must provide an api_key to work with OpenAI API.") if api_key.startswith("hf_"): raise ValueError( "OpenAI provider is not available through Hugging Face routing, please use your own OpenAI API key." ) return api_key def _prepare_mapping_info(self, model: Optional[str]) -> InferenceProviderMapping: if model is None: raise ValueError("Please provide an OpenAI model ID, e.g. `gpt-4o` or `o1`.") return InferenceProviderMapping(providerId=model, task="conversational", status="live", hf_model_id=model) huggingface_hub-0.31.1/src/huggingface_hub/inference/_providers/replicate.py000066400000000000000000000061251500667546600272540ustar00rootroot00000000000000from typing import Any, Dict, Optional, Union from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _as_dict from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none from huggingface_hub.utils import get_session _PROVIDER = "replicate" _BASE_URL = "https://api.replicate.com" class ReplicateTask(TaskProviderHelper): def __init__(self, task: str): super().__init__(provider=_PROVIDER, base_url=_BASE_URL, task=task) def _prepare_headers(self, headers: Dict, api_key: str) -> Dict: headers = super()._prepare_headers(headers, api_key) headers["Prefer"] = "wait" return headers def _prepare_route(self, mapped_model: str, api_key: str) -> str: if ":" in mapped_model: return "/v1/predictions" return f"/v1/models/{mapped_model}/predictions" def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: mapped_model = provider_mapping_info.provider_id payload: Dict[str, Any] = {"input": {"prompt": inputs, **filter_none(parameters)}} if ":" in mapped_model: version = mapped_model.split(":", 1)[1] payload["version"] = version return payload def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: response_dict = _as_dict(response) if response_dict.get("output") is None: raise TimeoutError( f"Inference request timed out after 60 seconds. No output generated for model {response_dict.get('model')}" "The model might be in cold state or starting up. Please try again later." ) output_url = ( response_dict["output"] if isinstance(response_dict["output"], str) else response_dict["output"][0] ) return get_session().get(output_url).content class ReplicateTextToImageTask(ReplicateTask): def __init__(self): super().__init__("text-to-image") def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: payload: Dict = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) # type: ignore[assignment] if provider_mapping_info.adapter_weights_path is not None: payload["input"]["lora_weights"] = f"https://huggingface.co/{provider_mapping_info.hf_model_id}" return payload class ReplicateTextToSpeechTask(ReplicateTask): def __init__(self): super().__init__("text-to-speech") def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: payload: Dict = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) # type: ignore[assignment] payload["input"]["text"] = payload["input"].pop("prompt") # rename "prompt" to "text" for TTS return payload huggingface_hub-0.31.1/src/huggingface_hub/inference/_providers/sambanova.py000066400000000000000000000024041500667546600272470ustar00rootroot00000000000000from typing import Any, Dict, Optional, Union from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _as_dict from huggingface_hub.inference._providers._common import BaseConversationalTask, TaskProviderHelper, filter_none class SambanovaConversationalTask(BaseConversationalTask): def __init__(self): super().__init__(provider="sambanova", base_url="https://api.sambanova.ai") class SambanovaFeatureExtractionTask(TaskProviderHelper): def __init__(self): super().__init__(provider="sambanova", base_url="https://api.sambanova.ai", task="feature-extraction") def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "/v1/embeddings" def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: parameters = filter_none(parameters) return {"input": inputs, "model": provider_mapping_info.provider_id, **parameters} def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: embeddings = _as_dict(response)["data"] return [embedding["embedding"] for embedding in embeddings] huggingface_hub-0.31.1/src/huggingface_hub/inference/_providers/together.py000066400000000000000000000051451500667546600271260ustar00rootroot00000000000000import base64 from abc import ABC from typing import Any, Dict, Optional, Union from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _as_dict from huggingface_hub.inference._providers._common import ( BaseConversationalTask, BaseTextGenerationTask, TaskProviderHelper, filter_none, ) _PROVIDER = "together" _BASE_URL = "https://api.together.xyz" class TogetherTask(TaskProviderHelper, ABC): """Base class for Together API tasks.""" def __init__(self, task: str): super().__init__(provider=_PROVIDER, base_url=_BASE_URL, task=task) def _prepare_route(self, mapped_model: str, api_key: str) -> str: if self.task == "text-to-image": return "/v1/images/generations" elif self.task == "conversational": return "/v1/chat/completions" elif self.task == "text-generation": return "/v1/completions" raise ValueError(f"Unsupported task '{self.task}' for Together API.") class TogetherTextGenerationTask(BaseTextGenerationTask): def __init__(self): super().__init__(provider=_PROVIDER, base_url=_BASE_URL) def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: output = _as_dict(response)["choices"][0] return { "generated_text": output["text"], "details": { "finish_reason": output.get("finish_reason"), "seed": output.get("seed"), }, } class TogetherConversationalTask(BaseConversationalTask): def __init__(self): super().__init__(provider=_PROVIDER, base_url=_BASE_URL) class TogetherTextToImageTask(TogetherTask): def __init__(self): super().__init__("text-to-image") def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: mapped_model = provider_mapping_info.provider_id parameters = filter_none(parameters) if "num_inference_steps" in parameters: parameters["steps"] = parameters.pop("num_inference_steps") if "guidance_scale" in parameters: parameters["guidance"] = parameters.pop("guidance_scale") return {"prompt": inputs, "response_format": "base64", **parameters, "model": mapped_model} def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: response_dict = _as_dict(response) return base64.b64decode(response_dict["data"][0]["b64_json"]) huggingface_hub-0.31.1/src/huggingface_hub/inference_api.py000066400000000000000000000202031500667546600237520ustar00rootroot00000000000000import io from typing import Any, Dict, List, Optional, Union from . import constants from .hf_api import HfApi from .utils import build_hf_headers, get_session, is_pillow_available, logging, validate_hf_hub_args from .utils._deprecation import _deprecate_method logger = logging.get_logger(__name__) ALL_TASKS = [ # NLP "text-classification", "token-classification", "table-question-answering", "question-answering", "zero-shot-classification", "translation", "summarization", "conversational", "feature-extraction", "text-generation", "text2text-generation", "fill-mask", "sentence-similarity", # Audio "text-to-speech", "automatic-speech-recognition", "audio-to-audio", "audio-classification", "voice-activity-detection", # Computer vision "image-classification", "object-detection", "image-segmentation", "text-to-image", "image-to-image", # Others "tabular-classification", "tabular-regression", ] class InferenceApi: """Client to configure requests and make calls to the HuggingFace Inference API. Example: ```python >>> from huggingface_hub.inference_api import InferenceApi >>> # Mask-fill example >>> inference = InferenceApi("bert-base-uncased") >>> inference(inputs="The goal of life is [MASK].") [{'sequence': 'the goal of life is life.', 'score': 0.10933292657136917, 'token': 2166, 'token_str': 'life'}] >>> # Question Answering example >>> inference = InferenceApi("deepset/roberta-base-squad2") >>> inputs = { ... "question": "What's my name?", ... "context": "My name is Clara and I live in Berkeley.", ... } >>> inference(inputs) {'score': 0.9326569437980652, 'start': 11, 'end': 16, 'answer': 'Clara'} >>> # Zero-shot example >>> inference = InferenceApi("typeform/distilbert-base-uncased-mnli") >>> inputs = "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!" >>> params = {"candidate_labels": ["refund", "legal", "faq"]} >>> inference(inputs, params) {'sequence': 'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!', 'labels': ['refund', 'faq', 'legal'], 'scores': [0.9378499388694763, 0.04914155602455139, 0.013008488342165947]} >>> # Overriding configured task >>> inference = InferenceApi("bert-base-uncased", task="feature-extraction") >>> # Text-to-image >>> inference = InferenceApi("stabilityai/stable-diffusion-2-1") >>> inference("cat") >>> # Return as raw response to parse the output yourself >>> inference = InferenceApi("mio/amadeus") >>> response = inference("hello world", raw_response=True) >>> response.headers {"Content-Type": "audio/flac", ...} >>> response.content # raw bytes from server b'(...)' ``` """ @validate_hf_hub_args @_deprecate_method( version="1.0", message=( "`InferenceApi` client is deprecated in favor of the more feature-complete `InferenceClient`. Check out" " this guide to learn how to convert your script to use it:" " https://huggingface.co/docs/huggingface_hub/guides/inference#legacy-inferenceapi-client." ), ) def __init__( self, repo_id: str, task: Optional[str] = None, token: Optional[str] = None, gpu: bool = False, ): """Inits headers and API call information. Args: repo_id (``str``): Id of repository (e.g. `user/bert-base-uncased`). task (``str``, `optional`, defaults ``None``): Whether to force a task instead of using task specified in the repository. token (`str`, `optional`): The API token to use as HTTP bearer authorization. This is not the authentication token. You can find the token in https://huggingface.co/settings/token. Alternatively, you can find both your organizations and personal API tokens using `HfApi().whoami(token)`. gpu (`bool`, `optional`, defaults `False`): Whether to use GPU instead of CPU for inference(requires Startup plan at least). """ self.options = {"wait_for_model": True, "use_gpu": gpu} self.headers = build_hf_headers(token=token) # Configure task model_info = HfApi(token=token).model_info(repo_id=repo_id) if not model_info.pipeline_tag and not task: raise ValueError( "Task not specified in the repository. Please add it to the model card" " using pipeline_tag" " (https://huggingface.co/docs#how-is-a-models-type-of-inference-api-and-widget-determined)" ) if task and task != model_info.pipeline_tag: if task not in ALL_TASKS: raise ValueError(f"Invalid task {task}. Make sure it's valid.") logger.warning( "You're using a different task than the one specified in the" " repository. Be sure to know what you're doing :)" ) self.task = task else: assert model_info.pipeline_tag is not None, "Pipeline tag cannot be None" self.task = model_info.pipeline_tag self.api_url = f"{constants.INFERENCE_ENDPOINT}/pipeline/{self.task}/{repo_id}" def __repr__(self): # Do not add headers to repr to avoid leaking token. return f"InferenceAPI(api_url='{self.api_url}', task='{self.task}', options={self.options})" def __call__( self, inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None, params: Optional[Dict] = None, data: Optional[bytes] = None, raw_response: bool = False, ) -> Any: """Make a call to the Inference API. Args: inputs (`str` or `Dict` or `List[str]` or `List[List[str]]`, *optional*): Inputs for the prediction. params (`Dict`, *optional*): Additional parameters for the models. Will be sent as `parameters` in the payload. data (`bytes`, *optional*): Bytes content of the request. In this case, leave `inputs` and `params` empty. raw_response (`bool`, defaults to `False`): If `True`, the raw `Response` object is returned. You can parse its content as preferred. By default, the content is parsed into a more practical format (json dictionary or PIL Image for example). """ # Build payload payload: Dict[str, Any] = { "options": self.options, } if inputs: payload["inputs"] = inputs if params: payload["parameters"] = params # Make API call response = get_session().post(self.api_url, headers=self.headers, json=payload, data=data) # Let the user handle the response if raw_response: return response # By default, parse the response for the user. content_type = response.headers.get("Content-Type") or "" if content_type.startswith("image"): if not is_pillow_available(): raise ImportError( f"Task '{self.task}' returned as image but Pillow is not installed." " Please install it (`pip install Pillow`) or pass" " `raw_response=True` to get the raw `Response` object and parse" " the image by yourself." ) from PIL import Image return Image.open(io.BytesIO(response.content)) elif content_type == "application/json": return response.json() else: raise NotImplementedError( f"{content_type} output type is not implemented yet. You can pass" " `raw_response=True` to get the raw `Response` object and parse the" " output by yourself." ) huggingface_hub-0.31.1/src/huggingface_hub/keras_mixin.py000066400000000000000000000461661500667546600235140ustar00rootroot00000000000000import collections.abc as collections import json import os import warnings from functools import wraps from pathlib import Path from shutil import copytree from typing import Any, Dict, List, Optional, Union from huggingface_hub import ModelHubMixin, snapshot_download from huggingface_hub.utils import ( get_tf_version, is_graphviz_available, is_pydot_available, is_tf_available, yaml_dump, ) from . import constants from .hf_api import HfApi from .utils import SoftTemporaryDirectory, logging, validate_hf_hub_args from .utils._typing import CallableT logger = logging.get_logger(__name__) keras = None if is_tf_available(): # Depending on which version of TensorFlow is installed, we need to import # keras from the correct location. # See https://github.com/tensorflow/tensorflow/releases/tag/v2.16.1. # Note: saving a keras model only works with Keras<3.0. try: import tf_keras as keras # type: ignore except ImportError: import tensorflow as tf # type: ignore keras = tf.keras def _requires_keras_2_model(fn: CallableT) -> CallableT: # Wrapper to raise if user tries to save a Keras 3.x model @wraps(fn) def _inner(model, *args, **kwargs): if not hasattr(model, "history"): # hacky way to check if model is Keras 2.x raise NotImplementedError( f"Cannot use '{fn.__name__}': Keras 3.x is not supported." " Please save models manually and upload them using `upload_folder` or `huggingface-cli upload`." ) return fn(model, *args, **kwargs) return _inner # type: ignore [return-value] def _flatten_dict(dictionary, parent_key=""): """Flatten a nested dictionary. Reference: https://stackoverflow.com/a/6027615/10319735 Args: dictionary (`dict`): The nested dictionary to be flattened. parent_key (`str`): The parent key to be prefixed to the children keys. Necessary for recursing over the nested dictionary. Returns: The flattened dictionary. """ items = [] for key, value in dictionary.items(): new_key = f"{parent_key}.{key}" if parent_key else key if isinstance(value, collections.MutableMapping): items.extend( _flatten_dict( value, new_key, ).items() ) else: items.append((new_key, value)) return dict(items) def _create_hyperparameter_table(model): """Parse hyperparameter dictionary into a markdown table.""" table = None if model.optimizer is not None: optimizer_params = model.optimizer.get_config() # flatten the configuration optimizer_params = _flatten_dict(optimizer_params) optimizer_params["training_precision"] = keras.mixed_precision.global_policy().name table = "| Hyperparameters | Value |\n| :-- | :-- |\n" for key, value in optimizer_params.items(): table += f"| {key} | {value} |\n" return table def _plot_network(model, save_directory): keras.utils.plot_model( model, to_file=f"{save_directory}/model.png", show_shapes=False, show_dtype=False, show_layer_names=True, rankdir="TB", expand_nested=False, dpi=96, layer_range=None, ) def _create_model_card( model, repo_dir: Path, plot_model: bool = True, metadata: Optional[dict] = None, ): """ Creates a model card for the repository. Do not overwrite an existing README.md file. """ readme_path = repo_dir / "README.md" if readme_path.exists(): return hyperparameters = _create_hyperparameter_table(model) if plot_model and is_graphviz_available() and is_pydot_available(): _plot_network(model, repo_dir) if metadata is None: metadata = {} metadata["library_name"] = "keras" model_card: str = "---\n" model_card += yaml_dump(metadata, default_flow_style=False) model_card += "---\n" model_card += "\n## Model description\n\nMore information needed\n" model_card += "\n## Intended uses & limitations\n\nMore information needed\n" model_card += "\n## Training and evaluation data\n\nMore information needed\n" if hyperparameters is not None: model_card += "\n## Training procedure\n" model_card += "\n### Training hyperparameters\n" model_card += "\nThe following hyperparameters were used during training:\n\n" model_card += hyperparameters model_card += "\n" if plot_model and os.path.exists(f"{repo_dir}/model.png"): model_card += "\n ## Model Plot\n" model_card += "\n
" model_card += "\nView Model Plot\n" path_to_plot = "./model.png" model_card += f"\n![Model Image]({path_to_plot})\n" model_card += "\n
" readme_path.write_text(model_card) @_requires_keras_2_model def save_pretrained_keras( model, save_directory: Union[str, Path], config: Optional[Dict[str, Any]] = None, include_optimizer: bool = False, plot_model: bool = True, tags: Optional[Union[list, str]] = None, **model_save_kwargs, ): """ Saves a Keras model to save_directory in SavedModel format. Use this if you're using the Functional or Sequential APIs. Args: model (`Keras.Model`): The [Keras model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) you'd like to save. The model must be compiled and built. save_directory (`str` or `Path`): Specify directory in which you want to save the Keras model. config (`dict`, *optional*): Configuration object to be saved alongside the model weights. include_optimizer(`bool`, *optional*, defaults to `False`): Whether or not to include optimizer in serialization. plot_model (`bool`, *optional*, defaults to `True`): Setting this to `True` will plot the model and put it in the model card. Requires graphviz and pydot to be installed. tags (Union[`str`,`list`], *optional*): List of tags that are related to model or string of a single tag. See example tags [here](https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1). model_save_kwargs(`dict`, *optional*): model_save_kwargs will be passed to [`tf.keras.models.save_model()`](https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model). """ if keras is None: raise ImportError("Called a Tensorflow-specific function but could not import it.") if not model.built: raise ValueError("Model should be built before trying to save") save_directory = Path(save_directory) save_directory.mkdir(parents=True, exist_ok=True) # saving config if config: if not isinstance(config, dict): raise RuntimeError(f"Provided config to save_pretrained_keras should be a dict. Got: '{type(config)}'") with (save_directory / constants.CONFIG_NAME).open("w") as f: json.dump(config, f) metadata = {} if isinstance(tags, list): metadata["tags"] = tags elif isinstance(tags, str): metadata["tags"] = [tags] task_name = model_save_kwargs.pop("task_name", None) if task_name is not None: warnings.warn( "`task_name` input argument is deprecated. Pass `tags` instead.", FutureWarning, ) if "tags" in metadata: metadata["tags"].append(task_name) else: metadata["tags"] = [task_name] if model.history is not None: if model.history.history != {}: path = save_directory / "history.json" if path.exists(): warnings.warn( "`history.json` file already exists, it will be overwritten by the history of this version.", UserWarning, ) with path.open("w", encoding="utf-8") as f: json.dump(model.history.history, f, indent=2, sort_keys=True) _create_model_card(model, save_directory, plot_model, metadata) keras.models.save_model(model, save_directory, include_optimizer=include_optimizer, **model_save_kwargs) def from_pretrained_keras(*args, **kwargs) -> "KerasModelHubMixin": r""" Instantiate a pretrained Keras model from a pre-trained model from the Hub. The model is expected to be in `SavedModel` format. Args: pretrained_model_name_or_path (`str` or `os.PathLike`): Can be either: - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`. - You can add `revision` by appending `@` at the end of model_id simply like this: `dbmdz/bert-base-german-cased@main` Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. - A path to a `directory` containing model weights saved using [`~transformers.PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - `None` if you are both providing the configuration and state dictionary (resp. with keyword arguments `config` and `state_dict`). force_download (`bool`, *optional*, defaults to `False`): Whether to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. token (`str` or `bool`, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `transformers-cli login` (stored in `~/.huggingface`). cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. local_files_only(`bool`, *optional*, defaults to `False`): Whether to only look at local files (i.e., do not try to download the model). model_kwargs (`Dict`, *optional*): model_kwargs will be passed to the model during initialization Passing `token=True` is required when you want to use a private model. """ return KerasModelHubMixin.from_pretrained(*args, **kwargs) @validate_hf_hub_args @_requires_keras_2_model def push_to_hub_keras( model, repo_id: str, *, config: Optional[dict] = None, commit_message: str = "Push Keras model using huggingface_hub.", private: Optional[bool] = None, api_endpoint: Optional[str] = None, token: Optional[str] = None, branch: Optional[str] = None, create_pr: Optional[bool] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, delete_patterns: Optional[Union[List[str], str]] = None, log_dir: Optional[str] = None, include_optimizer: bool = False, tags: Optional[Union[list, str]] = None, plot_model: bool = True, **model_save_kwargs, ): """ Upload model checkpoint to the Hub. Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more details. Args: model (`Keras.Model`): The [Keras model](`https://www.tensorflow.org/api_docs/python/tf/keras/Model`) you'd like to push to the Hub. The model must be compiled and built. repo_id (`str`): ID of the repository to push to (example: `"username/my-model"`). commit_message (`str`, *optional*, defaults to "Add Keras model"): Message to commit while pushing. private (`bool`, *optional*): Whether the repository created should be private. If `None` (default), the repo will be public unless the organization's default is private. api_endpoint (`str`, *optional*): The API endpoint to use when pushing the model to the hub. token (`str`, *optional*): The token to use as HTTP bearer authorization for remote files. If not set, will use the token set when logging in with `huggingface-cli login` (stored in `~/.huggingface`). branch (`str`, *optional*): The git branch on which to push the model. This defaults to the default branch as specified in your repository, which defaults to `"main"`. create_pr (`boolean`, *optional*): Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`. config (`dict`, *optional*): Configuration object to be saved alongside the model weights. allow_patterns (`List[str]` or `str`, *optional*): If provided, only files matching at least one pattern are pushed. ignore_patterns (`List[str]` or `str`, *optional*): If provided, files matching any of the patterns are not pushed. delete_patterns (`List[str]` or `str`, *optional*): If provided, remote files matching any of the patterns will be deleted from the repo. log_dir (`str`, *optional*): TensorBoard logging directory to be pushed. The Hub automatically hosts and displays a TensorBoard instance if log files are included in the repository. include_optimizer (`bool`, *optional*, defaults to `False`): Whether or not to include optimizer during serialization. tags (Union[`list`, `str`], *optional*): List of tags that are related to model or string of a single tag. See example tags [here](https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1). plot_model (`bool`, *optional*, defaults to `True`): Setting this to `True` will plot the model and put it in the model card. Requires graphviz and pydot to be installed. model_save_kwargs(`dict`, *optional*): model_save_kwargs will be passed to [`tf.keras.models.save_model()`](https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model). Returns: The url of the commit of your model in the given repository. """ api = HfApi(endpoint=api_endpoint) repo_id = api.create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True).repo_id # Push the files to the repo in a single commit with SoftTemporaryDirectory() as tmp: saved_path = Path(tmp) / repo_id save_pretrained_keras( model, saved_path, config=config, include_optimizer=include_optimizer, tags=tags, plot_model=plot_model, **model_save_kwargs, ) # If `log_dir` provided, delete remote logs and upload new ones if log_dir is not None: delete_patterns = ( [] if delete_patterns is None else ( [delete_patterns] # convert `delete_patterns` to a list if isinstance(delete_patterns, str) else delete_patterns ) ) delete_patterns.append("logs/*") copytree(log_dir, saved_path / "logs") return api.upload_folder( repo_type="model", repo_id=repo_id, folder_path=saved_path, commit_message=commit_message, token=token, revision=branch, create_pr=create_pr, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, delete_patterns=delete_patterns, ) class KerasModelHubMixin(ModelHubMixin): """ Implementation of [`ModelHubMixin`] to provide model Hub upload/download capabilities to Keras models. ```python >>> import tensorflow as tf >>> from huggingface_hub import KerasModelHubMixin >>> class MyModel(tf.keras.Model, KerasModelHubMixin): ... def __init__(self, **kwargs): ... super().__init__() ... self.config = kwargs.pop("config", None) ... self.dummy_inputs = ... ... self.layer = ... ... def call(self, *args): ... return ... >>> # Initialize and compile the model as you normally would >>> model = MyModel() >>> model.compile(...) >>> # Build the graph by training it or passing dummy inputs >>> _ = model(model.dummy_inputs) >>> # Save model weights to local directory >>> model.save_pretrained("my-awesome-model") >>> # Push model weights to the Hub >>> model.push_to_hub("my-awesome-model") >>> # Download and initialize weights from the Hub >>> model = MyModel.from_pretrained("username/super-cool-model") ``` """ def _save_pretrained(self, save_directory): save_pretrained_keras(self, save_directory) @classmethod def _from_pretrained( cls, model_id, revision, cache_dir, force_download, proxies, resume_download, local_files_only, token, config: Optional[Dict[str, Any]] = None, **model_kwargs, ): """Here we just call [`from_pretrained_keras`] function so both the mixin and functional APIs stay in sync. TODO - Some args above aren't used since we are calling snapshot_download instead of hf_hub_download. """ if keras is None: raise ImportError("Called a TensorFlow-specific function but could not import it.") # Root is either a local filepath matching model_id or a cached snapshot if not os.path.isdir(model_id): storage_folder = snapshot_download( repo_id=model_id, revision=revision, cache_dir=cache_dir, library_name="keras", library_version=get_tf_version(), ) else: storage_folder = model_id # TODO: change this in a future PR. We are not returning a KerasModelHubMixin instance here... model = keras.models.load_model(storage_folder) # For now, we add a new attribute, config, to store the config loaded from the hub/a local dir. model.config = config return model huggingface_hub-0.31.1/src/huggingface_hub/lfs.py000066400000000000000000000404111500667546600217520ustar00rootroot00000000000000# coding=utf-8 # Copyright 2019-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Git LFS related type definitions and utilities""" import inspect import io import re import warnings from dataclasses import dataclass from math import ceil from os.path import getsize from pathlib import Path from typing import TYPE_CHECKING, BinaryIO, Dict, Iterable, List, Optional, Tuple, TypedDict from urllib.parse import unquote from huggingface_hub import constants from .utils import ( build_hf_headers, fix_hf_endpoint_in_url, get_session, hf_raise_for_status, http_backoff, logging, tqdm, validate_hf_hub_args, ) from .utils._lfs import SliceFileObj from .utils.sha import sha256, sha_fileobj from .utils.tqdm import is_tqdm_disabled if TYPE_CHECKING: from ._commit_api import CommitOperationAdd logger = logging.get_logger(__name__) OID_REGEX = re.compile(r"^[0-9a-f]{40}$") LFS_MULTIPART_UPLOAD_COMMAND = "lfs-multipart-upload" LFS_HEADERS = { "Accept": "application/vnd.git-lfs+json", "Content-Type": "application/vnd.git-lfs+json", } @dataclass class UploadInfo: """ Dataclass holding required information to determine whether a blob should be uploaded to the hub using the LFS protocol or the regular protocol Args: sha256 (`bytes`): SHA256 hash of the blob size (`int`): Size in bytes of the blob sample (`bytes`): First 512 bytes of the blob """ sha256: bytes size: int sample: bytes @classmethod def from_path(cls, path: str): size = getsize(path) with io.open(path, "rb") as file: sample = file.peek(512)[:512] sha = sha_fileobj(file) return cls(size=size, sha256=sha, sample=sample) @classmethod def from_bytes(cls, data: bytes): sha = sha256(data).digest() return cls(size=len(data), sample=data[:512], sha256=sha) @classmethod def from_fileobj(cls, fileobj: BinaryIO): sample = fileobj.read(512) fileobj.seek(0, io.SEEK_SET) sha = sha_fileobj(fileobj) size = fileobj.tell() fileobj.seek(0, io.SEEK_SET) return cls(size=size, sha256=sha, sample=sample) @validate_hf_hub_args def post_lfs_batch_info( upload_infos: Iterable[UploadInfo], token: Optional[str], repo_type: str, repo_id: str, revision: Optional[str] = None, endpoint: Optional[str] = None, headers: Optional[Dict[str, str]] = None, ) -> Tuple[List[dict], List[dict]]: """ Requests the LFS batch endpoint to retrieve upload instructions Learn more: https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md Args: upload_infos (`Iterable` of `UploadInfo`): `UploadInfo` for the files that are being uploaded, typically obtained from `CommitOperationAdd.upload_info` repo_type (`str`): Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`. repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. revision (`str`, *optional*): The git revision to upload to. headers (`dict`, *optional*): Additional headers to include in the request Returns: `LfsBatchInfo`: 2-tuple: - First element is the list of upload instructions from the server - Second element is an list of errors, if any Raises: [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If an argument is invalid or the server response is malformed. [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) If the server returned an error. """ endpoint = endpoint if endpoint is not None else constants.ENDPOINT url_prefix = "" if repo_type in constants.REPO_TYPES_URL_PREFIXES: url_prefix = constants.REPO_TYPES_URL_PREFIXES[repo_type] batch_url = f"{endpoint}/{url_prefix}{repo_id}.git/info/lfs/objects/batch" payload: Dict = { "operation": "upload", "transfers": ["basic", "multipart"], "objects": [ { "oid": upload.sha256.hex(), "size": upload.size, } for upload in upload_infos ], "hash_algo": "sha256", } if revision is not None: payload["ref"] = {"name": unquote(revision)} # revision has been previously 'quoted' headers = { **LFS_HEADERS, **build_hf_headers(token=token), **(headers or {}), } resp = get_session().post(batch_url, headers=headers, json=payload) hf_raise_for_status(resp) batch_info = resp.json() objects = batch_info.get("objects", None) if not isinstance(objects, list): raise ValueError("Malformed response from server") return ( [_validate_batch_actions(obj) for obj in objects if "error" not in obj], [_validate_batch_error(obj) for obj in objects if "error" in obj], ) class PayloadPartT(TypedDict): partNumber: int etag: str class CompletionPayloadT(TypedDict): """Payload that will be sent to the Hub when uploading multi-part.""" oid: str parts: List[PayloadPartT] def lfs_upload( operation: "CommitOperationAdd", lfs_batch_action: Dict, token: Optional[str] = None, headers: Optional[Dict[str, str]] = None, endpoint: Optional[str] = None, ) -> None: """ Handles uploading a given object to the Hub with the LFS protocol. Can be a No-op if the content of the file is already present on the hub large file storage. Args: operation (`CommitOperationAdd`): The add operation triggering this upload. lfs_batch_action (`dict`): Upload instructions from the LFS batch endpoint for this object. See [`~utils.lfs.post_lfs_batch_info`] for more details. headers (`dict`, *optional*): Headers to include in the request, including authentication and user agent headers. Raises: [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If `lfs_batch_action` is improperly formatted [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) If the upload resulted in an error """ # 0. If LFS file is already present, skip upload _validate_batch_actions(lfs_batch_action) actions = lfs_batch_action.get("actions") if actions is None: # The file was already uploaded logger.debug(f"Content of file {operation.path_in_repo} is already present upstream - skipping upload") return # 1. Validate server response (check required keys in dict) upload_action = lfs_batch_action["actions"]["upload"] _validate_lfs_action(upload_action) verify_action = lfs_batch_action["actions"].get("verify") if verify_action is not None: _validate_lfs_action(verify_action) # 2. Upload file (either single part or multi-part) header = upload_action.get("header", {}) chunk_size = header.get("chunk_size") upload_url = fix_hf_endpoint_in_url(upload_action["href"], endpoint=endpoint) if chunk_size is not None: try: chunk_size = int(chunk_size) except (ValueError, TypeError): raise ValueError( f"Malformed response from LFS batch endpoint: `chunk_size` should be an integer. Got '{chunk_size}'." ) _upload_multi_part(operation=operation, header=header, chunk_size=chunk_size, upload_url=upload_url) else: _upload_single_part(operation=operation, upload_url=upload_url) # 3. Verify upload went well if verify_action is not None: _validate_lfs_action(verify_action) verify_url = fix_hf_endpoint_in_url(verify_action["href"], endpoint) verify_resp = get_session().post( verify_url, headers=build_hf_headers(token=token, headers=headers), json={"oid": operation.upload_info.sha256.hex(), "size": operation.upload_info.size}, ) hf_raise_for_status(verify_resp) logger.debug(f"{operation.path_in_repo}: Upload successful") def _validate_lfs_action(lfs_action: dict): """validates response from the LFS batch endpoint""" if not ( isinstance(lfs_action.get("href"), str) and (lfs_action.get("header") is None or isinstance(lfs_action.get("header"), dict)) ): raise ValueError("lfs_action is improperly formatted") return lfs_action def _validate_batch_actions(lfs_batch_actions: dict): """validates response from the LFS batch endpoint""" if not (isinstance(lfs_batch_actions.get("oid"), str) and isinstance(lfs_batch_actions.get("size"), int)): raise ValueError("lfs_batch_actions is improperly formatted") upload_action = lfs_batch_actions.get("actions", {}).get("upload") verify_action = lfs_batch_actions.get("actions", {}).get("verify") if upload_action is not None: _validate_lfs_action(upload_action) if verify_action is not None: _validate_lfs_action(verify_action) return lfs_batch_actions def _validate_batch_error(lfs_batch_error: dict): """validates response from the LFS batch endpoint""" if not (isinstance(lfs_batch_error.get("oid"), str) and isinstance(lfs_batch_error.get("size"), int)): raise ValueError("lfs_batch_error is improperly formatted") error_info = lfs_batch_error.get("error") if not ( isinstance(error_info, dict) and isinstance(error_info.get("message"), str) and isinstance(error_info.get("code"), int) ): raise ValueError("lfs_batch_error is improperly formatted") return lfs_batch_error def _upload_single_part(operation: "CommitOperationAdd", upload_url: str) -> None: """ Uploads `fileobj` as a single PUT HTTP request (basic LFS transfer protocol) Args: upload_url (`str`): The URL to PUT the file to. fileobj: The file-like object holding the data to upload. Returns: `requests.Response` Raises: [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) If the upload resulted in an error. """ with operation.as_file(with_tqdm=True) as fileobj: # S3 might raise a transient 500 error -> let's retry if that happens response = http_backoff("PUT", upload_url, data=fileobj, retry_on_status_codes=(500, 502, 503, 504)) hf_raise_for_status(response) def _upload_multi_part(operation: "CommitOperationAdd", header: Dict, chunk_size: int, upload_url: str) -> None: """ Uploads file using HF multipart LFS transfer protocol. """ # 1. Get upload URLs for each part sorted_parts_urls = _get_sorted_parts_urls(header=header, upload_info=operation.upload_info, chunk_size=chunk_size) # 2. Upload parts (either with hf_transfer or in pure Python) use_hf_transfer = constants.HF_HUB_ENABLE_HF_TRANSFER if ( constants.HF_HUB_ENABLE_HF_TRANSFER and not isinstance(operation.path_or_fileobj, str) and not isinstance(operation.path_or_fileobj, Path) ): warnings.warn( "hf_transfer is enabled but does not support uploading from bytes or BinaryIO, falling back to regular" " upload" ) use_hf_transfer = False response_headers = ( _upload_parts_hf_transfer(operation=operation, sorted_parts_urls=sorted_parts_urls, chunk_size=chunk_size) if use_hf_transfer else _upload_parts_iteratively(operation=operation, sorted_parts_urls=sorted_parts_urls, chunk_size=chunk_size) ) # 3. Send completion request completion_res = get_session().post( upload_url, json=_get_completion_payload(response_headers, operation.upload_info.sha256.hex()), headers=LFS_HEADERS, ) hf_raise_for_status(completion_res) def _get_sorted_parts_urls(header: Dict, upload_info: UploadInfo, chunk_size: int) -> List[str]: sorted_part_upload_urls = [ upload_url for _, upload_url in sorted( [ (int(part_num, 10), upload_url) for part_num, upload_url in header.items() if part_num.isdigit() and len(part_num) > 0 ], key=lambda t: t[0], ) ] num_parts = len(sorted_part_upload_urls) if num_parts != ceil(upload_info.size / chunk_size): raise ValueError("Invalid server response to upload large LFS file") return sorted_part_upload_urls def _get_completion_payload(response_headers: List[Dict], oid: str) -> CompletionPayloadT: parts: List[PayloadPartT] = [] for part_number, header in enumerate(response_headers): etag = header.get("etag") if etag is None or etag == "": raise ValueError(f"Invalid etag (`{etag}`) returned for part {part_number + 1}") parts.append( { "partNumber": part_number + 1, "etag": etag, } ) return {"oid": oid, "parts": parts} def _upload_parts_iteratively( operation: "CommitOperationAdd", sorted_parts_urls: List[str], chunk_size: int ) -> List[Dict]: headers = [] with operation.as_file(with_tqdm=True) as fileobj: for part_idx, part_upload_url in enumerate(sorted_parts_urls): with SliceFileObj( fileobj, seek_from=chunk_size * part_idx, read_limit=chunk_size, ) as fileobj_slice: # S3 might raise a transient 500 error -> let's retry if that happens part_upload_res = http_backoff( "PUT", part_upload_url, data=fileobj_slice, retry_on_status_codes=(500, 502, 503, 504) ) hf_raise_for_status(part_upload_res) headers.append(part_upload_res.headers) return headers # type: ignore def _upload_parts_hf_transfer( operation: "CommitOperationAdd", sorted_parts_urls: List[str], chunk_size: int ) -> List[Dict]: # Upload file using an external Rust-based package. Upload is faster but support less features (no progress bars). try: from hf_transfer import multipart_upload except ImportError: raise ValueError( "Fast uploading using 'hf_transfer' is enabled (HF_HUB_ENABLE_HF_TRANSFER=1) but 'hf_transfer' package is" " not available in your environment. Try `pip install hf_transfer`." ) supports_callback = "callback" in inspect.signature(multipart_upload).parameters if not supports_callback: warnings.warn( "You are using an outdated version of `hf_transfer`. Consider upgrading to latest version to enable progress bars using `pip install -U hf_transfer`." ) total = operation.upload_info.size desc = operation.path_in_repo if len(desc) > 40: desc = f"(…){desc[-40:]}" with tqdm( unit="B", unit_scale=True, total=total, initial=0, desc=desc, disable=is_tqdm_disabled(logger.getEffectiveLevel()), name="huggingface_hub.lfs_upload", ) as progress: try: output = multipart_upload( file_path=operation.path_or_fileobj, parts_urls=sorted_parts_urls, chunk_size=chunk_size, max_files=128, parallel_failures=127, # could be removed max_retries=5, **({"callback": progress.update} if supports_callback else {}), ) except Exception as e: raise RuntimeError( "An error occurred while uploading using `hf_transfer`. Consider disabling HF_HUB_ENABLE_HF_TRANSFER for" " better error handling." ) from e if not supports_callback: progress.update(total) return output huggingface_hub-0.31.1/src/huggingface_hub/py.typed000066400000000000000000000000001500667546600223010ustar00rootroot00000000000000huggingface_hub-0.31.1/src/huggingface_hub/repocard.py000066400000000000000000001036551500667546600227770ustar00rootroot00000000000000import os import re from pathlib import Path from typing import Any, Dict, Literal, Optional, Type, Union import requests import yaml from huggingface_hub.file_download import hf_hub_download from huggingface_hub.hf_api import upload_file from huggingface_hub.repocard_data import ( CardData, DatasetCardData, EvalResult, ModelCardData, SpaceCardData, eval_results_to_model_index, model_index_to_eval_results, ) from huggingface_hub.utils import get_session, is_jinja_available, yaml_dump from . import constants from .errors import EntryNotFoundError from .utils import SoftTemporaryDirectory, logging, validate_hf_hub_args logger = logging.get_logger(__name__) TEMPLATE_MODELCARD_PATH = Path(__file__).parent / "templates" / "modelcard_template.md" TEMPLATE_DATASETCARD_PATH = Path(__file__).parent / "templates" / "datasetcard_template.md" # exact same regex as in the Hub server. Please keep in sync. # See https://github.com/huggingface/moon-landing/blob/main/server/lib/ViewMarkdown.ts#L18 REGEX_YAML_BLOCK = re.compile(r"^(\s*---[\r\n]+)([\S\s]*?)([\r\n]+---(\r\n|\n|$))") class RepoCard: card_data_class = CardData default_template_path = TEMPLATE_MODELCARD_PATH repo_type = "model" def __init__(self, content: str, ignore_metadata_errors: bool = False): """Initialize a RepoCard from string content. The content should be a Markdown file with a YAML block at the beginning and a Markdown body. Args: content (`str`): The content of the Markdown file. Example: ```python >>> from huggingface_hub.repocard import RepoCard >>> text = ''' ... --- ... language: en ... license: mit ... --- ... ... # My repo ... ''' >>> card = RepoCard(text) >>> card.data.to_dict() {'language': 'en', 'license': 'mit'} >>> card.text '\\n# My repo\\n' ``` Raises the following error: - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) when the content of the repo card metadata is not a dictionary. """ # Set the content of the RepoCard, as well as underlying .data and .text attributes. # See the `content` property setter for more details. self.ignore_metadata_errors = ignore_metadata_errors self.content = content @property def content(self): """The content of the RepoCard, including the YAML block and the Markdown body.""" line_break = _detect_line_ending(self._content) or "\n" return f"---{line_break}{self.data.to_yaml(line_break=line_break, original_order=self._original_order)}{line_break}---{line_break}{self.text}" @content.setter def content(self, content: str): """Set the content of the RepoCard.""" self._content = content match = REGEX_YAML_BLOCK.search(content) if match: # Metadata found in the YAML block yaml_block = match.group(2) self.text = content[match.end() :] data_dict = yaml.safe_load(yaml_block) if data_dict is None: data_dict = {} # The YAML block's data should be a dictionary if not isinstance(data_dict, dict): raise ValueError("repo card metadata block should be a dict") else: # Model card without metadata... create empty metadata logger.warning("Repo card metadata block was not found. Setting CardData to empty.") data_dict = {} self.text = content self.data = self.card_data_class(**data_dict, ignore_metadata_errors=self.ignore_metadata_errors) self._original_order = list(data_dict.keys()) def __str__(self): return self.content def save(self, filepath: Union[Path, str]): r"""Save a RepoCard to a file. Args: filepath (`Union[Path, str]`): Filepath to the markdown file to save. Example: ```python >>> from huggingface_hub.repocard import RepoCard >>> card = RepoCard("---\nlanguage: en\n---\n# This is a test repo card") >>> card.save("/tmp/test.md") ``` """ filepath = Path(filepath) filepath.parent.mkdir(parents=True, exist_ok=True) # Preserve newlines as in the existing file. with open(filepath, mode="w", newline="", encoding="utf-8") as f: f.write(str(self)) @classmethod def load( cls, repo_id_or_path: Union[str, Path], repo_type: Optional[str] = None, token: Optional[str] = None, ignore_metadata_errors: bool = False, ): """Initialize a RepoCard from a Hugging Face Hub repo's README.md or a local filepath. Args: repo_id_or_path (`Union[str, Path]`): The repo ID associated with a Hugging Face Hub repo or a local filepath. repo_type (`str`, *optional*): The type of Hugging Face repo to push to. Defaults to None, which will use use "model". Other options are "dataset" and "space". Not used when loading from a local filepath. If this is called from a child class, the default value will be the child class's `repo_type`. token (`str`, *optional*): Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to the stored token. ignore_metadata_errors (`str`): If True, errors while parsing the metadata section will be ignored. Some information might be lost during the process. Use it at your own risk. Returns: [`huggingface_hub.repocard.RepoCard`]: The RepoCard (or subclass) initialized from the repo's README.md file or filepath. Example: ```python >>> from huggingface_hub.repocard import RepoCard >>> card = RepoCard.load("nateraw/food") >>> assert card.data.tags == ["generated_from_trainer", "image-classification", "pytorch"] ``` """ if Path(repo_id_or_path).is_file(): card_path = Path(repo_id_or_path) elif isinstance(repo_id_or_path, str): card_path = Path( hf_hub_download( repo_id_or_path, constants.REPOCARD_NAME, repo_type=repo_type or cls.repo_type, token=token, ) ) else: raise ValueError(f"Cannot load RepoCard: path not found on disk ({repo_id_or_path}).") # Preserve newlines in the existing file. with card_path.open(mode="r", newline="", encoding="utf-8") as f: return cls(f.read(), ignore_metadata_errors=ignore_metadata_errors) def validate(self, repo_type: Optional[str] = None): """Validates card against Hugging Face Hub's card validation logic. Using this function requires access to the internet, so it is only called internally by [`huggingface_hub.repocard.RepoCard.push_to_hub`]. Args: repo_type (`str`, *optional*, defaults to "model"): The type of Hugging Face repo to push to. Options are "model", "dataset", and "space". If this function is called from a child class, the default will be the child class's `repo_type`. Raises the following errors: - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if the card fails validation checks. - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) if the request to the Hub API fails for any other reason. """ # If repo type is provided, otherwise, use the repo type of the card. repo_type = repo_type or self.repo_type body = { "repoType": repo_type, "content": str(self), } headers = {"Accept": "text/plain"} try: r = get_session().post("https://huggingface.co/api/validate-yaml", body, headers=headers) r.raise_for_status() except requests.exceptions.HTTPError as exc: if r.status_code == 400: raise ValueError(r.text) else: raise exc def push_to_hub( self, repo_id: str, token: Optional[str] = None, repo_type: Optional[str] = None, commit_message: Optional[str] = None, commit_description: Optional[str] = None, revision: Optional[str] = None, create_pr: Optional[bool] = None, parent_commit: Optional[str] = None, ): """Push a RepoCard to a Hugging Face Hub repo. Args: repo_id (`str`): The repo ID of the Hugging Face Hub repo to push to. Example: "nateraw/food". token (`str`, *optional*): Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to the stored token. repo_type (`str`, *optional*, defaults to "model"): The type of Hugging Face repo to push to. Options are "model", "dataset", and "space". If this function is called by a child class, it will default to the child class's `repo_type`. commit_message (`str`, *optional*): The summary / title / first line of the generated commit. commit_description (`str`, *optional*) The description of the generated commit. revision (`str`, *optional*): The git revision to commit from. Defaults to the head of the `"main"` branch. create_pr (`bool`, *optional*): Whether or not to create a Pull Request with this commit. Defaults to `False`. parent_commit (`str`, *optional*): The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be especially useful if the repo is updated / committed to concurrently. Returns: `str`: URL of the commit which updated the card metadata. """ # If repo type is provided, otherwise, use the repo type of the card. repo_type = repo_type or self.repo_type # Validate card before pushing to hub self.validate(repo_type=repo_type) with SoftTemporaryDirectory() as tmpdir: tmp_path = Path(tmpdir) / constants.REPOCARD_NAME tmp_path.write_text(str(self)) url = upload_file( path_or_fileobj=str(tmp_path), path_in_repo=constants.REPOCARD_NAME, repo_id=repo_id, token=token, repo_type=repo_type, commit_message=commit_message, commit_description=commit_description, create_pr=create_pr, revision=revision, parent_commit=parent_commit, ) return url @classmethod def from_template( cls, card_data: CardData, template_path: Optional[str] = None, template_str: Optional[str] = None, **template_kwargs, ): """Initialize a RepoCard from a template. By default, it uses the default template. Templates are Jinja2 templates that can be customized by passing keyword arguments. Args: card_data (`huggingface_hub.CardData`): A huggingface_hub.CardData instance containing the metadata you want to include in the YAML header of the repo card on the Hugging Face Hub. template_path (`str`, *optional*): A path to a markdown file with optional Jinja template variables that can be filled in with `template_kwargs`. Defaults to the default template. Returns: [`huggingface_hub.repocard.RepoCard`]: A RepoCard instance with the specified card data and content from the template. """ if is_jinja_available(): import jinja2 else: raise ImportError( "Using RepoCard.from_template requires Jinja2 to be installed. Please" " install it with `pip install Jinja2`." ) kwargs = card_data.to_dict().copy() kwargs.update(template_kwargs) # Template_kwargs have priority if template_path is not None: template_str = Path(template_path).read_text() if template_str is None: template_str = Path(cls.default_template_path).read_text() template = jinja2.Template(template_str) content = template.render(card_data=card_data.to_yaml(), **kwargs) return cls(content) class ModelCard(RepoCard): card_data_class = ModelCardData default_template_path = TEMPLATE_MODELCARD_PATH repo_type = "model" @classmethod def from_template( # type: ignore # violates Liskov property but easier to use cls, card_data: ModelCardData, template_path: Optional[str] = None, template_str: Optional[str] = None, **template_kwargs, ): """Initialize a ModelCard from a template. By default, it uses the default template, which can be found here: https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/modelcard_template.md Templates are Jinja2 templates that can be customized by passing keyword arguments. Args: card_data (`huggingface_hub.ModelCardData`): A huggingface_hub.ModelCardData instance containing the metadata you want to include in the YAML header of the model card on the Hugging Face Hub. template_path (`str`, *optional*): A path to a markdown file with optional Jinja template variables that can be filled in with `template_kwargs`. Defaults to the default template. Returns: [`huggingface_hub.ModelCard`]: A ModelCard instance with the specified card data and content from the template. Example: ```python >>> from huggingface_hub import ModelCard, ModelCardData, EvalResult >>> # Using the Default Template >>> card_data = ModelCardData( ... language='en', ... license='mit', ... library_name='timm', ... tags=['image-classification', 'resnet'], ... datasets=['beans'], ... metrics=['accuracy'], ... ) >>> card = ModelCard.from_template( ... card_data, ... model_description='This model does x + y...' ... ) >>> # Including Evaluation Results >>> card_data = ModelCardData( ... language='en', ... tags=['image-classification', 'resnet'], ... eval_results=[ ... EvalResult( ... task_type='image-classification', ... dataset_type='beans', ... dataset_name='Beans', ... metric_type='accuracy', ... metric_value=0.9, ... ), ... ], ... model_name='my-cool-model', ... ) >>> card = ModelCard.from_template(card_data) >>> # Using a Custom Template >>> card_data = ModelCardData( ... language='en', ... tags=['image-classification', 'resnet'] ... ) >>> card = ModelCard.from_template( ... card_data=card_data, ... template_path='./src/huggingface_hub/templates/modelcard_template.md', ... custom_template_var='custom value', # will be replaced in template if it exists ... ) ``` """ return super().from_template(card_data, template_path, template_str, **template_kwargs) class DatasetCard(RepoCard): card_data_class = DatasetCardData default_template_path = TEMPLATE_DATASETCARD_PATH repo_type = "dataset" @classmethod def from_template( # type: ignore # violates Liskov property but easier to use cls, card_data: DatasetCardData, template_path: Optional[str] = None, template_str: Optional[str] = None, **template_kwargs, ): """Initialize a DatasetCard from a template. By default, it uses the default template, which can be found here: https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/datasetcard_template.md Templates are Jinja2 templates that can be customized by passing keyword arguments. Args: card_data (`huggingface_hub.DatasetCardData`): A huggingface_hub.DatasetCardData instance containing the metadata you want to include in the YAML header of the dataset card on the Hugging Face Hub. template_path (`str`, *optional*): A path to a markdown file with optional Jinja template variables that can be filled in with `template_kwargs`. Defaults to the default template. Returns: [`huggingface_hub.DatasetCard`]: A DatasetCard instance with the specified card data and content from the template. Example: ```python >>> from huggingface_hub import DatasetCard, DatasetCardData >>> # Using the Default Template >>> card_data = DatasetCardData( ... language='en', ... license='mit', ... annotations_creators='crowdsourced', ... task_categories=['text-classification'], ... task_ids=['sentiment-classification', 'text-scoring'], ... multilinguality='monolingual', ... pretty_name='My Text Classification Dataset', ... ) >>> card = DatasetCard.from_template( ... card_data, ... pretty_name=card_data.pretty_name, ... ) >>> # Using a Custom Template >>> card_data = DatasetCardData( ... language='en', ... license='mit', ... ) >>> card = DatasetCard.from_template( ... card_data=card_data, ... template_path='./src/huggingface_hub/templates/datasetcard_template.md', ... custom_template_var='custom value', # will be replaced in template if it exists ... ) ``` """ return super().from_template(card_data, template_path, template_str, **template_kwargs) class SpaceCard(RepoCard): card_data_class = SpaceCardData default_template_path = TEMPLATE_MODELCARD_PATH repo_type = "space" def _detect_line_ending(content: str) -> Literal["\r", "\n", "\r\n", None]: # noqa: F722 """Detect the line ending of a string. Used by RepoCard to avoid making huge diff on newlines. Uses same implementation as in Hub server, keep it in sync. Returns: str: The detected line ending of the string. """ cr = content.count("\r") lf = content.count("\n") crlf = content.count("\r\n") if cr + lf == 0: return None if crlf == cr and crlf == lf: return "\r\n" if cr > lf: return "\r" else: return "\n" def metadata_load(local_path: Union[str, Path]) -> Optional[Dict]: content = Path(local_path).read_text() match = REGEX_YAML_BLOCK.search(content) if match: yaml_block = match.group(2) data = yaml.safe_load(yaml_block) if data is None or isinstance(data, dict): return data raise ValueError("repo card metadata block should be a dict") else: return None def metadata_save(local_path: Union[str, Path], data: Dict) -> None: """ Save the metadata dict in the upper YAML part Trying to preserve newlines as in the existing file. Docs about open() with newline="" parameter: https://docs.python.org/3/library/functions.html?highlight=open#open Does not work with "^M" linebreaks, which are replaced by \n """ line_break = "\n" content = "" # try to detect existing newline character if os.path.exists(local_path): with open(local_path, "r", newline="", encoding="utf8") as readme: content = readme.read() if isinstance(readme.newlines, tuple): line_break = readme.newlines[0] elif isinstance(readme.newlines, str): line_break = readme.newlines # creates a new file if it not with open(local_path, "w", newline="", encoding="utf8") as readme: data_yaml = yaml_dump(data, sort_keys=False, line_break=line_break) # sort_keys: keep dict order match = REGEX_YAML_BLOCK.search(content) if match: output = content[: match.start()] + f"---{line_break}{data_yaml}---{line_break}" + content[match.end() :] else: output = f"---{line_break}{data_yaml}---{line_break}{content}" readme.write(output) readme.close() def metadata_eval_result( *, model_pretty_name: str, task_pretty_name: str, task_id: str, metrics_pretty_name: str, metrics_id: str, metrics_value: Any, dataset_pretty_name: str, dataset_id: str, metrics_config: Optional[str] = None, metrics_verified: bool = False, dataset_config: Optional[str] = None, dataset_split: Optional[str] = None, dataset_revision: Optional[str] = None, metrics_verification_token: Optional[str] = None, ) -> Dict: """ Creates a metadata dict with the result from a model evaluated on a dataset. Args: model_pretty_name (`str`): The name of the model in natural language. task_pretty_name (`str`): The name of a task in natural language. task_id (`str`): Example: automatic-speech-recognition. A task id. metrics_pretty_name (`str`): A name for the metric in natural language. Example: Test WER. metrics_id (`str`): Example: wer. A metric id from https://hf.co/metrics. metrics_value (`Any`): The value from the metric. Example: 20.0 or "20.0 ± 1.2". dataset_pretty_name (`str`): The name of the dataset in natural language. dataset_id (`str`): Example: common_voice. A dataset id from https://hf.co/datasets. metrics_config (`str`, *optional*): The name of the metric configuration used in `load_metric()`. Example: bleurt-large-512 in `load_metric("bleurt", "bleurt-large-512")`. metrics_verified (`bool`, *optional*, defaults to `False`): Indicates whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. Automatically computed by Hugging Face, do not set. dataset_config (`str`, *optional*): Example: fr. The name of the dataset configuration used in `load_dataset()`. dataset_split (`str`, *optional*): Example: test. The name of the dataset split used in `load_dataset()`. dataset_revision (`str`, *optional*): Example: 5503434ddd753f426f4b38109466949a1217c2bb. The name of the dataset dataset revision used in `load_dataset()`. metrics_verification_token (`bool`, *optional*): A JSON Web Token that is used to verify whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. Returns: `dict`: a metadata dict with the result from a model evaluated on a dataset. Example: ```python >>> from huggingface_hub import metadata_eval_result >>> results = metadata_eval_result( ... model_pretty_name="RoBERTa fine-tuned on ReactionGIF", ... task_pretty_name="Text Classification", ... task_id="text-classification", ... metrics_pretty_name="Accuracy", ... metrics_id="accuracy", ... metrics_value=0.2662102282047272, ... dataset_pretty_name="ReactionJPEG", ... dataset_id="julien-c/reactionjpeg", ... dataset_config="default", ... dataset_split="test", ... ) >>> results == { ... 'model-index': [ ... { ... 'name': 'RoBERTa fine-tuned on ReactionGIF', ... 'results': [ ... { ... 'task': { ... 'type': 'text-classification', ... 'name': 'Text Classification' ... }, ... 'dataset': { ... 'name': 'ReactionJPEG', ... 'type': 'julien-c/reactionjpeg', ... 'config': 'default', ... 'split': 'test' ... }, ... 'metrics': [ ... { ... 'type': 'accuracy', ... 'value': 0.2662102282047272, ... 'name': 'Accuracy', ... 'verified': False ... } ... ] ... } ... ] ... } ... ] ... } True ``` """ return { "model-index": eval_results_to_model_index( model_name=model_pretty_name, eval_results=[ EvalResult( task_name=task_pretty_name, task_type=task_id, metric_name=metrics_pretty_name, metric_type=metrics_id, metric_value=metrics_value, dataset_name=dataset_pretty_name, dataset_type=dataset_id, metric_config=metrics_config, verified=metrics_verified, verify_token=metrics_verification_token, dataset_config=dataset_config, dataset_split=dataset_split, dataset_revision=dataset_revision, ) ], ) } @validate_hf_hub_args def metadata_update( repo_id: str, metadata: Dict, *, repo_type: Optional[str] = None, overwrite: bool = False, token: Optional[str] = None, commit_message: Optional[str] = None, commit_description: Optional[str] = None, revision: Optional[str] = None, create_pr: bool = False, parent_commit: Optional[str] = None, ) -> str: """ Updates the metadata in the README.md of a repository on the Hugging Face Hub. If the README.md file doesn't exist yet, a new one is created with metadata and an the default ModelCard or DatasetCard template. For `space` repo, an error is thrown as a Space cannot exist without a `README.md` file. Args: repo_id (`str`): The name of the repository. metadata (`dict`): A dictionary containing the metadata to be updated. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if updating to a dataset or space, `None` or `"model"` if updating to a model. Default is `None`. overwrite (`bool`, *optional*, defaults to `False`): If set to `True` an existing field can be overwritten, otherwise attempting to overwrite an existing field will cause an error. token (`str`, *optional*): The Hugging Face authentication token. commit_message (`str`, *optional*): The summary / title / first line of the generated commit. Defaults to `f"Update metadata with huggingface_hub"` commit_description (`str` *optional*) The description of the generated commit revision (`str`, *optional*): The git revision to commit from. Defaults to the head of the `"main"` branch. create_pr (`boolean`, *optional*): Whether or not to create a Pull Request from `revision` with that commit. Defaults to `False`. parent_commit (`str`, *optional*): The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be especially useful if the repo is updated / committed to concurrently. Returns: `str`: URL of the commit which updated the card metadata. Example: ```python >>> from huggingface_hub import metadata_update >>> metadata = {'model-index': [{'name': 'RoBERTa fine-tuned on ReactionGIF', ... 'results': [{'dataset': {'name': 'ReactionGIF', ... 'type': 'julien-c/reactiongif'}, ... 'metrics': [{'name': 'Recall', ... 'type': 'recall', ... 'value': 0.7762102282047272}], ... 'task': {'name': 'Text Classification', ... 'type': 'text-classification'}}]}]} >>> url = metadata_update("hf-internal-testing/reactiongif-roberta-card", metadata) ``` """ commit_message = commit_message if commit_message is not None else "Update metadata with huggingface_hub" # Card class given repo_type card_class: Type[RepoCard] if repo_type is None or repo_type == "model": card_class = ModelCard elif repo_type == "dataset": card_class = DatasetCard elif repo_type == "space": card_class = RepoCard else: raise ValueError(f"Unknown repo_type: {repo_type}") # Either load repo_card from the Hub or create an empty one. # NOTE: Will not create the repo if it doesn't exist. try: card = card_class.load(repo_id, token=token, repo_type=repo_type) except EntryNotFoundError: if repo_type == "space": raise ValueError("Cannot update metadata on a Space that doesn't contain a `README.md` file.") # Initialize a ModelCard or DatasetCard from default template and no data. card = card_class.from_template(CardData()) for key, value in metadata.items(): if key == "model-index": # if the new metadata doesn't include a name, either use existing one or repo name if "name" not in value[0]: value[0]["name"] = getattr(card, "model_name", repo_id) model_name, new_results = model_index_to_eval_results(value) if card.data.eval_results is None: card.data.eval_results = new_results card.data.model_name = model_name else: existing_results = card.data.eval_results # Iterate over new results # Iterate over existing results # If both results describe the same metric but value is different: # If overwrite=True: overwrite the metric value # Else: raise ValueError # Else: append new result to existing ones. for new_result in new_results: result_found = False for existing_result in existing_results: if new_result.is_equal_except_value(existing_result): if new_result != existing_result and not overwrite: raise ValueError( "You passed a new value for the existing metric" f" 'name: {new_result.metric_name}, type: " f"{new_result.metric_type}'. Set `overwrite=True`" " to overwrite existing metrics." ) result_found = True existing_result.metric_value = new_result.metric_value if existing_result.verified is True: existing_result.verify_token = new_result.verify_token if not result_found: card.data.eval_results.append(new_result) else: # Any metadata that is not a result metric if card.data.get(key) is not None and not overwrite and card.data.get(key) != value: raise ValueError( f"You passed a new value for the existing meta data field '{key}'." " Set `overwrite=True` to overwrite existing metadata." ) else: card.data[key] = value return card.push_to_hub( repo_id, token=token, repo_type=repo_type, commit_message=commit_message, commit_description=commit_description, create_pr=create_pr, revision=revision, parent_commit=parent_commit, ) huggingface_hub-0.31.1/src/huggingface_hub/repocard_data.py000066400000000000000000001024421500667546600237610ustar00rootroot00000000000000import copy from collections import defaultdict from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union from huggingface_hub.utils import logging, yaml_dump logger = logging.get_logger(__name__) @dataclass class EvalResult: """ Flattened representation of individual evaluation results found in model-index of Model Cards. For more information on the model-index spec, see https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1. Args: task_type (`str`): The task identifier. Example: "image-classification". dataset_type (`str`): The dataset identifier. Example: "common_voice". Use dataset id from https://hf.co/datasets. dataset_name (`str`): A pretty name for the dataset. Example: "Common Voice (French)". metric_type (`str`): The metric identifier. Example: "wer". Use metric id from https://hf.co/metrics. metric_value (`Any`): The metric value. Example: 0.9 or "20.0 ± 1.2". task_name (`str`, *optional*): A pretty name for the task. Example: "Speech Recognition". dataset_config (`str`, *optional*): The name of the dataset configuration used in `load_dataset()`. Example: fr in `load_dataset("common_voice", "fr")`. See the `datasets` docs for more info: https://hf.co/docs/datasets/package_reference/loading_methods#datasets.load_dataset.name dataset_split (`str`, *optional*): The split used in `load_dataset()`. Example: "test". dataset_revision (`str`, *optional*): The revision (AKA Git Sha) of the dataset used in `load_dataset()`. Example: 5503434ddd753f426f4b38109466949a1217c2bb dataset_args (`Dict[str, Any]`, *optional*): The arguments passed during `Metric.compute()`. Example for `bleu`: `{"max_order": 4}` metric_name (`str`, *optional*): A pretty name for the metric. Example: "Test WER". metric_config (`str`, *optional*): The name of the metric configuration used in `load_metric()`. Example: bleurt-large-512 in `load_metric("bleurt", "bleurt-large-512")`. See the `datasets` docs for more info: https://huggingface.co/docs/datasets/v2.1.0/en/loading#load-configurations metric_args (`Dict[str, Any]`, *optional*): The arguments passed during `Metric.compute()`. Example for `bleu`: max_order: 4 verified (`bool`, *optional*): Indicates whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. Automatically computed by Hugging Face, do not set. verify_token (`str`, *optional*): A JSON Web Token that is used to verify whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. source_name (`str`, *optional*): The name of the source of the evaluation result. Example: "Open LLM Leaderboard". source_url (`str`, *optional*): The URL of the source of the evaluation result. Example: "https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard". """ # Required # The task identifier # Example: automatic-speech-recognition task_type: str # The dataset identifier # Example: common_voice. Use dataset id from https://hf.co/datasets dataset_type: str # A pretty name for the dataset. # Example: Common Voice (French) dataset_name: str # The metric identifier # Example: wer. Use metric id from https://hf.co/metrics metric_type: str # Value of the metric. # Example: 20.0 or "20.0 ± 1.2" metric_value: Any # Optional # A pretty name for the task. # Example: Speech Recognition task_name: Optional[str] = None # The name of the dataset configuration used in `load_dataset()`. # Example: fr in `load_dataset("common_voice", "fr")`. # See the `datasets` docs for more info: # https://huggingface.co/docs/datasets/package_reference/loading_methods#datasets.load_dataset.name dataset_config: Optional[str] = None # The split used in `load_dataset()`. # Example: test dataset_split: Optional[str] = None # The revision (AKA Git Sha) of the dataset used in `load_dataset()`. # Example: 5503434ddd753f426f4b38109466949a1217c2bb dataset_revision: Optional[str] = None # The arguments passed during `Metric.compute()`. # Example for `bleu`: max_order: 4 dataset_args: Optional[Dict[str, Any]] = None # A pretty name for the metric. # Example: Test WER metric_name: Optional[str] = None # The name of the metric configuration used in `load_metric()`. # Example: bleurt-large-512 in `load_metric("bleurt", "bleurt-large-512")`. # See the `datasets` docs for more info: https://huggingface.co/docs/datasets/v2.1.0/en/loading#load-configurations metric_config: Optional[str] = None # The arguments passed during `Metric.compute()`. # Example for `bleu`: max_order: 4 metric_args: Optional[Dict[str, Any]] = None # Indicates whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. Automatically computed by Hugging Face, do not set. verified: Optional[bool] = None # A JSON Web Token that is used to verify whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. verify_token: Optional[str] = None # The name of the source of the evaluation result. # Example: Open LLM Leaderboard source_name: Optional[str] = None # The URL of the source of the evaluation result. # Example: https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard source_url: Optional[str] = None @property def unique_identifier(self) -> tuple: """Returns a tuple that uniquely identifies this evaluation.""" return ( self.task_type, self.dataset_type, self.dataset_config, self.dataset_split, self.dataset_revision, ) def is_equal_except_value(self, other: "EvalResult") -> bool: """ Return True if `self` and `other` describe exactly the same metric but with a different value. """ for key, _ in self.__dict__.items(): if key == "metric_value": continue # For metrics computed by Hugging Face's evaluation service, `verify_token` is derived from `metric_value`, # so we exclude it here in the comparison. if key != "verify_token" and getattr(self, key) != getattr(other, key): return False return True def __post_init__(self) -> None: if self.source_name is not None and self.source_url is None: raise ValueError("If `source_name` is provided, `source_url` must also be provided.") @dataclass class CardData: """Structure containing metadata from a RepoCard. [`CardData`] is the parent class of [`ModelCardData`] and [`DatasetCardData`]. Metadata can be exported as a dictionary or YAML. Export can be customized to alter the representation of the data (example: flatten evaluation results). `CardData` behaves as a dictionary (can get, pop, set values) but do not inherit from `dict` to allow this export step. """ def __init__(self, ignore_metadata_errors: bool = False, **kwargs): self.__dict__.update(kwargs) def to_dict(self): """Converts CardData to a dict. Returns: `dict`: CardData represented as a dictionary ready to be dumped to a YAML block for inclusion in a README.md file. """ data_dict = copy.deepcopy(self.__dict__) self._to_dict(data_dict) return {key: value for key, value in data_dict.items() if value is not None} def _to_dict(self, data_dict): """Use this method in child classes to alter the dict representation of the data. Alter the dict in-place. Args: data_dict (`dict`): The raw dict representation of the card data. """ pass def to_yaml(self, line_break=None, original_order: Optional[List[str]] = None) -> str: """Dumps CardData to a YAML block for inclusion in a README.md file. Args: line_break (str, *optional*): The line break to use when dumping to yaml. Returns: `str`: CardData represented as a YAML block. """ if original_order: self.__dict__ = { k: self.__dict__[k] for k in original_order + list(set(self.__dict__.keys()) - set(original_order)) if k in self.__dict__ } return yaml_dump(self.to_dict(), sort_keys=False, line_break=line_break).strip() def __repr__(self): return repr(self.__dict__) def __str__(self): return self.to_yaml() def get(self, key: str, default: Any = None) -> Any: """Get value for a given metadata key.""" value = self.__dict__.get(key) return default if value is None else value def pop(self, key: str, default: Any = None) -> Any: """Pop value for a given metadata key.""" return self.__dict__.pop(key, default) def __getitem__(self, key: str) -> Any: """Get value for a given metadata key.""" return self.__dict__[key] def __setitem__(self, key: str, value: Any) -> None: """Set value for a given metadata key.""" self.__dict__[key] = value def __contains__(self, key: str) -> bool: """Check if a given metadata key is set.""" return key in self.__dict__ def __len__(self) -> int: """Return the number of metadata keys set.""" return len(self.__dict__) def _validate_eval_results( eval_results: Optional[Union[EvalResult, List[EvalResult]]], model_name: Optional[str], ) -> List[EvalResult]: if eval_results is None: return [] if isinstance(eval_results, EvalResult): eval_results = [eval_results] if not isinstance(eval_results, list) or not all(isinstance(r, EvalResult) for r in eval_results): raise ValueError( f"`eval_results` should be of type `EvalResult` or a list of `EvalResult`, got {type(eval_results)}." ) if model_name is None: raise ValueError("Passing `eval_results` requires `model_name` to be set.") return eval_results class ModelCardData(CardData): """Model Card Metadata that is used by Hugging Face Hub when included at the top of your README.md Args: base_model (`str` or `List[str]`, *optional*): The identifier of the base model from which the model derives. This is applicable for example if your model is a fine-tune or adapter of an existing model. The value must be the ID of a model on the Hub (or a list of IDs if your model derives from multiple models). Defaults to None. datasets (`Union[str, List[str]]`, *optional*): Dataset or list of datasets that were used to train this model. Should be a dataset ID found on https://hf.co/datasets. Defaults to None. eval_results (`Union[List[EvalResult], EvalResult]`, *optional*): List of `huggingface_hub.EvalResult` that define evaluation results of the model. If provided, `model_name` is used to as a name on PapersWithCode's leaderboards. Defaults to `None`. language (`Union[str, List[str]]`, *optional*): Language of model's training data or metadata. It must be an ISO 639-1, 639-2 or 639-3 code (two/three letters), or a special value like "code", "multilingual". Defaults to `None`. library_name (`str`, *optional*): Name of library used by this model. Example: keras or any library from https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/model-libraries.ts. Defaults to None. license (`str`, *optional*): License of this model. Example: apache-2.0 or any license from https://huggingface.co/docs/hub/repositories-licenses. Defaults to None. license_name (`str`, *optional*): Name of the license of this model. Defaults to None. To be used in conjunction with `license_link`. Common licenses (Apache-2.0, MIT, CC-BY-SA-4.0) do not need a name. In that case, use `license` instead. license_link (`str`, *optional*): Link to the license of this model. Defaults to None. To be used in conjunction with `license_name`. Common licenses (Apache-2.0, MIT, CC-BY-SA-4.0) do not need a link. In that case, use `license` instead. metrics (`List[str]`, *optional*): List of metrics used to evaluate this model. Should be a metric name that can be found at https://hf.co/metrics. Example: 'accuracy'. Defaults to None. model_name (`str`, *optional*): A name for this model. It is used along with `eval_results` to construct the `model-index` within the card's metadata. The name you supply here is what will be used on PapersWithCode's leaderboards. If None is provided then the repo name is used as a default. Defaults to None. pipeline_tag (`str`, *optional*): The pipeline tag associated with the model. Example: "text-classification". tags (`List[str]`, *optional*): List of tags to add to your model that can be used when filtering on the Hugging Face Hub. Defaults to None. ignore_metadata_errors (`str`): If True, errors while parsing the metadata section will be ignored. Some information might be lost during the process. Use it at your own risk. kwargs (`dict`, *optional*): Additional metadata that will be added to the model card. Defaults to None. Example: ```python >>> from huggingface_hub import ModelCardData >>> card_data = ModelCardData( ... language="en", ... license="mit", ... library_name="timm", ... tags=['image-classification', 'resnet'], ... ) >>> card_data.to_dict() {'language': 'en', 'license': 'mit', 'library_name': 'timm', 'tags': ['image-classification', 'resnet']} ``` """ def __init__( self, *, base_model: Optional[Union[str, List[str]]] = None, datasets: Optional[Union[str, List[str]]] = None, eval_results: Optional[List[EvalResult]] = None, language: Optional[Union[str, List[str]]] = None, library_name: Optional[str] = None, license: Optional[str] = None, license_name: Optional[str] = None, license_link: Optional[str] = None, metrics: Optional[List[str]] = None, model_name: Optional[str] = None, pipeline_tag: Optional[str] = None, tags: Optional[List[str]] = None, ignore_metadata_errors: bool = False, **kwargs, ): self.base_model = base_model self.datasets = datasets self.eval_results = eval_results self.language = language self.library_name = library_name self.license = license self.license_name = license_name self.license_link = license_link self.metrics = metrics self.model_name = model_name self.pipeline_tag = pipeline_tag self.tags = _to_unique_list(tags) model_index = kwargs.pop("model-index", None) if model_index: try: model_name, eval_results = model_index_to_eval_results(model_index) self.model_name = model_name self.eval_results = eval_results except (KeyError, TypeError) as error: if ignore_metadata_errors: logger.warning("Invalid model-index. Not loading eval results into CardData.") else: raise ValueError( f"Invalid `model_index` in metadata cannot be parsed: {error.__class__} {error}. Pass" " `ignore_metadata_errors=True` to ignore this error while loading a Model Card. Warning:" " some information will be lost. Use it at your own risk." ) super().__init__(**kwargs) if self.eval_results: try: self.eval_results = _validate_eval_results(self.eval_results, self.model_name) except Exception as e: if ignore_metadata_errors: logger.warning(f"Failed to validate eval_results: {e}. Not loading eval results into CardData.") else: raise ValueError(f"Failed to validate eval_results: {e}") from e def _to_dict(self, data_dict): """Format the internal data dict. In this case, we convert eval results to a valid model index""" if self.eval_results is not None: data_dict["model-index"] = eval_results_to_model_index(self.model_name, self.eval_results) del data_dict["eval_results"], data_dict["model_name"] class DatasetCardData(CardData): """Dataset Card Metadata that is used by Hugging Face Hub when included at the top of your README.md Args: language (`List[str]`, *optional*): Language of dataset's data or metadata. It must be an ISO 639-1, 639-2 or 639-3 code (two/three letters), or a special value like "code", "multilingual". license (`Union[str, List[str]]`, *optional*): License(s) of this dataset. Example: apache-2.0 or any license from https://huggingface.co/docs/hub/repositories-licenses. annotations_creators (`Union[str, List[str]]`, *optional*): How the annotations for the dataset were created. Options are: 'found', 'crowdsourced', 'expert-generated', 'machine-generated', 'no-annotation', 'other'. language_creators (`Union[str, List[str]]`, *optional*): How the text-based data in the dataset was created. Options are: 'found', 'crowdsourced', 'expert-generated', 'machine-generated', 'other' multilinguality (`Union[str, List[str]]`, *optional*): Whether the dataset is multilingual. Options are: 'monolingual', 'multilingual', 'translation', 'other'. size_categories (`Union[str, List[str]]`, *optional*): The number of examples in the dataset. Options are: 'n<1K', '1K1T', and 'other'. source_datasets (`List[str]]`, *optional*): Indicates whether the dataset is an original dataset or extended from another existing dataset. Options are: 'original' and 'extended'. task_categories (`Union[str, List[str]]`, *optional*): What categories of task does the dataset support? task_ids (`Union[str, List[str]]`, *optional*): What specific tasks does the dataset support? paperswithcode_id (`str`, *optional*): ID of the dataset on PapersWithCode. pretty_name (`str`, *optional*): A more human-readable name for the dataset. (ex. "Cats vs. Dogs") train_eval_index (`Dict`, *optional*): A dictionary that describes the necessary spec for doing evaluation on the Hub. If not provided, it will be gathered from the 'train-eval-index' key of the kwargs. config_names (`Union[str, List[str]]`, *optional*): A list of the available dataset configs for the dataset. """ def __init__( self, *, language: Optional[Union[str, List[str]]] = None, license: Optional[Union[str, List[str]]] = None, annotations_creators: Optional[Union[str, List[str]]] = None, language_creators: Optional[Union[str, List[str]]] = None, multilinguality: Optional[Union[str, List[str]]] = None, size_categories: Optional[Union[str, List[str]]] = None, source_datasets: Optional[List[str]] = None, task_categories: Optional[Union[str, List[str]]] = None, task_ids: Optional[Union[str, List[str]]] = None, paperswithcode_id: Optional[str] = None, pretty_name: Optional[str] = None, train_eval_index: Optional[Dict] = None, config_names: Optional[Union[str, List[str]]] = None, ignore_metadata_errors: bool = False, **kwargs, ): self.annotations_creators = annotations_creators self.language_creators = language_creators self.language = language self.license = license self.multilinguality = multilinguality self.size_categories = size_categories self.source_datasets = source_datasets self.task_categories = task_categories self.task_ids = task_ids self.paperswithcode_id = paperswithcode_id self.pretty_name = pretty_name self.config_names = config_names # TODO - maybe handle this similarly to EvalResult? self.train_eval_index = train_eval_index or kwargs.pop("train-eval-index", None) super().__init__(**kwargs) def _to_dict(self, data_dict): data_dict["train-eval-index"] = data_dict.pop("train_eval_index") class SpaceCardData(CardData): """Space Card Metadata that is used by Hugging Face Hub when included at the top of your README.md To get an exhaustive reference of Spaces configuration, please visit https://huggingface.co/docs/hub/spaces-config-reference#spaces-configuration-reference. Args: title (`str`, *optional*) Title of the Space. sdk (`str`, *optional*) SDK of the Space (one of `gradio`, `streamlit`, `docker`, or `static`). sdk_version (`str`, *optional*) Version of the used SDK (if Gradio/Streamlit sdk). python_version (`str`, *optional*) Python version used in the Space (if Gradio/Streamlit sdk). app_file (`str`, *optional*) Path to your main application file (which contains either gradio or streamlit Python code, or static html code). Path is relative to the root of the repository. app_port (`str`, *optional*) Port on which your application is running. Used only if sdk is `docker`. license (`str`, *optional*) License of this model. Example: apache-2.0 or any license from https://huggingface.co/docs/hub/repositories-licenses. duplicated_from (`str`, *optional*) ID of the original Space if this is a duplicated Space. models (List[`str`], *optional*) List of models related to this Space. Should be a dataset ID found on https://hf.co/models. datasets (`List[str]`, *optional*) List of datasets related to this Space. Should be a dataset ID found on https://hf.co/datasets. tags (`List[str]`, *optional*) List of tags to add to your Space that can be used when filtering on the Hub. ignore_metadata_errors (`str`): If True, errors while parsing the metadata section will be ignored. Some information might be lost during the process. Use it at your own risk. kwargs (`dict`, *optional*): Additional metadata that will be added to the space card. Example: ```python >>> from huggingface_hub import SpaceCardData >>> card_data = SpaceCardData( ... title="Dreambooth Training", ... license="mit", ... sdk="gradio", ... duplicated_from="multimodalart/dreambooth-training" ... ) >>> card_data.to_dict() {'title': 'Dreambooth Training', 'sdk': 'gradio', 'license': 'mit', 'duplicated_from': 'multimodalart/dreambooth-training'} ``` """ def __init__( self, *, title: Optional[str] = None, sdk: Optional[str] = None, sdk_version: Optional[str] = None, python_version: Optional[str] = None, app_file: Optional[str] = None, app_port: Optional[int] = None, license: Optional[str] = None, duplicated_from: Optional[str] = None, models: Optional[List[str]] = None, datasets: Optional[List[str]] = None, tags: Optional[List[str]] = None, ignore_metadata_errors: bool = False, **kwargs, ): self.title = title self.sdk = sdk self.sdk_version = sdk_version self.python_version = python_version self.app_file = app_file self.app_port = app_port self.license = license self.duplicated_from = duplicated_from self.models = models self.datasets = datasets self.tags = _to_unique_list(tags) super().__init__(**kwargs) def model_index_to_eval_results(model_index: List[Dict[str, Any]]) -> Tuple[str, List[EvalResult]]: """Takes in a model index and returns the model name and a list of `huggingface_hub.EvalResult` objects. A detailed spec of the model index can be found here: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 Args: model_index (`List[Dict[str, Any]]`): A model index data structure, likely coming from a README.md file on the Hugging Face Hub. Returns: model_name (`str`): The name of the model as found in the model index. This is used as the identifier for the model on leaderboards like PapersWithCode. eval_results (`List[EvalResult]`): A list of `huggingface_hub.EvalResult` objects containing the metrics reported in the provided model_index. Example: ```python >>> from huggingface_hub.repocard_data import model_index_to_eval_results >>> # Define a minimal model index >>> model_index = [ ... { ... "name": "my-cool-model", ... "results": [ ... { ... "task": { ... "type": "image-classification" ... }, ... "dataset": { ... "type": "beans", ... "name": "Beans" ... }, ... "metrics": [ ... { ... "type": "accuracy", ... "value": 0.9 ... } ... ] ... } ... ] ... } ... ] >>> model_name, eval_results = model_index_to_eval_results(model_index) >>> model_name 'my-cool-model' >>> eval_results[0].task_type 'image-classification' >>> eval_results[0].metric_type 'accuracy' ``` """ eval_results = [] for elem in model_index: name = elem["name"] results = elem["results"] for result in results: task_type = result["task"]["type"] task_name = result["task"].get("name") dataset_type = result["dataset"]["type"] dataset_name = result["dataset"]["name"] dataset_config = result["dataset"].get("config") dataset_split = result["dataset"].get("split") dataset_revision = result["dataset"].get("revision") dataset_args = result["dataset"].get("args") source_name = result.get("source", {}).get("name") source_url = result.get("source", {}).get("url") for metric in result["metrics"]: metric_type = metric["type"] metric_value = metric["value"] metric_name = metric.get("name") metric_args = metric.get("args") metric_config = metric.get("config") verified = metric.get("verified") verify_token = metric.get("verifyToken") eval_result = EvalResult( task_type=task_type, # Required dataset_type=dataset_type, # Required dataset_name=dataset_name, # Required metric_type=metric_type, # Required metric_value=metric_value, # Required task_name=task_name, dataset_config=dataset_config, dataset_split=dataset_split, dataset_revision=dataset_revision, dataset_args=dataset_args, metric_name=metric_name, metric_args=metric_args, metric_config=metric_config, verified=verified, verify_token=verify_token, source_name=source_name, source_url=source_url, ) eval_results.append(eval_result) return name, eval_results def _remove_none(obj): """ Recursively remove `None` values from a dict. Borrowed from: https://stackoverflow.com/a/20558778 """ if isinstance(obj, (list, tuple, set)): return type(obj)(_remove_none(x) for x in obj if x is not None) elif isinstance(obj, dict): return type(obj)((_remove_none(k), _remove_none(v)) for k, v in obj.items() if k is not None and v is not None) else: return obj def eval_results_to_model_index(model_name: str, eval_results: List[EvalResult]) -> List[Dict[str, Any]]: """Takes in given model name and list of `huggingface_hub.EvalResult` and returns a valid model-index that will be compatible with the format expected by the Hugging Face Hub. Args: model_name (`str`): Name of the model (ex. "my-cool-model"). This is used as the identifier for the model on leaderboards like PapersWithCode. eval_results (`List[EvalResult]`): List of `huggingface_hub.EvalResult` objects containing the metrics to be reported in the model-index. Returns: model_index (`List[Dict[str, Any]]`): The eval_results converted to a model-index. Example: ```python >>> from huggingface_hub.repocard_data import eval_results_to_model_index, EvalResult >>> # Define minimal eval_results >>> eval_results = [ ... EvalResult( ... task_type="image-classification", # Required ... dataset_type="beans", # Required ... dataset_name="Beans", # Required ... metric_type="accuracy", # Required ... metric_value=0.9, # Required ... ) ... ] >>> eval_results_to_model_index("my-cool-model", eval_results) [{'name': 'my-cool-model', 'results': [{'task': {'type': 'image-classification'}, 'dataset': {'name': 'Beans', 'type': 'beans'}, 'metrics': [{'type': 'accuracy', 'value': 0.9}]}]}] ``` """ # Metrics are reported on a unique task-and-dataset basis. # Here, we make a map of those pairs and the associated EvalResults. task_and_ds_types_map: Dict[Any, List[EvalResult]] = defaultdict(list) for eval_result in eval_results: task_and_ds_types_map[eval_result.unique_identifier].append(eval_result) # Use the map from above to generate the model index data. model_index_data = [] for results in task_and_ds_types_map.values(): # All items from `results` share same metadata sample_result = results[0] data = { "task": { "type": sample_result.task_type, "name": sample_result.task_name, }, "dataset": { "name": sample_result.dataset_name, "type": sample_result.dataset_type, "config": sample_result.dataset_config, "split": sample_result.dataset_split, "revision": sample_result.dataset_revision, "args": sample_result.dataset_args, }, "metrics": [ { "type": result.metric_type, "value": result.metric_value, "name": result.metric_name, "config": result.metric_config, "args": result.metric_args, "verified": result.verified, "verifyToken": result.verify_token, } for result in results ], } if sample_result.source_url is not None: source = { "url": sample_result.source_url, } if sample_result.source_name is not None: source["name"] = sample_result.source_name data["source"] = source model_index_data.append(data) # TODO - Check if there cases where this list is longer than one? # Finally, the model index itself is list of dicts. model_index = [ { "name": model_name, "results": model_index_data, } ] return _remove_none(model_index) def _to_unique_list(tags: Optional[List[str]]) -> Optional[List[str]]: if tags is None: return tags unique_tags = [] # make tags unique + keep order explicitly for tag in tags: if tag not in unique_tags: unique_tags.append(tag) return unique_tags huggingface_hub-0.31.1/src/huggingface_hub/repository.py000066400000000000000000001524351500667546600234170ustar00rootroot00000000000000import atexit import os import re import subprocess import threading import time from contextlib import contextmanager from pathlib import Path from typing import Callable, Dict, Iterator, List, Optional, Tuple, TypedDict, Union from urllib.parse import urlparse from huggingface_hub import constants from huggingface_hub.repocard import metadata_load, metadata_save from .hf_api import HfApi, repo_type_and_id_from_hf_id from .lfs import LFS_MULTIPART_UPLOAD_COMMAND from .utils import ( SoftTemporaryDirectory, get_token, logging, run_subprocess, tqdm, validate_hf_hub_args, ) from .utils._deprecation import _deprecate_method logger = logging.get_logger(__name__) class CommandInProgress: """ Utility to follow commands launched asynchronously. """ def __init__( self, title: str, is_done_method: Callable, status_method: Callable, process: subprocess.Popen, post_method: Optional[Callable] = None, ): self.title = title self._is_done = is_done_method self._status = status_method self._process = process self._stderr = "" self._stdout = "" self._post_method = post_method @property def is_done(self) -> bool: """ Whether the process is done. """ result = self._is_done() if result and self._post_method is not None: self._post_method() self._post_method = None return result @property def status(self) -> int: """ The exit code/status of the current action. Will return `0` if the command has completed successfully, and a number between 1 and 255 if the process errored-out. Will return -1 if the command is still ongoing. """ return self._status() @property def failed(self) -> bool: """ Whether the process errored-out. """ return self.status > 0 @property def stderr(self) -> str: """ The current output message on the standard error. """ if self._process.stderr is not None: self._stderr += self._process.stderr.read() return self._stderr @property def stdout(self) -> str: """ The current output message on the standard output. """ if self._process.stdout is not None: self._stdout += self._process.stdout.read() return self._stdout def __repr__(self): status = self.status if status == -1: status = "running" return ( f"[{self.title} command, status code: {status}," f" {'in progress.' if not self.is_done else 'finished.'} PID:" f" {self._process.pid}]" ) def is_git_repo(folder: Union[str, Path]) -> bool: """ Check if the folder is the root or part of a git repository Args: folder (`str`): The folder in which to run the command. Returns: `bool`: `True` if the repository is part of a repository, `False` otherwise. """ folder_exists = os.path.exists(os.path.join(folder, ".git")) git_branch = subprocess.run("git branch".split(), cwd=folder, stdout=subprocess.PIPE, stderr=subprocess.PIPE) return folder_exists and git_branch.returncode == 0 def is_local_clone(folder: Union[str, Path], remote_url: str) -> bool: """ Check if the folder is a local clone of the remote_url Args: folder (`str` or `Path`): The folder in which to run the command. remote_url (`str`): The url of a git repository. Returns: `bool`: `True` if the repository is a local clone of the remote repository specified, `False` otherwise. """ if not is_git_repo(folder): return False remotes = run_subprocess("git remote -v", folder).stdout # Remove token for the test with remotes. remote_url = re.sub(r"https://.*@", "https://", remote_url) remotes = [re.sub(r"https://.*@", "https://", remote) for remote in remotes.split()] return remote_url in remotes def is_tracked_with_lfs(filename: Union[str, Path]) -> bool: """ Check if the file passed is tracked with git-lfs. Args: filename (`str` or `Path`): The filename to check. Returns: `bool`: `True` if the file passed is tracked with git-lfs, `False` otherwise. """ folder = Path(filename).parent filename = Path(filename).name try: p = run_subprocess("git check-attr -a".split() + [filename], folder) attributes = p.stdout.strip() except subprocess.CalledProcessError as exc: if not is_git_repo(folder): return False else: raise OSError(exc.stderr) if len(attributes) == 0: return False found_lfs_tag = {"diff": False, "merge": False, "filter": False} for attribute in attributes.split("\n"): for tag in found_lfs_tag.keys(): if tag in attribute and "lfs" in attribute: found_lfs_tag[tag] = True return all(found_lfs_tag.values()) def is_git_ignored(filename: Union[str, Path]) -> bool: """ Check if file is git-ignored. Supports nested .gitignore files. Args: filename (`str` or `Path`): The filename to check. Returns: `bool`: `True` if the file passed is ignored by `git`, `False` otherwise. """ folder = Path(filename).parent filename = Path(filename).name try: p = run_subprocess("git check-ignore".split() + [filename], folder, check=False) # Will return exit code 1 if not gitignored is_ignored = not bool(p.returncode) except subprocess.CalledProcessError as exc: raise OSError(exc.stderr) return is_ignored def is_binary_file(filename: Union[str, Path]) -> bool: """ Check if file is a binary file. Args: filename (`str` or `Path`): The filename to check. Returns: `bool`: `True` if the file passed is a binary file, `False` otherwise. """ try: with open(filename, "rb") as f: content = f.read(10 * (1024**2)) # Read a maximum of 10MB # Code sample taken from the following stack overflow thread # https://stackoverflow.com/questions/898669/how-can-i-detect-if-a-file-is-binary-non-text-in-python/7392391#7392391 text_chars = bytearray({7, 8, 9, 10, 12, 13, 27} | set(range(0x20, 0x100)) - {0x7F}) return bool(content.translate(None, text_chars)) except UnicodeDecodeError: return True def files_to_be_staged(pattern: str = ".", folder: Union[str, Path, None] = None) -> List[str]: """ Returns a list of filenames that are to be staged. Args: pattern (`str` or `Path`): The pattern of filenames to check. Put `.` to get all files. folder (`str` or `Path`): The folder in which to run the command. Returns: `List[str]`: List of files that are to be staged. """ try: p = run_subprocess("git ls-files --exclude-standard -mo".split() + [pattern], folder) if len(p.stdout.strip()): files = p.stdout.strip().split("\n") else: files = [] except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) return files def is_tracked_upstream(folder: Union[str, Path]) -> bool: """ Check if the current checked-out branch is tracked upstream. Args: folder (`str` or `Path`): The folder in which to run the command. Returns: `bool`: `True` if the current checked-out branch is tracked upstream, `False` otherwise. """ try: run_subprocess("git rev-parse --symbolic-full-name --abbrev-ref @{u}", folder) return True except subprocess.CalledProcessError as exc: if "HEAD" in exc.stderr: raise OSError("No branch checked out") return False def commits_to_push(folder: Union[str, Path], upstream: Optional[str] = None) -> int: """ Check the number of commits that would be pushed upstream Args: folder (`str` or `Path`): The folder in which to run the command. upstream (`str`, *optional*): The name of the upstream repository with which the comparison should be made. Returns: `int`: Number of commits that would be pushed upstream were a `git push` to proceed. """ try: result = run_subprocess(f"git cherry -v {upstream or ''}", folder) return len(result.stdout.split("\n")) - 1 except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) class PbarT(TypedDict): # Used to store an opened progress bar in `_lfs_log_progress` bar: tqdm past_bytes: int @contextmanager def _lfs_log_progress(): """ This is a context manager that will log the Git LFS progress of cleaning, smudging, pulling and pushing. """ if logger.getEffectiveLevel() >= logging.ERROR: try: yield except Exception: pass return def output_progress(stopping_event: threading.Event): """ To be launched as a separate thread with an event meaning it should stop the tail. """ # Key is tuple(state, filename), value is a dict(tqdm bar and a previous value) pbars: Dict[Tuple[str, str], PbarT] = {} def close_pbars(): for pbar in pbars.values(): pbar["bar"].update(pbar["bar"].total - pbar["past_bytes"]) pbar["bar"].refresh() pbar["bar"].close() def tail_file(filename) -> Iterator[str]: """ Creates a generator to be iterated through, which will return each line one by one. Will stop tailing the file if the stopping_event is set. """ with open(filename, "r") as file: current_line = "" while True: if stopping_event.is_set(): close_pbars() break line_bit = file.readline() if line_bit is not None and not len(line_bit.strip()) == 0: current_line += line_bit if current_line.endswith("\n"): yield current_line current_line = "" else: time.sleep(1) # If the file isn't created yet, wait for a few seconds before trying again. # Can be interrupted with the stopping_event. while not os.path.exists(os.environ["GIT_LFS_PROGRESS"]): if stopping_event.is_set(): close_pbars() return time.sleep(2) for line in tail_file(os.environ["GIT_LFS_PROGRESS"]): try: state, file_progress, byte_progress, filename = line.split() except ValueError as error: # Try/except to ease debugging. See https://github.com/huggingface/huggingface_hub/issues/1373. raise ValueError(f"Cannot unpack LFS progress line:\n{line}") from error description = f"{state.capitalize()} file {filename}" current_bytes, total_bytes = byte_progress.split("/") current_bytes_int = int(current_bytes) total_bytes_int = int(total_bytes) pbar = pbars.get((state, filename)) if pbar is None: # Initialize progress bar pbars[(state, filename)] = { "bar": tqdm( desc=description, initial=current_bytes_int, total=total_bytes_int, unit="B", unit_scale=True, unit_divisor=1024, name="huggingface_hub.lfs_upload", ), "past_bytes": int(current_bytes), } else: # Update progress bar pbar["bar"].update(current_bytes_int - pbar["past_bytes"]) pbar["past_bytes"] = current_bytes_int current_lfs_progress_value = os.environ.get("GIT_LFS_PROGRESS", "") with SoftTemporaryDirectory() as tmpdir: os.environ["GIT_LFS_PROGRESS"] = os.path.join(tmpdir, "lfs_progress") logger.debug(f"Following progress in {os.environ['GIT_LFS_PROGRESS']}") exit_event = threading.Event() x = threading.Thread(target=output_progress, args=(exit_event,), daemon=True) x.start() try: yield finally: exit_event.set() x.join() os.environ["GIT_LFS_PROGRESS"] = current_lfs_progress_value class Repository: """ Helper class to wrap the git and git-lfs commands. The aim is to facilitate interacting with huggingface.co hosted model or dataset repos, though not a lot here (if any) is actually specific to huggingface.co. [`Repository`] is deprecated in favor of the http-based alternatives implemented in [`HfApi`]. Given its large adoption in legacy code, the complete removal of [`Repository`] will only happen in release `v1.0`. For more details, please read https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http. """ command_queue: List[CommandInProgress] @validate_hf_hub_args @_deprecate_method( version="1.0", message=( "Please prefer the http-based alternatives instead. Given its large adoption in legacy code, the complete" " removal is only planned on next major release.\nFor more details, please read" " https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http." ), ) def __init__( self, local_dir: Union[str, Path], clone_from: Optional[str] = None, repo_type: Optional[str] = None, token: Union[bool, str] = True, git_user: Optional[str] = None, git_email: Optional[str] = None, revision: Optional[str] = None, skip_lfs_files: bool = False, client: Optional[HfApi] = None, ): """ Instantiate a local clone of a git repo. If `clone_from` is set, the repo will be cloned from an existing remote repository. If the remote repo does not exist, a `EnvironmentError` exception will be thrown. Please create the remote repo first using [`create_repo`]. `Repository` uses the local git credentials by default. If explicitly set, the `token` or the `git_user`/`git_email` pair will be used instead. Args: local_dir (`str` or `Path`): path (e.g. `'my_trained_model/'`) to the local directory, where the `Repository` will be initialized. clone_from (`str`, *optional*): Either a repository url or `repo_id`. Example: - `"https://huggingface.co/philschmid/playground-tests"` - `"philschmid/playground-tests"` repo_type (`str`, *optional*): To set when cloning a repo from a repo_id. Default is model. token (`bool` or `str`, *optional*): A valid authentication token (see https://huggingface.co/settings/token). If `None` or `True` and machine is logged in (through `huggingface-cli login` or [`~huggingface_hub.login`]), token will be retrieved from the cache. If `False`, token is not sent in the request header. git_user (`str`, *optional*): will override the `git config user.name` for committing and pushing files to the hub. git_email (`str`, *optional*): will override the `git config user.email` for committing and pushing files to the hub. revision (`str`, *optional*): Revision to checkout after initializing the repository. If the revision doesn't exist, a branch will be created with that revision name from the default branch's current HEAD. skip_lfs_files (`bool`, *optional*, defaults to `False`): whether to skip git-LFS files or not. client (`HfApi`, *optional*): Instance of [`HfApi`] to use when calling the HF Hub API. A new instance will be created if this is left to `None`. Raises: [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) If the remote repository set in `clone_from` does not exist. """ if isinstance(local_dir, Path): local_dir = str(local_dir) os.makedirs(local_dir, exist_ok=True) self.local_dir = os.path.join(os.getcwd(), local_dir) self._repo_type = repo_type self.command_queue = [] self.skip_lfs_files = skip_lfs_files self.client = client if client is not None else HfApi() self.check_git_versions() if isinstance(token, str): self.huggingface_token: Optional[str] = token elif token is False: self.huggingface_token = None else: # if `True` -> explicit use of the cached token # if `None` -> implicit use of the cached token self.huggingface_token = get_token() if clone_from is not None: self.clone_from(repo_url=clone_from) else: if is_git_repo(self.local_dir): logger.debug("[Repository] is a valid git repo") else: raise ValueError("If not specifying `clone_from`, you need to pass Repository a valid git clone.") if self.huggingface_token is not None and (git_email is None or git_user is None): user = self.client.whoami(self.huggingface_token) if git_email is None: git_email = user.get("email") if git_user is None: git_user = user.get("fullname") if git_user is not None or git_email is not None: self.git_config_username_and_email(git_user, git_email) self.lfs_enable_largefiles() self.git_credential_helper_store() if revision is not None: self.git_checkout(revision, create_branch_ok=True) # This ensures that all commands exit before exiting the Python runtime. # This will ensure all pushes register on the hub, even if other errors happen in subsequent operations. atexit.register(self.wait_for_commands) @property def current_branch(self) -> str: """ Returns the current checked out branch. Returns: `str`: Current checked out branch. """ try: result = run_subprocess("git rev-parse --abbrev-ref HEAD", self.local_dir).stdout.strip() except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) return result def check_git_versions(self): """ Checks that `git` and `git-lfs` can be run. Raises: [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) If `git` or `git-lfs` are not installed. """ try: git_version = run_subprocess("git --version", self.local_dir).stdout.strip() except FileNotFoundError: raise EnvironmentError("Looks like you do not have git installed, please install.") try: lfs_version = run_subprocess("git-lfs --version", self.local_dir).stdout.strip() except FileNotFoundError: raise EnvironmentError( "Looks like you do not have git-lfs installed, please install." " You can install from https://git-lfs.github.com/." " Then run `git lfs install` (you only have to do this once)." ) logger.info(git_version + "\n" + lfs_version) @validate_hf_hub_args def clone_from(self, repo_url: str, token: Union[bool, str, None] = None): """ Clone from a remote. If the folder already exists, will try to clone the repository within it. If this folder is a git repository with linked history, will try to update the repository. Args: repo_url (`str`): The URL from which to clone the repository token (`Union[str, bool]`, *optional*): Whether to use the authentication token. It can be: - a string which is the token itself - `False`, which would not use the authentication token - `True`, which would fetch the authentication token from the local folder and use it (you should be logged in for this to work). - `None`, which would retrieve the value of `self.huggingface_token`. Raises the following error: - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if an organization token (starts with "api_org") is passed. Use must use your own personal access token (see https://hf.co/settings/tokens). - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) if you are trying to clone the repository in a non-empty folder, or if the `git` operations raise errors. """ token = ( token # str -> use it if isinstance(token, str) else ( None # `False` -> explicit no token if token is False else self.huggingface_token # `None` or `True` -> use default ) ) if token is not None and token.startswith("api_org"): raise ValueError( "You must use your personal access token, not an Organization token" " (see https://hf.co/settings/tokens)." ) hub_url = self.client.endpoint if hub_url in repo_url or ("http" not in repo_url and len(repo_url.split("/")) <= 2): repo_type, namespace, repo_name = repo_type_and_id_from_hf_id(repo_url, hub_url=hub_url) repo_id = f"{namespace}/{repo_name}" if namespace is not None else repo_name if repo_type is not None: self._repo_type = repo_type repo_url = hub_url + "/" if self._repo_type in constants.REPO_TYPES_URL_PREFIXES: repo_url += constants.REPO_TYPES_URL_PREFIXES[self._repo_type] if token is not None: # Add token in git url when provided scheme = urlparse(repo_url).scheme repo_url = repo_url.replace(f"{scheme}://", f"{scheme}://user:{token}@") repo_url += repo_id # For error messages, it's cleaner to show the repo url without the token. clean_repo_url = re.sub(r"(https?)://.*@", r"\1://", repo_url) try: run_subprocess("git lfs install", self.local_dir) # checks if repository is initialized in a empty repository or in one with files if len(os.listdir(self.local_dir)) == 0: logger.warning(f"Cloning {clean_repo_url} into local empty directory.") with _lfs_log_progress(): env = os.environ.copy() if self.skip_lfs_files: env.update({"GIT_LFS_SKIP_SMUDGE": "1"}) run_subprocess( # 'git lfs clone' is deprecated (will display a warning in the terminal) # but we still use it as it provides a nicer UX when downloading large # files (shows progress). f"{'git clone' if self.skip_lfs_files else 'git lfs clone'} {repo_url} .", self.local_dir, env=env, ) else: # Check if the folder is the root of a git repository if not is_git_repo(self.local_dir): raise EnvironmentError( "Tried to clone a repository in a non-empty folder that isn't" f" a git repository ('{self.local_dir}'). If you really want to" f" do this, do it manually:\n cd {self.local_dir} && git init" " && git remote add origin && git pull origin main\n or clone" " repo to a new folder and move your existing files there" " afterwards." ) if is_local_clone(self.local_dir, repo_url): logger.warning( f"{self.local_dir} is already a clone of {clean_repo_url}." " Make sure you pull the latest changes with" " `repo.git_pull()`." ) else: output = run_subprocess("git remote get-url origin", self.local_dir, check=False) error_msg = ( f"Tried to clone {clean_repo_url} in an unrelated git" " repository.\nIf you believe this is an error, please add" f" a remote with the following URL: {clean_repo_url}." ) if output.returncode == 0: clean_local_remote_url = re.sub(r"https://.*@", "https://", output.stdout) error_msg += f"\nLocal path has its origin defined as: {clean_local_remote_url}" raise EnvironmentError(error_msg) except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) def git_config_username_and_email(self, git_user: Optional[str] = None, git_email: Optional[str] = None): """ Sets git username and email (only in the current repo). Args: git_user (`str`, *optional*): The username to register through `git`. git_email (`str`, *optional*): The email to register through `git`. """ try: if git_user is not None: run_subprocess("git config user.name".split() + [git_user], self.local_dir) if git_email is not None: run_subprocess(f"git config user.email {git_email}".split(), self.local_dir) except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) def git_credential_helper_store(self): """ Sets the git credential helper to `store` """ try: run_subprocess("git config credential.helper store", self.local_dir) except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) def git_head_hash(self) -> str: """ Get commit sha on top of HEAD. Returns: `str`: The current checked out commit SHA. """ try: p = run_subprocess("git rev-parse HEAD", self.local_dir) return p.stdout.strip() except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) def git_remote_url(self) -> str: """ Get URL to origin remote. Returns: `str`: The URL of the `origin` remote. """ try: p = run_subprocess("git config --get remote.origin.url", self.local_dir) url = p.stdout.strip() # Strip basic auth info. return re.sub(r"https://.*@", "https://", url) except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) def git_head_commit_url(self) -> str: """ Get URL to last commit on HEAD. We assume it's been pushed, and the url scheme is the same one as for GitHub or HuggingFace. Returns: `str`: The URL to the current checked-out commit. """ sha = self.git_head_hash() url = self.git_remote_url() if url.endswith("/"): url = url[:-1] return f"{url}/commit/{sha}" def list_deleted_files(self) -> List[str]: """ Returns a list of the files that are deleted in the working directory or index. Returns: `List[str]`: A list of files that have been deleted in the working directory or index. """ try: git_status = run_subprocess("git status -s", self.local_dir).stdout.strip() except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) if len(git_status) == 0: return [] # Receives a status like the following # D .gitignore # D new_file.json # AD new_file1.json # ?? new_file2.json # ?? new_file4.json # Strip each line of whitespaces modified_files_statuses = [status.strip() for status in git_status.split("\n")] # Only keep files that are deleted using the D prefix deleted_files_statuses = [status for status in modified_files_statuses if "D" in status.split()[0]] # Remove the D prefix and strip to keep only the relevant filename deleted_files = [status.split()[-1].strip() for status in deleted_files_statuses] return deleted_files def lfs_track(self, patterns: Union[str, List[str]], filename: bool = False): """ Tell git-lfs to track files according to a pattern. Setting the `filename` argument to `True` will treat the arguments as literal filenames, not as patterns. Any special glob characters in the filename will be escaped when writing to the `.gitattributes` file. Args: patterns (`Union[str, List[str]]`): The pattern, or list of patterns, to track with git-lfs. filename (`bool`, *optional*, defaults to `False`): Whether to use the patterns as literal filenames. """ if isinstance(patterns, str): patterns = [patterns] try: for pattern in patterns: run_subprocess( f"git lfs track {'--filename' if filename else ''} {pattern}", self.local_dir, ) except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) def lfs_untrack(self, patterns: Union[str, List[str]]): """ Tell git-lfs to untrack those files. Args: patterns (`Union[str, List[str]]`): The pattern, or list of patterns, to untrack with git-lfs. """ if isinstance(patterns, str): patterns = [patterns] try: for pattern in patterns: run_subprocess("git lfs untrack".split() + [pattern], self.local_dir) except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) def lfs_enable_largefiles(self): """ HF-specific. This enables upload support of files >5GB. """ try: lfs_config = "git config lfs.customtransfer.multipart" run_subprocess(f"{lfs_config}.path huggingface-cli", self.local_dir) run_subprocess( f"{lfs_config}.args {LFS_MULTIPART_UPLOAD_COMMAND}", self.local_dir, ) except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) def auto_track_binary_files(self, pattern: str = ".") -> List[str]: """ Automatically track binary files with git-lfs. Args: pattern (`str`, *optional*, defaults to "."): The pattern with which to track files that are binary. Returns: `List[str]`: List of filenames that are now tracked due to being binary files """ files_to_be_tracked_with_lfs = [] deleted_files = self.list_deleted_files() for filename in files_to_be_staged(pattern, folder=self.local_dir): if filename in deleted_files: continue path_to_file = os.path.join(os.getcwd(), self.local_dir, filename) if not (is_tracked_with_lfs(path_to_file) or is_git_ignored(path_to_file)): size_in_mb = os.path.getsize(path_to_file) / (1024 * 1024) if size_in_mb >= 10: logger.warning( "Parsing a large file to check if binary or not. Tracking large" " files using `repository.auto_track_large_files` is" " recommended so as to not load the full file in memory." ) is_binary = is_binary_file(path_to_file) if is_binary: self.lfs_track(filename) files_to_be_tracked_with_lfs.append(filename) # Cleanup the .gitattributes if files were deleted self.lfs_untrack(deleted_files) return files_to_be_tracked_with_lfs def auto_track_large_files(self, pattern: str = ".") -> List[str]: """ Automatically track large files (files that weigh more than 10MBs) with git-lfs. Args: pattern (`str`, *optional*, defaults to "."): The pattern with which to track files that are above 10MBs. Returns: `List[str]`: List of filenames that are now tracked due to their size. """ files_to_be_tracked_with_lfs = [] deleted_files = self.list_deleted_files() for filename in files_to_be_staged(pattern, folder=self.local_dir): if filename in deleted_files: continue path_to_file = os.path.join(os.getcwd(), self.local_dir, filename) size_in_mb = os.path.getsize(path_to_file) / (1024 * 1024) if size_in_mb >= 10 and not is_tracked_with_lfs(path_to_file) and not is_git_ignored(path_to_file): self.lfs_track(filename) files_to_be_tracked_with_lfs.append(filename) # Cleanup the .gitattributes if files were deleted self.lfs_untrack(deleted_files) return files_to_be_tracked_with_lfs def lfs_prune(self, recent=False): """ git lfs prune Args: recent (`bool`, *optional*, defaults to `False`): Whether to prune files even if they were referenced by recent commits. See the following [link](https://github.com/git-lfs/git-lfs/blob/f3d43f0428a84fc4f1e5405b76b5a73ec2437e65/docs/man/git-lfs-prune.1.ronn#recent-files) for more information. """ try: with _lfs_log_progress(): result = run_subprocess(f"git lfs prune {'--recent' if recent else ''}", self.local_dir) logger.info(result.stdout) except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) def git_pull(self, rebase: bool = False, lfs: bool = False): """ git pull Args: rebase (`bool`, *optional*, defaults to `False`): Whether to rebase the current branch on top of the upstream branch after fetching. lfs (`bool`, *optional*, defaults to `False`): Whether to fetch the LFS files too. This option only changes the behavior when a repository was cloned without fetching the LFS files; calling `repo.git_pull(lfs=True)` will then fetch the LFS file from the remote repository. """ command = "git pull" if not lfs else "git lfs pull" if rebase: command += " --rebase" try: with _lfs_log_progress(): result = run_subprocess(command, self.local_dir) logger.info(result.stdout) except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) def git_add(self, pattern: str = ".", auto_lfs_track: bool = False): """ git add Setting the `auto_lfs_track` parameter to `True` will automatically track files that are larger than 10MB with `git-lfs`. Args: pattern (`str`, *optional*, defaults to "."): The pattern with which to add files to staging. auto_lfs_track (`bool`, *optional*, defaults to `False`): Whether to automatically track large and binary files with git-lfs. Any file over 10MB in size, or in binary format, will be automatically tracked. """ if auto_lfs_track: # Track files according to their size (>=10MB) tracked_files = self.auto_track_large_files(pattern) # Read the remaining files and track them if they're binary tracked_files.extend(self.auto_track_binary_files(pattern)) if tracked_files: logger.warning( f"Adding files tracked by Git LFS: {tracked_files}. This may take a" " bit of time if the files are large." ) try: result = run_subprocess("git add -v".split() + [pattern], self.local_dir) logger.info(f"Adding to index:\n{result.stdout}\n") except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) def git_commit(self, commit_message: str = "commit files to HF hub"): """ git commit Args: commit_message (`str`, *optional*, defaults to "commit files to HF hub"): The message attributed to the commit. """ try: result = run_subprocess("git commit -v -m".split() + [commit_message], self.local_dir) logger.info(f"Committed:\n{result.stdout}\n") except subprocess.CalledProcessError as exc: if len(exc.stderr) > 0: raise EnvironmentError(exc.stderr) else: raise EnvironmentError(exc.stdout) def git_push( self, upstream: Optional[str] = None, blocking: bool = True, auto_lfs_prune: bool = False, ) -> Union[str, Tuple[str, CommandInProgress]]: """ git push If used without setting `blocking`, will return url to commit on remote repo. If used with `blocking=True`, will return a tuple containing the url to commit and the command object to follow for information about the process. Args: upstream (`str`, *optional*): Upstream to which this should push. If not specified, will push to the lastly defined upstream or to the default one (`origin main`). blocking (`bool`, *optional*, defaults to `True`): Whether the function should return only when the push has finished. Setting this to `False` will return an `CommandInProgress` object which has an `is_done` property. This property will be set to `True` when the push is finished. auto_lfs_prune (`bool`, *optional*, defaults to `False`): Whether to automatically prune files once they have been pushed to the remote. """ command = "git push" if upstream: command += f" --set-upstream {upstream}" number_of_commits = commits_to_push(self.local_dir, upstream) if number_of_commits > 1: logger.warning(f"Several commits ({number_of_commits}) will be pushed upstream.") if blocking: logger.warning("The progress bars may be unreliable.") try: with _lfs_log_progress(): process = subprocess.Popen( command.split(), stderr=subprocess.PIPE, stdout=subprocess.PIPE, encoding="utf-8", cwd=self.local_dir, ) if blocking: stdout, stderr = process.communicate() return_code = process.poll() process.kill() if len(stderr): logger.warning(stderr) if return_code: raise subprocess.CalledProcessError(return_code, process.args, output=stdout, stderr=stderr) except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) if not blocking: def status_method(): status = process.poll() if status is None: return -1 else: return status command_in_progress = CommandInProgress( "push", is_done_method=lambda: process.poll() is not None, status_method=status_method, process=process, post_method=self.lfs_prune if auto_lfs_prune else None, ) self.command_queue.append(command_in_progress) return self.git_head_commit_url(), command_in_progress if auto_lfs_prune: self.lfs_prune() return self.git_head_commit_url() def git_checkout(self, revision: str, create_branch_ok: bool = False): """ git checkout a given revision Specifying `create_branch_ok` to `True` will create the branch to the given revision if that revision doesn't exist. Args: revision (`str`): The revision to checkout. create_branch_ok (`str`, *optional*, defaults to `False`): Whether creating a branch named with the `revision` passed at the current checked-out reference if `revision` isn't an existing revision is allowed. """ try: result = run_subprocess(f"git checkout {revision}", self.local_dir) logger.warning(f"Checked out {revision} from {self.current_branch}.") logger.warning(result.stdout) except subprocess.CalledProcessError as exc: if not create_branch_ok: raise EnvironmentError(exc.stderr) else: try: result = run_subprocess(f"git checkout -b {revision}", self.local_dir) logger.warning( f"Revision `{revision}` does not exist. Created and checked out branch `{revision}`." ) logger.warning(result.stdout) except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) def tag_exists(self, tag_name: str, remote: Optional[str] = None) -> bool: """ Check if a tag exists or not. Args: tag_name (`str`): The name of the tag to check. remote (`str`, *optional*): Whether to check if the tag exists on a remote. This parameter should be the identifier of the remote. Returns: `bool`: Whether the tag exists. """ if remote: try: result = run_subprocess(f"git ls-remote origin refs/tags/{tag_name}", self.local_dir).stdout.strip() except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) return len(result) != 0 else: try: git_tags = run_subprocess("git tag", self.local_dir).stdout.strip() except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) git_tags = git_tags.split("\n") return tag_name in git_tags def delete_tag(self, tag_name: str, remote: Optional[str] = None) -> bool: """ Delete a tag, both local and remote, if it exists Args: tag_name (`str`): The tag name to delete. remote (`str`, *optional*): The remote on which to delete the tag. Returns: `bool`: `True` if deleted, `False` if the tag didn't exist. If remote is not passed, will just be updated locally """ delete_locally = True delete_remotely = True if not self.tag_exists(tag_name): delete_locally = False if not self.tag_exists(tag_name, remote=remote): delete_remotely = False if delete_locally: try: run_subprocess(["git", "tag", "-d", tag_name], self.local_dir).stdout.strip() except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) if remote and delete_remotely: try: run_subprocess(f"git push {remote} --delete {tag_name}", self.local_dir).stdout.strip() except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) return True def add_tag(self, tag_name: str, message: Optional[str] = None, remote: Optional[str] = None): """ Add a tag at the current head and push it If remote is None, will just be updated locally If no message is provided, the tag will be lightweight. if a message is provided, the tag will be annotated. Args: tag_name (`str`): The name of the tag to be added. message (`str`, *optional*): The message that accompanies the tag. The tag will turn into an annotated tag if a message is passed. remote (`str`, *optional*): The remote on which to add the tag. """ if message: tag_args = ["git", "tag", "-a", tag_name, "-m", message] else: tag_args = ["git", "tag", tag_name] try: run_subprocess(tag_args, self.local_dir).stdout.strip() except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) if remote: try: run_subprocess(f"git push {remote} {tag_name}", self.local_dir).stdout.strip() except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) def is_repo_clean(self) -> bool: """ Return whether or not the git status is clean or not Returns: `bool`: `True` if the git status is clean, `False` otherwise. """ try: git_status = run_subprocess("git status --porcelain", self.local_dir).stdout.strip() except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) return len(git_status) == 0 def push_to_hub( self, commit_message: str = "commit files to HF hub", blocking: bool = True, clean_ok: bool = True, auto_lfs_prune: bool = False, ) -> Union[None, str, Tuple[str, CommandInProgress]]: """ Helper to add, commit, and push files to remote repository on the HuggingFace Hub. Will automatically track large files (>10MB). Args: commit_message (`str`): Message to use for the commit. blocking (`bool`, *optional*, defaults to `True`): Whether the function should return only when the `git push` has finished. clean_ok (`bool`, *optional*, defaults to `True`): If True, this function will return None if the repo is untouched. Default behavior is to fail because the git command fails. auto_lfs_prune (`bool`, *optional*, defaults to `False`): Whether to automatically prune files once they have been pushed to the remote. """ if clean_ok and self.is_repo_clean(): logger.info("Repo currently clean. Ignoring push_to_hub") return None self.git_add(auto_lfs_track=True) self.git_commit(commit_message) return self.git_push( upstream=f"origin {self.current_branch}", blocking=blocking, auto_lfs_prune=auto_lfs_prune, ) @contextmanager def commit( self, commit_message: str, branch: Optional[str] = None, track_large_files: bool = True, blocking: bool = True, auto_lfs_prune: bool = False, ): """ Context manager utility to handle committing to a repository. This automatically tracks large files (>10Mb) with git-lfs. Set the `track_large_files` argument to `False` if you wish to ignore that behavior. Args: commit_message (`str`): Message to use for the commit. branch (`str`, *optional*): The branch on which the commit will appear. This branch will be checked-out before any operation. track_large_files (`bool`, *optional*, defaults to `True`): Whether to automatically track large files or not. Will do so by default. blocking (`bool`, *optional*, defaults to `True`): Whether the function should return only when the `git push` has finished. auto_lfs_prune (`bool`, defaults to `True`): Whether to automatically prune files once they have been pushed to the remote. Examples: ```python >>> with Repository( ... "text-files", ... clone_from="/text-files", ... token=True, >>> ).commit("My first file :)"): ... with open("file.txt", "w+") as f: ... f.write(json.dumps({"hey": 8})) >>> import torch >>> model = torch.nn.Transformer() >>> with Repository( ... "torch-model", ... clone_from="/torch-model", ... token=True, >>> ).commit("My cool model :)"): ... torch.save(model.state_dict(), "model.pt") ``` """ files_to_stage = files_to_be_staged(".", folder=self.local_dir) if len(files_to_stage): files_in_msg = str(files_to_stage[:5])[:-1] + ", ...]" if len(files_to_stage) > 5 else str(files_to_stage) logger.error( "There exists some updated files in the local repository that are not" f" committed: {files_in_msg}. This may lead to errors if checking out" " a branch. These files and their modifications will be added to the" " current commit." ) if branch is not None: self.git_checkout(branch, create_branch_ok=True) if is_tracked_upstream(self.local_dir): logger.warning("Pulling changes ...") self.git_pull(rebase=True) else: logger.warning(f"The current branch has no upstream branch. Will push to 'origin {self.current_branch}'") current_working_directory = os.getcwd() os.chdir(os.path.join(current_working_directory, self.local_dir)) try: yield self finally: self.git_add(auto_lfs_track=track_large_files) try: self.git_commit(commit_message) except OSError as e: # If no changes are detected, there is nothing to commit. if "nothing to commit" not in str(e): raise e try: self.git_push( upstream=f"origin {self.current_branch}", blocking=blocking, auto_lfs_prune=auto_lfs_prune, ) except OSError as e: # If no changes are detected, there is nothing to commit. if "could not read Username" in str(e): raise OSError("Couldn't authenticate user for push. Did you set `token` to `True`?") from e else: raise e os.chdir(current_working_directory) def repocard_metadata_load(self) -> Optional[Dict]: filepath = os.path.join(self.local_dir, constants.REPOCARD_NAME) if os.path.isfile(filepath): return metadata_load(filepath) return None def repocard_metadata_save(self, data: Dict) -> None: return metadata_save(os.path.join(self.local_dir, constants.REPOCARD_NAME), data) @property def commands_failed(self): """ Returns the asynchronous commands that failed. """ return [c for c in self.command_queue if c.status > 0] @property def commands_in_progress(self): """ Returns the asynchronous commands that are currently in progress. """ return [c for c in self.command_queue if not c.is_done] def wait_for_commands(self): """ Blocking method: blocks all subsequent execution until all commands have been processed. """ index = 0 for command_failed in self.commands_failed: logger.error(f"The {command_failed.title} command with PID {command_failed._process.pid} failed.") logger.error(command_failed.stderr) while self.commands_in_progress: if index % 10 == 0: logger.warning( f"Waiting for the following commands to finish before shutting down: {self.commands_in_progress}." ) index += 1 time.sleep(1) huggingface_hub-0.31.1/src/huggingface_hub/serialization/000077500000000000000000000000001500667546600234715ustar00rootroot00000000000000huggingface_hub-0.31.1/src/huggingface_hub/serialization/__init__.py000066400000000000000000000020211500667546600255750ustar00rootroot00000000000000# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ruff: noqa: F401 """Contains helpers to serialize tensors.""" from ._base import StateDictSplit, split_state_dict_into_shards_factory from ._tensorflow import get_tf_storage_size, split_tf_state_dict_into_shards from ._torch import ( get_torch_storage_id, get_torch_storage_size, load_state_dict_from_file, load_torch_model, save_torch_model, save_torch_state_dict, split_torch_state_dict_into_shards, ) huggingface_hub-0.31.1/src/huggingface_hub/serialization/_base.py000066400000000000000000000176761500667546600251350ustar00rootroot00000000000000# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains helpers to split tensors into shards.""" from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, TypeVar, Union from .. import logging TensorT = TypeVar("TensorT") TensorSizeFn_T = Callable[[TensorT], int] StorageIDFn_T = Callable[[TensorT], Optional[Any]] MAX_SHARD_SIZE = "5GB" SIZE_UNITS = { "TB": 10**12, "GB": 10**9, "MB": 10**6, "KB": 10**3, } logger = logging.get_logger(__file__) @dataclass class StateDictSplit: is_sharded: bool = field(init=False) metadata: Dict[str, Any] filename_to_tensors: Dict[str, List[str]] tensor_to_filename: Dict[str, str] def __post_init__(self): self.is_sharded = len(self.filename_to_tensors) > 1 def split_state_dict_into_shards_factory( state_dict: Dict[str, TensorT], *, get_storage_size: TensorSizeFn_T, filename_pattern: str, get_storage_id: StorageIDFn_T = lambda tensor: None, max_shard_size: Union[int, str] = MAX_SHARD_SIZE, ) -> StateDictSplit: """ Split a model state dictionary in shards so that each shard is smaller than a given size. The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB]. If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a size greater than `max_shard_size`. Args: state_dict (`Dict[str, Tensor]`): The state dictionary to save. get_storage_size (`Callable[[Tensor], int]`): A function that returns the size of a tensor when saved on disk in bytes. get_storage_id (`Callable[[Tensor], Optional[Any]]`, *optional*): A function that returns a unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. This identifier is guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with non-overlapping lifetimes may have the same id. filename_pattern (`str`, *optional*): The pattern to generate the files names in which the model will be saved. Pattern must be a string that can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` max_shard_size (`int` or `str`, *optional*): The maximum size of each shard, in bytes. Defaults to 5GB. Returns: [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them. """ storage_id_to_tensors: Dict[Any, List[str]] = {} shard_list: List[Dict[str, TensorT]] = [] current_shard: Dict[str, TensorT] = {} current_shard_size = 0 total_size = 0 if isinstance(max_shard_size, str): max_shard_size = parse_size_to_int(max_shard_size) for key, tensor in state_dict.items(): # when bnb serialization is used the weights in the state dict can be strings # check: https://github.com/huggingface/transformers/pull/24416 for more details if isinstance(tensor, str): logger.info("Skipping tensor %s as it is a string (bnb serialization)", key) continue # If a `tensor` shares the same underlying storage as another tensor, we put `tensor` in the same `block` storage_id = get_storage_id(tensor) if storage_id is not None: if storage_id in storage_id_to_tensors: # We skip this tensor for now and will reassign to correct shard later storage_id_to_tensors[storage_id].append(key) continue else: # This is the first tensor with this storage_id, we create a new entry # in the storage_id_to_tensors dict => we will assign the shard id later storage_id_to_tensors[storage_id] = [key] # Compute tensor size tensor_size = get_storage_size(tensor) # If this tensor is bigger than the maximal size, we put it in its own shard if tensor_size > max_shard_size: total_size += tensor_size shard_list.append({key: tensor}) continue # If this tensor is going to tip up over the maximal size, we split. # Current shard already has some tensors, we add it to the list of shards and create a new one. if current_shard_size + tensor_size > max_shard_size: shard_list.append(current_shard) current_shard = {} current_shard_size = 0 # Add the tensor to the current shard current_shard[key] = tensor current_shard_size += tensor_size total_size += tensor_size # Add the last shard if len(current_shard) > 0: shard_list.append(current_shard) nb_shards = len(shard_list) # Loop over the tensors that share the same storage and assign them together for storage_id, keys in storage_id_to_tensors.items(): # Let's try to find the shard where the first tensor of this storage is and put all tensors in the same shard for shard in shard_list: if keys[0] in shard: for key in keys: shard[key] = state_dict[key] break # If we only have one shard, we return it => no need to build the index if nb_shards == 1: filename = filename_pattern.format(suffix="") return StateDictSplit( metadata={"total_size": total_size}, filename_to_tensors={filename: list(state_dict.keys())}, tensor_to_filename={key: filename for key in state_dict.keys()}, ) # Now that each tensor is assigned to a shard, let's assign a filename to each shard tensor_name_to_filename = {} filename_to_tensors = {} for idx, shard in enumerate(shard_list): filename = filename_pattern.format(suffix=f"-{idx + 1:05d}-of-{nb_shards:05d}") for key in shard: tensor_name_to_filename[key] = filename filename_to_tensors[filename] = list(shard.keys()) # Build the index and return return StateDictSplit( metadata={"total_size": total_size}, filename_to_tensors=filename_to_tensors, tensor_to_filename=tensor_name_to_filename, ) def parse_size_to_int(size_as_str: str) -> int: """ Parse a size expressed as a string with digits and unit (like `"5MB"`) to an integer (in bytes). Supported units are "TB", "GB", "MB", "KB". Args: size_as_str (`str`): The size to convert. Will be directly returned if an `int`. Example: ```py >>> parse_size_to_int("5MB") 5000000 ``` """ size_as_str = size_as_str.strip() # Parse unit unit = size_as_str[-2:].upper() if unit not in SIZE_UNITS: raise ValueError(f"Unit '{unit}' not supported. Supported units are TB, GB, MB, KB. Got '{size_as_str}'.") multiplier = SIZE_UNITS[unit] # Parse value try: value = float(size_as_str[:-2].strip()) except ValueError as e: raise ValueError(f"Could not parse the size value from '{size_as_str}': {e}") from e return int(value * multiplier) huggingface_hub-0.31.1/src/huggingface_hub/serialization/_dduf.py000066400000000000000000000361001500667546600251240ustar00rootroot00000000000000import json import logging import mmap import os import shutil import zipfile from contextlib import contextmanager from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, Generator, Iterable, Tuple, Union from ..errors import DDUFCorruptedFileError, DDUFExportError, DDUFInvalidEntryNameError logger = logging.getLogger(__name__) DDUF_ALLOWED_ENTRIES = { # Allowed file extensions in a DDUF file ".json", ".model", ".safetensors", ".txt", } DDUF_FOLDER_REQUIRED_ENTRIES = { # Each folder must contain at least one of these entries "config.json", "tokenizer_config.json", "preprocessor_config.json", "scheduler_config.json", } @dataclass class DDUFEntry: """Object representing a file entry in a DDUF file. See [`read_dduf_file`] for how to read a DDUF file. Attributes: filename (str): The name of the file in the DDUF archive. offset (int): The offset of the file in the DDUF archive. length (int): The length of the file in the DDUF archive. dduf_path (str): The path to the DDUF archive (for internal use). """ filename: str length: int offset: int dduf_path: Path = field(repr=False) @contextmanager def as_mmap(self) -> Generator[bytes, None, None]: """Open the file as a memory-mapped file. Useful to load safetensors directly from the file. Example: ```py >>> import safetensors.torch >>> with entry.as_mmap() as mm: ... tensors = safetensors.torch.load(mm) ``` """ with self.dduf_path.open("rb") as f: with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mm: yield mm[self.offset : self.offset + self.length] def read_text(self, encoding: str = "utf-8") -> str: """Read the file as text. Useful for '.txt' and '.json' entries. Example: ```py >>> import json >>> index = json.loads(entry.read_text()) ``` """ with self.dduf_path.open("rb") as f: f.seek(self.offset) return f.read(self.length).decode(encoding=encoding) def read_dduf_file(dduf_path: Union[os.PathLike, str]) -> Dict[str, DDUFEntry]: """ Read a DDUF file and return a dictionary of entries. Only the metadata is read, the data is not loaded in memory. Args: dduf_path (`str` or `os.PathLike`): The path to the DDUF file to read. Returns: `Dict[str, DDUFEntry]`: A dictionary of [`DDUFEntry`] indexed by filename. Raises: - [`DDUFCorruptedFileError`]: If the DDUF file is corrupted (i.e. doesn't follow the DDUF format). Example: ```python >>> import json >>> import safetensors.torch >>> from huggingface_hub import read_dduf_file # Read DDUF metadata >>> dduf_entries = read_dduf_file("FLUX.1-dev.dduf") # Returns a mapping filename <> DDUFEntry >>> dduf_entries["model_index.json"] DDUFEntry(filename='model_index.json', offset=66, length=587) # Load model index as JSON >>> json.loads(dduf_entries["model_index.json"].read_text()) {'_class_name': 'FluxPipeline', '_diffusers_version': '0.32.0.dev0', '_name_or_path': 'black-forest-labs/FLUX.1-dev', ... # Load VAE weights using safetensors >>> with dduf_entries["vae/diffusion_pytorch_model.safetensors"].as_mmap() as mm: ... state_dict = safetensors.torch.load(mm) ``` """ entries = {} dduf_path = Path(dduf_path) logger.info(f"Reading DDUF file {dduf_path}") with zipfile.ZipFile(str(dduf_path), "r") as zf: for info in zf.infolist(): logger.debug(f"Reading entry {info.filename}") if info.compress_type != zipfile.ZIP_STORED: raise DDUFCorruptedFileError("Data must not be compressed in DDUF file.") try: _validate_dduf_entry_name(info.filename) except DDUFInvalidEntryNameError as e: raise DDUFCorruptedFileError(f"Invalid entry name in DDUF file: {info.filename}") from e offset = _get_data_offset(zf, info) entries[info.filename] = DDUFEntry( filename=info.filename, offset=offset, length=info.file_size, dduf_path=dduf_path ) # Consistency checks on the DDUF file if "model_index.json" not in entries: raise DDUFCorruptedFileError("Missing required 'model_index.json' entry in DDUF file.") index = json.loads(entries["model_index.json"].read_text()) _validate_dduf_structure(index, entries.keys()) logger.info(f"Done reading DDUF file {dduf_path}. Found {len(entries)} entries") return entries def export_entries_as_dduf( dduf_path: Union[str, os.PathLike], entries: Iterable[Tuple[str, Union[str, Path, bytes]]] ) -> None: """Write a DDUF file from an iterable of entries. This is a lower-level helper than [`export_folder_as_dduf`] that allows more flexibility when serializing data. In particular, you don't need to save the data on disk before exporting it in the DDUF file. Args: dduf_path (`str` or `os.PathLike`): The path to the DDUF file to write. entries (`Iterable[Tuple[str, Union[str, Path, bytes]]]`): An iterable of entries to write in the DDUF file. Each entry is a tuple with the filename and the content. The filename should be the path to the file in the DDUF archive. The content can be a string or a pathlib.Path representing a path to a file on the local disk or directly the content as bytes. Raises: - [`DDUFExportError`]: If anything goes wrong during the export (e.g. invalid entry name, missing 'model_index.json', etc.). Example: ```python # Export specific files from the local disk. >>> from huggingface_hub import export_entries_as_dduf >>> export_entries_as_dduf( ... dduf_path="stable-diffusion-v1-4-FP16.dduf", ... entries=[ # List entries to add to the DDUF file (here, only FP16 weights) ... ("model_index.json", "path/to/model_index.json"), ... ("vae/config.json", "path/to/vae/config.json"), ... ("vae/diffusion_pytorch_model.fp16.safetensors", "path/to/vae/diffusion_pytorch_model.fp16.safetensors"), ... ("text_encoder/config.json", "path/to/text_encoder/config.json"), ... ("text_encoder/model.fp16.safetensors", "path/to/text_encoder/model.fp16.safetensors"), ... # ... add more entries here ... ] ... ) ``` ```python # Export state_dicts one by one from a loaded pipeline >>> from diffusers import DiffusionPipeline >>> from typing import Generator, Tuple >>> import safetensors.torch >>> from huggingface_hub import export_entries_as_dduf >>> pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") ... # ... do some work with the pipeline >>> def as_entries(pipe: DiffusionPipeline) -> Generator[Tuple[str, bytes], None, None]: ... # Build an generator that yields the entries to add to the DDUF file. ... # The first element of the tuple is the filename in the DDUF archive (must use UNIX separator!). The second element is the content of the file. ... # Entries will be evaluated lazily when the DDUF file is created (only 1 entry is loaded in memory at a time) ... yield "vae/config.json", pipe.vae.to_json_string().encode() ... yield "vae/diffusion_pytorch_model.safetensors", safetensors.torch.save(pipe.vae.state_dict()) ... yield "text_encoder/config.json", pipe.text_encoder.config.to_json_string().encode() ... yield "text_encoder/model.safetensors", safetensors.torch.save(pipe.text_encoder.state_dict()) ... # ... add more entries here >>> export_entries_as_dduf(dduf_path="stable-diffusion-v1-4.dduf", entries=as_entries(pipe)) ``` """ logger.info(f"Exporting DDUF file '{dduf_path}'") filenames = set() index = None with zipfile.ZipFile(str(dduf_path), "w", zipfile.ZIP_STORED) as archive: for filename, content in entries: if filename in filenames: raise DDUFExportError(f"Can't add duplicate entry: {filename}") filenames.add(filename) if filename == "model_index.json": try: index = json.loads(_load_content(content).decode()) except json.JSONDecodeError as e: raise DDUFExportError("Failed to parse 'model_index.json'.") from e try: filename = _validate_dduf_entry_name(filename) except DDUFInvalidEntryNameError as e: raise DDUFExportError(f"Invalid entry name: {filename}") from e logger.debug(f"Adding entry '{filename}' to DDUF file") _dump_content_in_archive(archive, filename, content) # Consistency checks on the DDUF file if index is None: raise DDUFExportError("Missing required 'model_index.json' entry in DDUF file.") try: _validate_dduf_structure(index, filenames) except DDUFCorruptedFileError as e: raise DDUFExportError("Invalid DDUF file structure.") from e logger.info(f"Done writing DDUF file {dduf_path}") def export_folder_as_dduf(dduf_path: Union[str, os.PathLike], folder_path: Union[str, os.PathLike]) -> None: """ Export a folder as a DDUF file. AUses [`export_entries_as_dduf`] under the hood. Args: dduf_path (`str` or `os.PathLike`): The path to the DDUF file to write. folder_path (`str` or `os.PathLike`): The path to the folder containing the diffusion model. Example: ```python >>> from huggingface_hub import export_folder_as_dduf >>> export_folder_as_dduf(dduf_path="FLUX.1-dev.dduf", folder_path="path/to/FLUX.1-dev") ``` """ folder_path = Path(folder_path) def _iterate_over_folder() -> Iterable[Tuple[str, Path]]: for path in Path(folder_path).glob("**/*"): if not path.is_file(): continue if path.suffix not in DDUF_ALLOWED_ENTRIES: logger.debug(f"Skipping file '{path}' (file type not allowed)") continue path_in_archive = path.relative_to(folder_path) if len(path_in_archive.parts) >= 3: logger.debug(f"Skipping file '{path}' (nested directories not allowed)") continue yield path_in_archive.as_posix(), path export_entries_as_dduf(dduf_path, _iterate_over_folder()) def _dump_content_in_archive(archive: zipfile.ZipFile, filename: str, content: Union[str, os.PathLike, bytes]) -> None: with archive.open(filename, "w", force_zip64=True) as archive_fh: if isinstance(content, (str, Path)): content_path = Path(content) with content_path.open("rb") as content_fh: shutil.copyfileobj(content_fh, archive_fh, 1024 * 1024 * 8) # type: ignore[misc] elif isinstance(content, bytes): archive_fh.write(content) else: raise DDUFExportError(f"Invalid content type for {filename}. Must be str, Path or bytes.") def _load_content(content: Union[str, Path, bytes]) -> bytes: """Load the content of an entry as bytes. Used only for small checks (not to dump content into archive). """ if isinstance(content, (str, Path)): return Path(content).read_bytes() elif isinstance(content, bytes): return content else: raise DDUFExportError(f"Invalid content type. Must be str, Path or bytes. Got {type(content)}.") def _validate_dduf_entry_name(entry_name: str) -> str: if "." + entry_name.split(".")[-1] not in DDUF_ALLOWED_ENTRIES: raise DDUFInvalidEntryNameError(f"File type not allowed: {entry_name}") if "\\" in entry_name: raise DDUFInvalidEntryNameError(f"Entry names must use UNIX separators ('/'). Got {entry_name}.") entry_name = entry_name.strip("/") if entry_name.count("/") > 1: raise DDUFInvalidEntryNameError(f"DDUF only supports 1 level of directory. Got {entry_name}.") return entry_name def _validate_dduf_structure(index: Any, entry_names: Iterable[str]) -> None: """ Consistency checks on the DDUF file structure. Rules: - The 'model_index.json' entry is required and must contain a dictionary. - Each folder name must correspond to an entry in 'model_index.json'. - Each folder must contain at least a config file ('config.json', 'tokenizer_config.json', 'preprocessor_config.json', 'scheduler_config.json'). Args: index (Any): The content of the 'model_index.json' entry. entry_names (Iterable[str]): The list of entry names in the DDUF file. Raises: - [`DDUFCorruptedFileError`]: If the DDUF file is corrupted (i.e. doesn't follow the DDUF format). """ if not isinstance(index, dict): raise DDUFCorruptedFileError(f"Invalid 'model_index.json' content. Must be a dictionary. Got {type(index)}.") dduf_folders = {entry.split("/")[0] for entry in entry_names if "/" in entry} for folder in dduf_folders: if folder not in index: raise DDUFCorruptedFileError(f"Missing required entry '{folder}' in 'model_index.json'.") if not any(f"{folder}/{required_entry}" in entry_names for required_entry in DDUF_FOLDER_REQUIRED_ENTRIES): raise DDUFCorruptedFileError( f"Missing required file in folder '{folder}'. Must contains at least one of {DDUF_FOLDER_REQUIRED_ENTRIES}." ) def _get_data_offset(zf: zipfile.ZipFile, info: zipfile.ZipInfo) -> int: """ Calculate the data offset for a file in a ZIP archive. Args: zf (`zipfile.ZipFile`): The opened ZIP file. Must be opened in read mode. info (`zipfile.ZipInfo`): The file info. Returns: int: The offset of the file data in the ZIP archive. """ if zf.fp is None: raise DDUFCorruptedFileError("ZipFile object must be opened in read mode.") # Step 1: Get the local file header offset header_offset = info.header_offset # Step 2: Read the local file header zf.fp.seek(header_offset) local_file_header = zf.fp.read(30) # Fixed-size part of the local header if len(local_file_header) < 30: raise DDUFCorruptedFileError("Incomplete local file header.") # Step 3: Parse the header fields to calculate the start of file data # Local file header: https://en.wikipedia.org/wiki/ZIP_(file_format)#File_headers filename_len = int.from_bytes(local_file_header[26:28], "little") extra_field_len = int.from_bytes(local_file_header[28:30], "little") # Data offset is after the fixed header, filename, and extra fields data_offset = header_offset + 30 + filename_len + extra_field_len return data_offset huggingface_hub-0.31.1/src/huggingface_hub/serialization/_tensorflow.py000066400000000000000000000070511500667546600264070ustar00rootroot00000000000000# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains tensorflow-specific helpers.""" import math import re from typing import TYPE_CHECKING, Dict, Union from .. import constants from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory if TYPE_CHECKING: import tensorflow as tf def split_tf_state_dict_into_shards( state_dict: Dict[str, "tf.Tensor"], *, filename_pattern: str = constants.TF2_WEIGHTS_FILE_PATTERN, max_shard_size: Union[int, str] = MAX_SHARD_SIZE, ) -> StateDictSplit: """ Split a model state dictionary in shards so that each shard is smaller than a given size. The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB]. If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a size greater than `max_shard_size`. Args: state_dict (`Dict[str, Tensor]`): The state dictionary to save. filename_pattern (`str`, *optional*): The pattern to generate the files names in which the model will be saved. Pattern must be a string that can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` Defaults to `"tf_model{suffix}.h5"`. max_shard_size (`int` or `str`, *optional*): The maximum size of each shard, in bytes. Defaults to 5GB. Returns: [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them. """ return split_state_dict_into_shards_factory( state_dict, max_shard_size=max_shard_size, filename_pattern=filename_pattern, get_storage_size=get_tf_storage_size, ) def get_tf_storage_size(tensor: "tf.Tensor") -> int: # Return `math.ceil` since dtype byte size can be a float (e.g., 0.125 for tf.bool). # Better to overestimate than underestimate. return math.ceil(tensor.numpy().size * _dtype_byte_size_tf(tensor.dtype)) def _dtype_byte_size_tf(dtype) -> float: """ Returns the size (in bytes) occupied by one parameter of type `dtype`. Taken from https://github.com/huggingface/transformers/blob/74d9d0cebb0263a3f8ab9c280569170cc74651d0/src/transformers/modeling_tf_utils.py#L608. NOTE: why not `tensor.numpy().nbytes`? Example: ```py >>> _dtype_byte_size(tf.float32) 4 ``` """ import tensorflow as tf if dtype == tf.bool: return 1 / 8 bit_search = re.search(r"[^\d](\d+)$", dtype.name) if bit_search is None: raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") bit_size = int(bit_search.groups()[0]) return bit_size // 8 huggingface_hub-0.31.1/src/huggingface_hub/serialization/_torch.py000066400000000000000000001272521500667546600253320ustar00rootroot00000000000000# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains pytorch-specific helpers.""" import importlib import json import os import re from collections import defaultdict, namedtuple from functools import lru_cache from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Union from packaging import version from .. import constants, logging from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory logger = logging.get_logger(__file__) if TYPE_CHECKING: import torch # SAVING def save_torch_model( model: "torch.nn.Module", save_directory: Union[str, Path], *, filename_pattern: Optional[str] = None, force_contiguous: bool = True, max_shard_size: Union[int, str] = MAX_SHARD_SIZE, metadata: Optional[Dict[str, str]] = None, safe_serialization: bool = True, is_main_process: bool = True, shared_tensors_to_discard: Optional[List[str]] = None, ): """ Saves a given torch model to disk, handling sharding and shared tensors issues. See also [`save_torch_state_dict`] to save a state dict with more flexibility. For more information about tensor sharing, check out [this guide](https://huggingface.co/docs/safetensors/torch_shared_tensors). The model state dictionary is split into shards so that each shard is smaller than a given size. The shards are saved in the `save_directory` with the given `filename_pattern`. If the model is too big to fit in a single shard, an index file is saved in the `save_directory` to indicate where each tensor is saved. This helper uses [`split_torch_state_dict_into_shards`] under the hood. If `safe_serialization` is `True`, the shards are saved as safetensors (the default). Otherwise, the shards are saved as pickle. Before saving the model, the `save_directory` is cleaned from any previous shard files. If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a size greater than `max_shard_size`. If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving. Args: model (`torch.nn.Module`): The model to save on disk. save_directory (`str` or `Path`): The directory in which the model will be saved. filename_pattern (`str`, *optional*): The pattern to generate the files names in which the model will be saved. Pattern must be a string that can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` Defaults to `"model{suffix}.safetensors"` or `pytorch_model{suffix}.bin` depending on `safe_serialization` parameter. force_contiguous (`boolean`, *optional*): Forcing the state_dict to be saved as contiguous tensors. This has no effect on the correctness of the model, but it could potentially change performance if the layout of the tensor was chosen specifically for that reason. Defaults to `True`. max_shard_size (`int` or `str`, *optional*): The maximum size of each shard, in bytes. Defaults to 5GB. metadata (`Dict[str, str]`, *optional*): Extra information to save along with the model. Some metadata will be added for each dropped tensors. This information will not be enough to recover the entire shared structure but might help understanding things. safe_serialization (`bool`, *optional*): Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle. Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed in a future version. is_main_process (`bool`, *optional*): Whether the process calling this is the main process or not. Useful when in distributed training like TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. Defaults to True. shared_tensors_to_discard (`List[str]`, *optional*): List of tensor names to drop when saving shared tensors. If not provided and shared tensors are detected, it will drop the first name alphabetically. Example: ```py >>> from huggingface_hub import save_torch_model >>> model = ... # A PyTorch model # Save state dict to "path/to/folder". The model will be split into shards of 5GB each and saved as safetensors. >>> save_torch_model(model, "path/to/folder") # Load model back >>> from huggingface_hub import load_torch_model # TODO >>> load_torch_model(model, "path/to/folder") >>> ``` """ save_torch_state_dict( state_dict=model.state_dict(), filename_pattern=filename_pattern, force_contiguous=force_contiguous, max_shard_size=max_shard_size, metadata=metadata, safe_serialization=safe_serialization, save_directory=save_directory, is_main_process=is_main_process, shared_tensors_to_discard=shared_tensors_to_discard, ) def save_torch_state_dict( state_dict: Dict[str, "torch.Tensor"], save_directory: Union[str, Path], *, filename_pattern: Optional[str] = None, force_contiguous: bool = True, max_shard_size: Union[int, str] = MAX_SHARD_SIZE, metadata: Optional[Dict[str, str]] = None, safe_serialization: bool = True, is_main_process: bool = True, shared_tensors_to_discard: Optional[List[str]] = None, ) -> None: """ Save a model state dictionary to the disk, handling sharding and shared tensors issues. See also [`save_torch_model`] to directly save a PyTorch model. For more information about tensor sharing, check out [this guide](https://huggingface.co/docs/safetensors/torch_shared_tensors). The model state dictionary is split into shards so that each shard is smaller than a given size. The shards are saved in the `save_directory` with the given `filename_pattern`. If the model is too big to fit in a single shard, an index file is saved in the `save_directory` to indicate where each tensor is saved. This helper uses [`split_torch_state_dict_into_shards`] under the hood. If `safe_serialization` is `True`, the shards are saved as safetensors (the default). Otherwise, the shards are saved as pickle. Before saving the model, the `save_directory` is cleaned from any previous shard files. If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a size greater than `max_shard_size`. If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving. Args: state_dict (`Dict[str, torch.Tensor]`): The state dictionary to save. save_directory (`str` or `Path`): The directory in which the model will be saved. filename_pattern (`str`, *optional*): The pattern to generate the files names in which the model will be saved. Pattern must be a string that can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` Defaults to `"model{suffix}.safetensors"` or `pytorch_model{suffix}.bin` depending on `safe_serialization` parameter. force_contiguous (`boolean`, *optional*): Forcing the state_dict to be saved as contiguous tensors. This has no effect on the correctness of the model, but it could potentially change performance if the layout of the tensor was chosen specifically for that reason. Defaults to `True`. max_shard_size (`int` or `str`, *optional*): The maximum size of each shard, in bytes. Defaults to 5GB. metadata (`Dict[str, str]`, *optional*): Extra information to save along with the model. Some metadata will be added for each dropped tensors. This information will not be enough to recover the entire shared structure but might help understanding things. safe_serialization (`bool`, *optional*): Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle. Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed in a future version. is_main_process (`bool`, *optional*): Whether the process calling this is the main process or not. Useful when in distributed training like TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. Defaults to True. shared_tensors_to_discard (`List[str]`, *optional*): List of tensor names to drop when saving shared tensors. If not provided and shared tensors are detected, it will drop the first name alphabetically. Example: ```py >>> from huggingface_hub import save_torch_state_dict >>> model = ... # A PyTorch model # Save state dict to "path/to/folder". The model will be split into shards of 5GB each and saved as safetensors. >>> state_dict = model_to_save.state_dict() >>> save_torch_state_dict(state_dict, "path/to/folder") ``` """ save_directory = str(save_directory) if filename_pattern is None: filename_pattern = ( constants.SAFETENSORS_WEIGHTS_FILE_PATTERN if safe_serialization else constants.PYTORCH_WEIGHTS_FILE_PATTERN ) if metadata is None: metadata = {} if safe_serialization: try: from safetensors.torch import save_file as save_file_fn except ImportError as e: raise ImportError( "Please install `safetensors` to use safe serialization. " "You can install it with `pip install safetensors`." ) from e # Clean state dict for safetensors state_dict = _clean_state_dict_for_safetensors( state_dict, metadata, force_contiguous=force_contiguous, shared_tensors_to_discard=shared_tensors_to_discard, ) else: from torch import save as save_file_fn # type: ignore[assignment] logger.warning( "You are using unsafe serialization. Due to security reasons, it is recommended not to load " "pickled models from untrusted sources. If you intend to share your model, we strongly recommend " "using safe serialization by installing `safetensors` with `pip install safetensors`." ) # Split dict state_dict_split = split_torch_state_dict_into_shards( state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size ) # Only main process should clean up existing files to avoid race conditions in distributed environment if is_main_process: existing_files_regex = re.compile(filename_pattern.format(suffix=r"(-\d{5}-of-\d{5})?") + r"(\.index\.json)?") for filename in os.listdir(save_directory): if existing_files_regex.match(filename): try: logger.debug(f"Removing existing file '{filename}' from folder.") os.remove(os.path.join(save_directory, filename)) except Exception as e: logger.warning( f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing..." ) # Save each shard per_file_metadata = {"format": "pt"} if not state_dict_split.is_sharded: per_file_metadata.update(metadata) safe_file_kwargs = {"metadata": per_file_metadata} if safe_serialization else {} for filename, tensors in state_dict_split.filename_to_tensors.items(): shard = {tensor: state_dict[tensor] for tensor in tensors} save_file_fn(shard, os.path.join(save_directory, filename), **safe_file_kwargs) logger.debug(f"Shard saved to {filename}") # Save the index (if any) if state_dict_split.is_sharded: index_path = filename_pattern.format(suffix="") + ".index.json" index = { "metadata": {**state_dict_split.metadata, **metadata}, "weight_map": state_dict_split.tensor_to_filename, } with open(os.path.join(save_directory, index_path), "w") as f: json.dump(index, f, indent=2) logger.info( f"The model is bigger than the maximum size per checkpoint ({max_shard_size}). " f"Model weighs have been saved in {len(state_dict_split.filename_to_tensors)} checkpoint shards. " f"You can find where each parameters has been saved in the index located at {index_path}." ) logger.info(f"Model weights successfully saved to {save_directory}!") def split_torch_state_dict_into_shards( state_dict: Dict[str, "torch.Tensor"], *, filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN, max_shard_size: Union[int, str] = MAX_SHARD_SIZE, ) -> StateDictSplit: """ Split a model state dictionary in shards so that each shard is smaller than a given size. The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB]. To save a model state dictionary to the disk, see [`save_torch_state_dict`]. This helper uses `split_torch_state_dict_into_shards` under the hood. If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a size greater than `max_shard_size`. Args: state_dict (`Dict[str, torch.Tensor]`): The state dictionary to save. filename_pattern (`str`, *optional*): The pattern to generate the files names in which the model will be saved. Pattern must be a string that can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` Defaults to `"model{suffix}.safetensors"`. max_shard_size (`int` or `str`, *optional*): The maximum size of each shard, in bytes. Defaults to 5GB. Returns: [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them. Example: ```py >>> import json >>> import os >>> from safetensors.torch import save_file as safe_save_file >>> from huggingface_hub import split_torch_state_dict_into_shards >>> def save_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: str): ... state_dict_split = split_torch_state_dict_into_shards(state_dict) ... for filename, tensors in state_dict_split.filename_to_tensors.items(): ... shard = {tensor: state_dict[tensor] for tensor in tensors} ... safe_save_file( ... shard, ... os.path.join(save_directory, filename), ... metadata={"format": "pt"}, ... ) ... if state_dict_split.is_sharded: ... index = { ... "metadata": state_dict_split.metadata, ... "weight_map": state_dict_split.tensor_to_filename, ... } ... with open(os.path.join(save_directory, "model.safetensors.index.json"), "w") as f: ... f.write(json.dumps(index, indent=2)) ``` """ return split_state_dict_into_shards_factory( state_dict, max_shard_size=max_shard_size, filename_pattern=filename_pattern, get_storage_size=get_torch_storage_size, get_storage_id=get_torch_storage_id, ) # LOADING def load_torch_model( model: "torch.nn.Module", checkpoint_path: Union[str, os.PathLike], *, strict: bool = False, safe: bool = True, weights_only: bool = False, map_location: Optional[Union[str, "torch.device"]] = None, mmap: bool = False, filename_pattern: Optional[str] = None, ) -> NamedTuple: """ Load a checkpoint into a model, handling both sharded and non-sharded checkpoints. Args: model (`torch.nn.Module`): The model in which to load the checkpoint. checkpoint_path (`str` or `os.PathLike`): Path to either the checkpoint file or directory containing the checkpoint(s). strict (`bool`, *optional*, defaults to `False`): Whether to strictly enforce that the keys in the model state dict match the keys in the checkpoint. safe (`bool`, *optional*, defaults to `True`): If `safe` is True, the safetensors files will be loaded. If `safe` is False, the function will first attempt to load safetensors files if they are available, otherwise it will fall back to loading pickle files. `filename_pattern` parameter takes precedence over `safe` parameter. weights_only (`bool`, *optional*, defaults to `False`): If True, only loads the model weights without optimizer states and other metadata. Only supported in PyTorch >= 1.13. map_location (`str` or `torch.device`, *optional*): A `torch.device` object, string or a dict specifying how to remap storage locations. It indicates the location where all tensors should be loaded. mmap (`bool`, *optional*, defaults to `False`): Whether to use memory-mapped file loading. Memory mapping can improve loading performance for large models in PyTorch >= 2.1.0 with zipfile-based checkpoints. filename_pattern (`str`, *optional*): The pattern to look for the index file. Pattern must be a string that can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` Defaults to `"model{suffix}.safetensors"`. Returns: `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields. - `missing_keys` is a list of str containing the missing keys, i.e. keys that are in the model but not in the checkpoint. - `unexpected_keys` is a list of str containing the unexpected keys, i.e. keys that are in the checkpoint but not in the model. Raises: [`FileNotFoundError`](https://docs.python.org/3/library/exceptions.html#FileNotFoundError) If the checkpoint file or directory does not exist. [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError) If safetensors or torch is not installed when trying to load a .safetensors file or a PyTorch checkpoint respectively. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If the checkpoint path is invalid or if the checkpoint format cannot be determined. Example: ```python >>> from huggingface_hub import load_torch_model >>> model = ... # A PyTorch model >>> load_torch_model(model, "path/to/checkpoint") ``` """ checkpoint_path = Path(checkpoint_path) if not checkpoint_path.exists(): raise ValueError(f"Checkpoint path {checkpoint_path} does not exist") # 1. Check if checkpoint is a single file if checkpoint_path.is_file(): state_dict = load_state_dict_from_file( checkpoint_file=checkpoint_path, map_location=map_location, weights_only=weights_only, ) return model.load_state_dict(state_dict, strict=strict) # 2. If not, checkpoint_path is a directory if filename_pattern is None: filename_pattern = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN index_path = checkpoint_path / (filename_pattern.format(suffix="") + ".index.json") # Only fallback to pickle format if safetensors index is not found and safe is False. if not index_path.is_file() and not safe: filename_pattern = constants.PYTORCH_WEIGHTS_FILE_PATTERN index_path = checkpoint_path / (filename_pattern.format(suffix="") + ".index.json") if index_path.is_file(): return _load_sharded_checkpoint( model=model, save_directory=checkpoint_path, strict=strict, weights_only=weights_only, filename_pattern=filename_pattern, ) # Look for single model file model_files = list(checkpoint_path.glob("*.safetensors" if safe else "*.bin")) if len(model_files) == 1: state_dict = load_state_dict_from_file( checkpoint_file=model_files[0], map_location=map_location, weights_only=weights_only, mmap=mmap, ) return model.load_state_dict(state_dict, strict=strict) raise ValueError( f"Directory '{checkpoint_path}' does not contain a valid checkpoint. " "Expected either a sharded checkpoint with an index file, or a single model file." ) def _load_sharded_checkpoint( model: "torch.nn.Module", save_directory: os.PathLike, *, strict: bool = False, weights_only: bool = False, filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN, ) -> NamedTuple: """ Loads a sharded checkpoint into a model. This is the same as [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict) but for a sharded checkpoint. Each shard is loaded one by one and removed from memory after being loaded into the model. Args: model (`torch.nn.Module`): The model in which to load the checkpoint. save_directory (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint. strict (`bool`, *optional*, defaults to `False`): Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint. weights_only (`bool`, *optional*, defaults to `False`): If True, only loads the model weights without optimizer states and other metadata. Only supported in PyTorch >= 1.13. filename_pattern (`str`, *optional*, defaults to `"model{suffix}.safetensors"`): The pattern to look for the index file. Pattern must be a string that can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` Defaults to `"model{suffix}.safetensors"`. Returns: `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields, - `missing_keys` is a list of str containing the missing keys - `unexpected_keys` is a list of str containing the unexpected keys """ # 1. Load and validate index file # The index file contains mapping of parameter names to shard files index_path = filename_pattern.format(suffix="") + ".index.json" index_file = os.path.join(save_directory, index_path) with open(index_file, "r", encoding="utf-8") as f: index = json.load(f) # 2. Validate keys if in strict mode # This is done before loading any shards to fail fast if strict: _validate_keys_for_strict_loading(model, index["weight_map"].keys()) # 3. Load each shard using `load_state_dict` # Get unique shard files (multiple parameters can be in same shard) shard_files = list(set(index["weight_map"].values())) for shard_file in shard_files: # Load shard into memory shard_path = os.path.join(save_directory, shard_file) state_dict = load_state_dict_from_file( shard_path, map_location="cpu", weights_only=weights_only, ) # Update model with parameters from this shard model.load_state_dict(state_dict, strict=strict) # Explicitly remove the state dict from memory del state_dict # 4. Return compatibility info loaded_keys = set(index["weight_map"].keys()) model_keys = set(model.state_dict().keys()) return _IncompatibleKeys( missing_keys=list(model_keys - loaded_keys), unexpected_keys=list(loaded_keys - model_keys) ) def load_state_dict_from_file( checkpoint_file: Union[str, os.PathLike], map_location: Optional[Union[str, "torch.device"]] = None, weights_only: bool = False, mmap: bool = False, ) -> Union[Dict[str, "torch.Tensor"], Any]: """ Loads a checkpoint file, handling both safetensors and pickle checkpoint formats. Args: checkpoint_file (`str` or `os.PathLike`): Path to the checkpoint file to load. Can be either a safetensors or pickle (`.bin`) checkpoint. map_location (`str` or `torch.device`, *optional*): A `torch.device` object, string or a dict specifying how to remap storage locations. It indicates the location where all tensors should be loaded. weights_only (`bool`, *optional*, defaults to `False`): If True, only loads the model weights without optimizer states and other metadata. Only supported for pickle (`.bin`) checkpoints with PyTorch >= 1.13. Has no effect when loading safetensors files. mmap (`bool`, *optional*, defaults to `False`): Whether to use memory-mapped file loading. Memory mapping can improve loading performance for large models in PyTorch >= 2.1.0 with zipfile-based checkpoints. Has no effect when loading safetensors files, as the `safetensors` library uses memory mapping by default. Returns: `Union[Dict[str, "torch.Tensor"], Any]`: The loaded checkpoint. - For safetensors files: always returns a dictionary mapping parameter names to tensors. - For pickle files: returns any Python object that was pickled (commonly a state dict, but could be an entire model, optimizer state, or any other Python object). Raises: [`FileNotFoundError`](https://docs.python.org/3/library/exceptions.html#FileNotFoundError) If the checkpoint file does not exist. [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError) If safetensors or torch is not installed when trying to load a .safetensors file or a PyTorch checkpoint respectively. [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) If the checkpoint file format is invalid or if git-lfs files are not properly downloaded. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If the checkpoint file path is empty or invalid. Example: ```python >>> from huggingface_hub import load_state_dict_from_file # Load a PyTorch checkpoint >>> state_dict = load_state_dict_from_file("path/to/model.bin", map_location="cpu") >>> model.load_state_dict(state_dict) # Load a safetensors checkpoint >>> state_dict = load_state_dict_from_file("path/to/model.safetensors") >>> model.load_state_dict(state_dict) ``` """ checkpoint_path = Path(checkpoint_file) # Check if file exists and is a regular file (not a directory) if not checkpoint_path.is_file(): raise FileNotFoundError( f"No checkpoint file found at '{checkpoint_path}'. Please verify the path is correct and " "the file has been properly downloaded." ) # Load safetensors checkpoint if checkpoint_path.suffix == ".safetensors": try: from safetensors import safe_open from safetensors.torch import load_file except ImportError as e: raise ImportError( "Please install `safetensors` to load safetensors checkpoint. " "You can install it with `pip install safetensors`." ) from e # Check format of the archive with safe_open(checkpoint_file, framework="pt") as f: # type: ignore[attr-defined] metadata = f.metadata() # see comment: https://github.com/huggingface/transformers/blob/3d213b57fe74302e5902d68ed9478c3ad1aaa713/src/transformers/modeling_utils.py#L3966 if metadata is not None and metadata.get("format") not in ["pt", "mlx"]: raise OSError( f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " "you save your model with the `save_torch_model` method." ) device = str(map_location.type) if map_location is not None and hasattr(map_location, "type") else map_location # meta device is not supported with safetensors, falling back to CPU if device == "meta": logger.warning("Meta device is not supported with safetensors. Falling back to CPU device.") device = "cpu" return load_file(checkpoint_file, device=device) # type: ignore[arg-type] # Otherwise, load from pickle try: import torch from torch import load except ImportError as e: raise ImportError( "Please install `torch` to load torch tensors. You can install it with `pip install torch`." ) from e # Add additional kwargs, mmap is only supported in torch >= 2.1.0 additional_kwargs = {} if version.parse(torch.__version__) >= version.parse("2.1.0"): additional_kwargs["mmap"] = mmap # weights_only is only supported in torch >= 1.13.0 if version.parse(torch.__version__) >= version.parse("1.13.0"): additional_kwargs["weights_only"] = weights_only return load( checkpoint_file, map_location=map_location, **additional_kwargs, ) # HELPERS def _validate_keys_for_strict_loading( model: "torch.nn.Module", loaded_keys: Iterable[str], ) -> None: """ Validate that model keys match loaded keys when strict loading is enabled. Args: model: The PyTorch model being loaded loaded_keys: The keys present in the checkpoint Raises: RuntimeError: If there are missing or unexpected keys in strict mode """ loaded_keys_set = set(loaded_keys) model_keys = set(model.state_dict().keys()) missing_keys = model_keys - loaded_keys_set # Keys in model but not in checkpoint unexpected_keys = loaded_keys_set - model_keys # Keys in checkpoint but not in model if missing_keys or unexpected_keys: error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}" if missing_keys: str_missing_keys = ",".join([f'"{k}"' for k in sorted(missing_keys)]) error_message += f"\nMissing key(s): {str_missing_keys}." if unexpected_keys: str_unexpected_keys = ",".join([f'"{k}"' for k in sorted(unexpected_keys)]) error_message += f"\nUnexpected key(s): {str_unexpected_keys}." raise RuntimeError(error_message) def _get_unique_id(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]: """Returns a unique id for plain tensor or a (potentially nested) Tuple of unique id for the flattened Tensor if the input is a wrapper tensor subclass Tensor """ try: # for torch 2.1 and above we can also handle tensor subclasses from torch.utils._python_dispatch import is_traceable_wrapper_subclass if is_traceable_wrapper_subclass(tensor): attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined] return tuple(_get_unique_id(getattr(tensor, attr)) for attr in attrs) except ImportError: # for torch version less than 2.1, we can fallback to original implementation pass if tensor.device.type == "xla" and is_torch_tpu_available(): # NOTE: xla tensors dont have storage # use some other unique id to distinguish. # this is a XLA tensor, it must be created using torch_xla's # device. So the following import is safe: import torch_xla # type: ignore[import] unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor) else: unique_id = storage_ptr(tensor) return unique_id def get_torch_storage_id(tensor: "torch.Tensor") -> Optional[Tuple["torch.device", Union[int, Tuple[Any, ...]], int]]: """ Return unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. This identifier is guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with non-overlapping lifetimes may have the same id. In the case of meta tensors, we return None since we can't tell if they share the same storage. Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/pytorch_utils.py#L278. """ if tensor.device.type == "meta": return None else: return tensor.device, _get_unique_id(tensor), get_torch_storage_size(tensor) def get_torch_storage_size(tensor: "torch.Tensor") -> int: """ Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L31C1-L41C59 """ try: # for torch 2.1 and above we can also handle tensor subclasses from torch.utils._python_dispatch import is_traceable_wrapper_subclass if is_traceable_wrapper_subclass(tensor): attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined] return sum(get_torch_storage_size(getattr(tensor, attr)) for attr in attrs) except ImportError: # for torch version less than 2.1, we can fallback to original implementation pass try: return tensor.untyped_storage().nbytes() except AttributeError: # Fallback for torch==1.10 try: return tensor.storage().size() * _get_dtype_size(tensor.dtype) except NotImplementedError: # Fallback for meta storage # On torch >=2.0 this is the tensor size return tensor.nelement() * _get_dtype_size(tensor.dtype) @lru_cache() def is_torch_tpu_available(check_device=True): """ Checks if `torch_xla` is installed and potentially if a TPU is in the environment Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/utils/import_utils.py#L463. """ if importlib.util.find_spec("torch_xla") is not None: if check_device: # We need to check if `xla_device` can be found, will raise a RuntimeError if not try: import torch_xla.core.xla_model as xm # type: ignore[import] _ = xm.xla_device() return True except RuntimeError: return False return True return False def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]: """ Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L11. """ try: # for torch 2.1 and above we can also handle tensor subclasses from torch.utils._python_dispatch import is_traceable_wrapper_subclass if is_traceable_wrapper_subclass(tensor): return _get_unique_id(tensor) # type: ignore except ImportError: # for torch version less than 2.1, we can fallback to original implementation pass try: return tensor.untyped_storage().data_ptr() except Exception: # Fallback for torch==1.10 try: return tensor.storage().data_ptr() except NotImplementedError: # Fallback for meta storage return 0 def _clean_state_dict_for_safetensors( state_dict: Dict[str, "torch.Tensor"], metadata: Dict[str, str], force_contiguous: bool = True, shared_tensors_to_discard: Optional[List[str]] = None, ): """Remove shared tensors from state_dict and update metadata accordingly (for reloading). Warning: `state_dict` and `metadata` are mutated in-place! Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L155. """ to_removes = _remove_duplicate_names(state_dict, discard_names=shared_tensors_to_discard) for kept_name, to_remove_group in to_removes.items(): for to_remove in to_remove_group: if metadata is None: metadata = {} if to_remove not in metadata: # Do not override user data metadata[to_remove] = kept_name del state_dict[to_remove] if force_contiguous: state_dict = {k: v.contiguous() for k, v in state_dict.items()} return state_dict def _end_ptr(tensor: "torch.Tensor") -> int: """ Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L23. """ if tensor.nelement(): stop = tensor.view(-1)[-1].data_ptr() + _get_dtype_size(tensor.dtype) else: stop = tensor.data_ptr() return stop def _filter_shared_not_shared(tensors: List[Set[str]], state_dict: Dict[str, "torch.Tensor"]) -> List[Set[str]]: """ Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L44 """ filtered_tensors = [] for shared in tensors: if len(shared) < 2: filtered_tensors.append(shared) continue areas = [] for name in shared: tensor = state_dict[name] areas.append((tensor.data_ptr(), _end_ptr(tensor), name)) areas.sort() _, last_stop, last_name = areas[0] filtered_tensors.append({last_name}) for start, stop, name in areas[1:]: if start >= last_stop: filtered_tensors.append({name}) else: filtered_tensors[-1].add(name) last_stop = stop return filtered_tensors def _find_shared_tensors(state_dict: Dict[str, "torch.Tensor"]) -> List[Set[str]]: """ Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L69. """ import torch tensors_dict = defaultdict(set) for k, v in state_dict.items(): if v.device != torch.device("meta") and storage_ptr(v) != 0 and get_torch_storage_size(v) != 0: # Need to add device as key because of multiple GPU. tensors_dict[(v.device, storage_ptr(v), get_torch_storage_size(v))].add(k) tensors = list(sorted(tensors_dict.values())) tensors = _filter_shared_not_shared(tensors, state_dict) return tensors def _is_complete(tensor: "torch.Tensor") -> bool: """ Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80 """ try: # for torch 2.1 and above we can also handle tensor subclasses from torch.utils._python_dispatch import is_traceable_wrapper_subclass if is_traceable_wrapper_subclass(tensor): attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined] return all(_is_complete(getattr(tensor, attr)) for attr in attrs) except ImportError: # for torch version less than 2.1, we can fallback to original implementation pass return tensor.data_ptr() == storage_ptr(tensor) and tensor.nelement() * _get_dtype_size( tensor.dtype ) == get_torch_storage_size(tensor) def _remove_duplicate_names( state_dict: Dict[str, "torch.Tensor"], *, preferred_names: Optional[List[str]] = None, discard_names: Optional[List[str]] = None, ) -> Dict[str, List[str]]: """ Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80 """ if preferred_names is None: preferred_names = [] unique_preferred_names = set(preferred_names) if discard_names is None: discard_names = [] unique_discard_names = set(discard_names) shareds = _find_shared_tensors(state_dict) to_remove = defaultdict(list) for shared in shareds: complete_names = set([name for name in shared if _is_complete(state_dict[name])]) if not complete_names: raise RuntimeError( "Error while trying to find names to remove to save state dict, but found no suitable name to keep" f" for saving amongst: {shared}. None is covering the entire storage. Refusing to save/load the model" " since you could be storing much more memory than needed. Please refer to" " https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an" " issue." ) keep_name = sorted(list(complete_names))[0] # Mechanism to preferentially select keys to keep # coming from the on-disk file to allow # loading models saved with a different choice # of keep_name preferred = complete_names.difference(unique_discard_names) if preferred: keep_name = sorted(list(preferred))[0] if unique_preferred_names: preferred = unique_preferred_names.intersection(complete_names) if preferred: keep_name = sorted(list(preferred))[0] for name in sorted(shared): if name != keep_name: to_remove[keep_name].append(name) return to_remove @lru_cache() def _get_dtype_size(dtype: "torch.dtype") -> int: """ Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L344 """ import torch # torch.float8 formats require 2.1; we do not support these dtypes on earlier versions _float8_e4m3fn = getattr(torch, "float8_e4m3fn", None) _float8_e5m2 = getattr(torch, "float8_e5m2", None) _SIZE = { torch.int64: 8, torch.float32: 4, torch.int32: 4, torch.bfloat16: 2, torch.float16: 2, torch.int16: 2, torch.uint8: 1, torch.int8: 1, torch.bool: 1, torch.float64: 8, _float8_e4m3fn: 1, _float8_e5m2: 1, } return _SIZE[dtype] class _IncompatibleKeys(namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"])): """ This is used to report missing and unexpected keys in the state dict. Taken from https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L52. """ def __repr__(self) -> str: if not self.missing_keys and not self.unexpected_keys: return "" return super().__repr__() __str__ = __repr__ huggingface_hub-0.31.1/src/huggingface_hub/templates/000077500000000000000000000000001500667546600226125ustar00rootroot00000000000000huggingface_hub-0.31.1/src/huggingface_hub/templates/datasetcard_template.md000066400000000000000000000125771500667546600273220ustar00rootroot00000000000000--- # For reference on dataset card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/datasetcard.md?plain=1 # Doc / guide: https://huggingface.co/docs/hub/datasets-cards {{ card_data }} --- # Dataset Card for {{ pretty_name | default("Dataset Name", true) }} {{ dataset_summary | default("", true) }} ## Dataset Details ### Dataset Description {{ dataset_description | default("", true) }} - **Curated by:** {{ curators | default("[More Information Needed]", true)}} - **Funded by [optional]:** {{ funded_by | default("[More Information Needed]", true)}} - **Shared by [optional]:** {{ shared_by | default("[More Information Needed]", true)}} - **Language(s) (NLP):** {{ language | default("[More Information Needed]", true)}} - **License:** {{ license | default("[More Information Needed]", true)}} ### Dataset Sources [optional] - **Repository:** {{ repo | default("[More Information Needed]", true)}} - **Paper [optional]:** {{ paper | default("[More Information Needed]", true)}} - **Demo [optional]:** {{ demo | default("[More Information Needed]", true)}} ## Uses ### Direct Use {{ direct_use | default("[More Information Needed]", true)}} ### Out-of-Scope Use {{ out_of_scope_use | default("[More Information Needed]", true)}} ## Dataset Structure {{ dataset_structure | default("[More Information Needed]", true)}} ## Dataset Creation ### Curation Rationale {{ curation_rationale_section | default("[More Information Needed]", true)}} ### Source Data #### Data Collection and Processing {{ data_collection_and_processing_section | default("[More Information Needed]", true)}} #### Who are the source data producers? {{ source_data_producers_section | default("[More Information Needed]", true)}} ### Annotations [optional] #### Annotation process {{ annotation_process_section | default("[More Information Needed]", true)}} #### Who are the annotators? {{ who_are_annotators_section | default("[More Information Needed]", true)}} #### Personal and Sensitive Information {{ personal_and_sensitive_information | default("[More Information Needed]", true)}} ## Bias, Risks, and Limitations {{ bias_risks_limitations | default("[More Information Needed]", true)}} ### Recommendations {{ bias_recommendations | default("Users should be made aware of the risks, biases and limitations of the dataset. More information needed for further recommendations.", true)}} ## Citation [optional] **BibTeX:** {{ citation_bibtex | default("[More Information Needed]", true)}} **APA:** {{ citation_apa | default("[More Information Needed]", true)}} ## Glossary [optional] {{ glossary | default("[More Information Needed]", true)}} ## More Information [optional] {{ more_information | default("[More Information Needed]", true)}} ## Dataset Card Authors [optional] {{ dataset_card_authors | default("[More Information Needed]", true)}} ## Dataset Card Contact {{ dataset_card_contact | default("[More Information Needed]", true)}} huggingface_hub-0.31.1/src/huggingface_hub/templates/modelcard_template.md000066400000000000000000000153261500667546600267700ustar00rootroot00000000000000--- # For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 # Doc / guide: https://huggingface.co/docs/hub/model-cards {{ card_data }} --- # Model Card for {{ model_id | default("Model ID", true) }} {{ model_summary | default("", true) }} ## Model Details ### Model Description {{ model_description | default("", true) }} - **Developed by:** {{ developers | default("[More Information Needed]", true)}} - **Funded by [optional]:** {{ funded_by | default("[More Information Needed]", true)}} - **Shared by [optional]:** {{ shared_by | default("[More Information Needed]", true)}} - **Model type:** {{ model_type | default("[More Information Needed]", true)}} - **Language(s) (NLP):** {{ language | default("[More Information Needed]", true)}} - **License:** {{ license | default("[More Information Needed]", true)}} - **Finetuned from model [optional]:** {{ base_model | default("[More Information Needed]", true)}} ### Model Sources [optional] - **Repository:** {{ repo | default("[More Information Needed]", true)}} - **Paper [optional]:** {{ paper | default("[More Information Needed]", true)}} - **Demo [optional]:** {{ demo | default("[More Information Needed]", true)}} ## Uses ### Direct Use {{ direct_use | default("[More Information Needed]", true)}} ### Downstream Use [optional] {{ downstream_use | default("[More Information Needed]", true)}} ### Out-of-Scope Use {{ out_of_scope_use | default("[More Information Needed]", true)}} ## Bias, Risks, and Limitations {{ bias_risks_limitations | default("[More Information Needed]", true)}} ### Recommendations {{ bias_recommendations | default("Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.", true)}} ## How to Get Started with the Model Use the code below to get started with the model. {{ get_started_code | default("[More Information Needed]", true)}} ## Training Details ### Training Data {{ training_data | default("[More Information Needed]", true)}} ### Training Procedure #### Preprocessing [optional] {{ preprocessing | default("[More Information Needed]", true)}} #### Training Hyperparameters - **Training regime:** {{ training_regime | default("[More Information Needed]", true)}} #### Speeds, Sizes, Times [optional] {{ speeds_sizes_times | default("[More Information Needed]", true)}} ## Evaluation ### Testing Data, Factors & Metrics #### Testing Data {{ testing_data | default("[More Information Needed]", true)}} #### Factors {{ testing_factors | default("[More Information Needed]", true)}} #### Metrics {{ testing_metrics | default("[More Information Needed]", true)}} ### Results {{ results | default("[More Information Needed]", true)}} #### Summary {{ results_summary | default("", true) }} ## Model Examination [optional] {{ model_examination | default("[More Information Needed]", true)}} ## Environmental Impact Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). - **Hardware Type:** {{ hardware_type | default("[More Information Needed]", true)}} - **Hours used:** {{ hours_used | default("[More Information Needed]", true)}} - **Cloud Provider:** {{ cloud_provider | default("[More Information Needed]", true)}} - **Compute Region:** {{ cloud_region | default("[More Information Needed]", true)}} - **Carbon Emitted:** {{ co2_emitted | default("[More Information Needed]", true)}} ## Technical Specifications [optional] ### Model Architecture and Objective {{ model_specs | default("[More Information Needed]", true)}} ### Compute Infrastructure {{ compute_infrastructure | default("[More Information Needed]", true)}} #### Hardware {{ hardware_requirements | default("[More Information Needed]", true)}} #### Software {{ software | default("[More Information Needed]", true)}} ## Citation [optional] **BibTeX:** {{ citation_bibtex | default("[More Information Needed]", true)}} **APA:** {{ citation_apa | default("[More Information Needed]", true)}} ## Glossary [optional] {{ glossary | default("[More Information Needed]", true)}} ## More Information [optional] {{ more_information | default("[More Information Needed]", true)}} ## Model Card Authors [optional] {{ model_card_authors | default("[More Information Needed]", true)}} ## Model Card Contact {{ model_card_contact | default("[More Information Needed]", true)}} huggingface_hub-0.31.1/src/huggingface_hub/utils/000077500000000000000000000000001500667546600217545ustar00rootroot00000000000000huggingface_hub-0.31.1/src/huggingface_hub/utils/__init__.py000066400000000000000000000072131500667546600240700ustar00rootroot00000000000000# coding=utf-8 # Copyright 2021 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License # ruff: noqa: F401 from huggingface_hub.errors import ( BadRequestError, CacheNotFound, CorruptedCacheException, DisabledRepoError, EntryNotFoundError, FileMetadataError, GatedRepoError, HfHubHTTPError, HFValidationError, LocalEntryNotFoundError, LocalTokenNotFoundError, NotASafetensorsRepoError, OfflineModeIsEnabled, RepositoryNotFoundError, RevisionNotFoundError, SafetensorsParsingError, ) from . import tqdm as _tqdm # _tqdm is the module from ._auth import get_stored_tokens, get_token from ._cache_assets import cached_assets_path from ._cache_manager import ( CachedFileInfo, CachedRepoInfo, CachedRevisionInfo, DeleteCacheStrategy, HFCacheInfo, scan_cache_dir, ) from ._chunk_utils import chunk_iterable from ._datetime import parse_datetime from ._experimental import experimental from ._fixes import SoftTemporaryDirectory, WeakFileLock, yaml_dump from ._git_credential import list_credential_helpers, set_git_credential, unset_git_credential from ._headers import build_hf_headers, get_token_to_send from ._hf_folder import HfFolder from ._http import ( configure_http_backend, fix_hf_endpoint_in_url, get_session, hf_raise_for_status, http_backoff, reset_sessions, ) from ._pagination import paginate from ._paths import DEFAULT_IGNORE_PATTERNS, FORBIDDEN_FOLDERS, filter_repo_objects from ._runtime import ( dump_environment_info, get_aiohttp_version, get_fastai_version, get_fastapi_version, get_fastcore_version, get_gradio_version, get_graphviz_version, get_hf_hub_version, get_hf_transfer_version, get_jinja_version, get_numpy_version, get_pillow_version, get_pydantic_version, get_pydot_version, get_python_version, get_tensorboard_version, get_tf_version, get_torch_version, is_aiohttp_available, is_colab_enterprise, is_fastai_available, is_fastapi_available, is_fastcore_available, is_google_colab, is_gradio_available, is_graphviz_available, is_hf_transfer_available, is_jinja_available, is_notebook, is_numpy_available, is_package_available, is_pillow_available, is_pydantic_available, is_pydot_available, is_safetensors_available, is_tensorboard_available, is_tf_available, is_torch_available, ) from ._safetensors import SafetensorsFileMetadata, SafetensorsRepoMetadata, TensorInfo from ._subprocess import capture_output, run_interactive_subprocess, run_subprocess from ._telemetry import send_telemetry from ._typing import is_jsonable, is_simple_optional_type, unwrap_simple_optional_type from ._validators import smoothly_deprecate_use_auth_token, validate_hf_hub_args, validate_repo_id from ._xet import ( XetConnectionInfo, XetFileData, XetTokenType, fetch_xet_connection_info_from_repo_info, parse_xet_file_data_from_response, refresh_xet_connection_info, ) from .tqdm import are_progress_bars_disabled, disable_progress_bars, enable_progress_bars, tqdm, tqdm_stream_file huggingface_hub-0.31.1/src/huggingface_hub/utils/_auth.py000066400000000000000000000201461500667546600234310ustar00rootroot00000000000000# Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains an helper to get the token from machine (env variable, secret or config file).""" import configparser import logging import os import warnings from pathlib import Path from threading import Lock from typing import Dict, Optional from .. import constants from ._runtime import is_colab_enterprise, is_google_colab _IS_GOOGLE_COLAB_CHECKED = False _GOOGLE_COLAB_SECRET_LOCK = Lock() _GOOGLE_COLAB_SECRET: Optional[str] = None logger = logging.getLogger(__name__) def get_token() -> Optional[str]: """ Get token if user is logged in. Note: in most cases, you should use [`huggingface_hub.utils.build_hf_headers`] instead. This method is only useful if you want to retrieve the token for other purposes than sending an HTTP request. Token is retrieved in priority from the `HF_TOKEN` environment variable. Otherwise, we read the token file located in the Hugging Face home folder. Returns None if user is not logged in. To log in, use [`login`] or `huggingface-cli login`. Returns: `str` or `None`: The token, `None` if it doesn't exist. """ return _get_token_from_google_colab() or _get_token_from_environment() or _get_token_from_file() def _get_token_from_google_colab() -> Optional[str]: """Get token from Google Colab secrets vault using `google.colab.userdata.get(...)`. Token is read from the vault only once per session and then stored in a global variable to avoid re-requesting access to the vault. """ # If it's not a Google Colab or it's Colab Enterprise, fallback to environment variable or token file authentication if not is_google_colab() or is_colab_enterprise(): return None # `google.colab.userdata` is not thread-safe # This can lead to a deadlock if multiple threads try to access it at the same time # (typically when using `snapshot_download`) # => use a lock # See https://github.com/huggingface/huggingface_hub/issues/1952 for more details. with _GOOGLE_COLAB_SECRET_LOCK: global _GOOGLE_COLAB_SECRET global _IS_GOOGLE_COLAB_CHECKED if _IS_GOOGLE_COLAB_CHECKED: # request access only once return _GOOGLE_COLAB_SECRET try: from google.colab import userdata # type: ignore from google.colab.errors import Error as ColabError # type: ignore except ImportError: return None try: token = userdata.get("HF_TOKEN") _GOOGLE_COLAB_SECRET = _clean_token(token) except userdata.NotebookAccessError: # Means the user has a secret call `HF_TOKEN` and got a popup "please grand access to HF_TOKEN" and refused it # => warn user but ignore error => do not re-request access to user warnings.warn( "\nAccess to the secret `HF_TOKEN` has not been granted on this notebook." "\nYou will not be requested again." "\nPlease restart the session if you want to be prompted again." ) _GOOGLE_COLAB_SECRET = None except userdata.SecretNotFoundError: # Means the user did not define a `HF_TOKEN` secret => warn warnings.warn( "\nThe secret `HF_TOKEN` does not exist in your Colab secrets." "\nTo authenticate with the Hugging Face Hub, create a token in your settings tab " "(https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session." "\nYou will be able to reuse this secret in all of your notebooks." "\nPlease note that authentication is recommended but still optional to access public models or datasets." ) _GOOGLE_COLAB_SECRET = None except ColabError as e: # Something happen but we don't know what => recommend to open a GitHub issue warnings.warn( f"\nError while fetching `HF_TOKEN` secret value from your vault: '{str(e)}'." "\nYou are not authenticated with the Hugging Face Hub in this notebook." "\nIf the error persists, please let us know by opening an issue on GitHub " "(https://github.com/huggingface/huggingface_hub/issues/new)." ) _GOOGLE_COLAB_SECRET = None _IS_GOOGLE_COLAB_CHECKED = True return _GOOGLE_COLAB_SECRET def _get_token_from_environment() -> Optional[str]: # `HF_TOKEN` has priority (keep `HUGGING_FACE_HUB_TOKEN` for backward compatibility) return _clean_token(os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")) def _get_token_from_file() -> Optional[str]: try: return _clean_token(Path(constants.HF_TOKEN_PATH).read_text()) except FileNotFoundError: return None def get_stored_tokens() -> Dict[str, str]: """ Returns the parsed INI file containing the access tokens. The file is located at `HF_STORED_TOKENS_PATH`, defaulting to `~/.cache/huggingface/stored_tokens`. If the file does not exist, an empty dictionary is returned. Returns: `Dict[str, str]` Key is the token name and value is the token. """ tokens_path = Path(constants.HF_STORED_TOKENS_PATH) if not tokens_path.exists(): stored_tokens = {} config = configparser.ConfigParser() try: config.read(tokens_path) stored_tokens = {token_name: config.get(token_name, "hf_token") for token_name in config.sections()} except configparser.Error as e: logger.error(f"Error parsing stored tokens file: {e}") stored_tokens = {} return stored_tokens def _save_stored_tokens(stored_tokens: Dict[str, str]) -> None: """ Saves the given configuration to the stored tokens file. Args: stored_tokens (`Dict[str, str]`): The stored tokens to save. Key is the token name and value is the token. """ stored_tokens_path = Path(constants.HF_STORED_TOKENS_PATH) # Write the stored tokens into an INI file config = configparser.ConfigParser() for token_name in sorted(stored_tokens.keys()): config.add_section(token_name) config.set(token_name, "hf_token", stored_tokens[token_name]) stored_tokens_path.parent.mkdir(parents=True, exist_ok=True) with stored_tokens_path.open("w") as config_file: config.write(config_file) def _get_token_by_name(token_name: str) -> Optional[str]: """ Get the token by name. Args: token_name (`str`): The name of the token to get. Returns: `str` or `None`: The token, `None` if it doesn't exist. """ stored_tokens = get_stored_tokens() if token_name not in stored_tokens: return None return _clean_token(stored_tokens[token_name]) def _save_token(token: str, token_name: str) -> None: """ Save the given token. If the stored tokens file does not exist, it will be created. Args: token (`str`): The token to save. token_name (`str`): The name of the token. """ tokens_path = Path(constants.HF_STORED_TOKENS_PATH) stored_tokens = get_stored_tokens() stored_tokens[token_name] = token _save_stored_tokens(stored_tokens) logger.info(f"The token `{token_name}` has been saved to {tokens_path}") def _clean_token(token: Optional[str]) -> Optional[str]: """Clean token by removing trailing and leading spaces and newlines. If token is an empty string, return None. """ if token is None: return None return token.replace("\r", "").replace("\n", "").strip() or None huggingface_hub-0.31.1/src/huggingface_hub/utils/_cache_assets.py000066400000000000000000000131401500667546600251110ustar00rootroot00000000000000# coding=utf-8 # Copyright 2019-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from pathlib import Path from typing import Union from ..constants import HF_ASSETS_CACHE def cached_assets_path( library_name: str, namespace: str = "default", subfolder: str = "default", *, assets_dir: Union[str, Path, None] = None, ): """Return a folder path to cache arbitrary files. `huggingface_hub` provides a canonical folder path to store assets. This is the recommended way to integrate cache in a downstream library as it will benefit from the builtins tools to scan and delete the cache properly. The distinction is made between files cached from the Hub and assets. Files from the Hub are cached in a git-aware manner and entirely managed by `huggingface_hub`. See [related documentation](https://huggingface.co/docs/huggingface_hub/how-to-cache). All other files that a downstream library caches are considered to be "assets" (files downloaded from external sources, extracted from a .tar archive, preprocessed for training,...). Once the folder path is generated, it is guaranteed to exist and to be a directory. The path is based on 3 levels of depth: the library name, a namespace and a subfolder. Those 3 levels grants flexibility while allowing `huggingface_hub` to expect folders when scanning/deleting parts of the assets cache. Within a library, it is expected that all namespaces share the same subset of subfolder names but this is not a mandatory rule. The downstream library has then full control on which file structure to adopt within its cache. Namespace and subfolder are optional (would default to a `"default/"` subfolder) but library name is mandatory as we want every downstream library to manage its own cache. Expected tree: ```text assets/ └── datasets/ │ ├── SQuAD/ │ │ ├── downloaded/ │ │ ├── extracted/ │ │ └── processed/ │ ├── Helsinki-NLP--tatoeba_mt/ │ ├── downloaded/ │ ├── extracted/ │ └── processed/ └── transformers/ ├── default/ │ ├── something/ ├── bert-base-cased/ │ ├── default/ │ └── training/ hub/ └── models--julien-c--EsperBERTo-small/ ├── blobs/ │ ├── (...) │ ├── (...) ├── refs/ │ └── (...) └── [ 128] snapshots/ ├── 2439f60ef33a0d46d85da5001d52aeda5b00ce9f/ │ ├── (...) └── bbc77c8132af1cc5cf678da3f1ddf2de43606d48/ └── (...) ``` Args: library_name (`str`): Name of the library that will manage the cache folder. Example: `"dataset"`. namespace (`str`, *optional*, defaults to "default"): Namespace to which the data belongs. Example: `"SQuAD"`. subfolder (`str`, *optional*, defaults to "default"): Subfolder in which the data will be stored. Example: `extracted`. assets_dir (`str`, `Path`, *optional*): Path to the folder where assets are cached. This must not be the same folder where Hub files are cached. Defaults to `HF_HOME / "assets"` if not provided. Can also be set with `HF_ASSETS_CACHE` environment variable. Returns: Path to the cache folder (`Path`). Example: ```py >>> from huggingface_hub import cached_assets_path >>> cached_assets_path(library_name="datasets", namespace="SQuAD", subfolder="download") PosixPath('/home/wauplin/.cache/huggingface/extra/datasets/SQuAD/download') >>> cached_assets_path(library_name="datasets", namespace="SQuAD", subfolder="extracted") PosixPath('/home/wauplin/.cache/huggingface/extra/datasets/SQuAD/extracted') >>> cached_assets_path(library_name="datasets", namespace="Helsinki-NLP/tatoeba_mt") PosixPath('/home/wauplin/.cache/huggingface/extra/datasets/Helsinki-NLP--tatoeba_mt/default') >>> cached_assets_path(library_name="datasets", assets_dir="/tmp/tmp123456") PosixPath('/tmp/tmp123456/datasets/default/default') ``` """ # Resolve assets_dir if assets_dir is None: assets_dir = HF_ASSETS_CACHE assets_dir = Path(assets_dir).expanduser().resolve() # Avoid names that could create path issues for part in (" ", "/", "\\"): library_name = library_name.replace(part, "--") namespace = namespace.replace(part, "--") subfolder = subfolder.replace(part, "--") # Path to subfolder is created path = assets_dir / library_name / namespace / subfolder try: path.mkdir(exist_ok=True, parents=True) except (FileExistsError, NotADirectoryError): raise ValueError(f"Corrupted assets folder: cannot create directory because of an existing file ({path}).") # Return return path huggingface_hub-0.31.1/src/huggingface_hub/utils/_cache_manager.py000066400000000000000000001033431500667546600252260ustar00rootroot00000000000000# coding=utf-8 # Copyright 2022-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains utilities to manage the HF cache directory.""" import os import shutil import time from collections import defaultdict from dataclasses import dataclass from pathlib import Path from typing import Dict, FrozenSet, List, Literal, Optional, Set, Union from huggingface_hub.errors import CacheNotFound, CorruptedCacheException from ..commands._cli_utils import tabulate from ..constants import HF_HUB_CACHE from . import logging logger = logging.get_logger(__name__) REPO_TYPE_T = Literal["model", "dataset", "space"] # List of OS-created helper files that need to be ignored FILES_TO_IGNORE = [".DS_Store"] @dataclass(frozen=True) class CachedFileInfo: """Frozen data structure holding information about a single cached file. Args: file_name (`str`): Name of the file. Example: `config.json`. file_path (`Path`): Path of the file in the `snapshots` directory. The file path is a symlink referring to a blob in the `blobs` folder. blob_path (`Path`): Path of the blob file. This is equivalent to `file_path.resolve()`. size_on_disk (`int`): Size of the blob file in bytes. blob_last_accessed (`float`): Timestamp of the last time the blob file has been accessed (from any revision). blob_last_modified (`float`): Timestamp of the last time the blob file has been modified/created. `blob_last_accessed` and `blob_last_modified` reliability can depend on the OS you are using. See [python documentation](https://docs.python.org/3/library/os.html#os.stat_result) for more details. """ file_name: str file_path: Path blob_path: Path size_on_disk: int blob_last_accessed: float blob_last_modified: float @property def blob_last_accessed_str(self) -> str: """ (property) Timestamp of the last time the blob file has been accessed (from any revision), returned as a human-readable string. Example: "2 weeks ago". """ return _format_timesince(self.blob_last_accessed) @property def blob_last_modified_str(self) -> str: """ (property) Timestamp of the last time the blob file has been modified, returned as a human-readable string. Example: "2 weeks ago". """ return _format_timesince(self.blob_last_modified) @property def size_on_disk_str(self) -> str: """ (property) Size of the blob file as a human-readable string. Example: "42.2K". """ return _format_size(self.size_on_disk) @dataclass(frozen=True) class CachedRevisionInfo: """Frozen data structure holding information about a revision. A revision correspond to a folder in the `snapshots` folder and is populated with the exact tree structure as the repo on the Hub but contains only symlinks. A revision can be either referenced by 1 or more `refs` or be "detached" (no refs). Args: commit_hash (`str`): Hash of the revision (unique). Example: `"9338f7b671827df886678df2bdd7cc7b4f36dffd"`. snapshot_path (`Path`): Path to the revision directory in the `snapshots` folder. It contains the exact tree structure as the repo on the Hub. files: (`FrozenSet[CachedFileInfo]`): Set of [`~CachedFileInfo`] describing all files contained in the snapshot. refs (`FrozenSet[str]`): Set of `refs` pointing to this revision. If the revision has no `refs`, it is considered detached. Example: `{"main", "2.4.0"}` or `{"refs/pr/1"}`. size_on_disk (`int`): Sum of the blob file sizes that are symlink-ed by the revision. last_modified (`float`): Timestamp of the last time the revision has been created/modified. `last_accessed` cannot be determined correctly on a single revision as blob files are shared across revisions. `size_on_disk` is not necessarily the sum of all file sizes because of possible duplicated files. Besides, only blobs are taken into account, not the (negligible) size of folders and symlinks. """ commit_hash: str snapshot_path: Path size_on_disk: int files: FrozenSet[CachedFileInfo] refs: FrozenSet[str] last_modified: float @property def last_modified_str(self) -> str: """ (property) Timestamp of the last time the revision has been modified, returned as a human-readable string. Example: "2 weeks ago". """ return _format_timesince(self.last_modified) @property def size_on_disk_str(self) -> str: """ (property) Sum of the blob file sizes as a human-readable string. Example: "42.2K". """ return _format_size(self.size_on_disk) @property def nb_files(self) -> int: """ (property) Total number of files in the revision. """ return len(self.files) @dataclass(frozen=True) class CachedRepoInfo: """Frozen data structure holding information about a cached repository. Args: repo_id (`str`): Repo id of the repo on the Hub. Example: `"google/fleurs"`. repo_type (`Literal["dataset", "model", "space"]`): Type of the cached repo. repo_path (`Path`): Local path to the cached repo. size_on_disk (`int`): Sum of the blob file sizes in the cached repo. nb_files (`int`): Total number of blob files in the cached repo. revisions (`FrozenSet[CachedRevisionInfo]`): Set of [`~CachedRevisionInfo`] describing all revisions cached in the repo. last_accessed (`float`): Timestamp of the last time a blob file of the repo has been accessed. last_modified (`float`): Timestamp of the last time a blob file of the repo has been modified/created. `size_on_disk` is not necessarily the sum of all revisions sizes because of duplicated files. Besides, only blobs are taken into account, not the (negligible) size of folders and symlinks. `last_accessed` and `last_modified` reliability can depend on the OS you are using. See [python documentation](https://docs.python.org/3/library/os.html#os.stat_result) for more details. """ repo_id: str repo_type: REPO_TYPE_T repo_path: Path size_on_disk: int nb_files: int revisions: FrozenSet[CachedRevisionInfo] last_accessed: float last_modified: float @property def last_accessed_str(self) -> str: """ (property) Last time a blob file of the repo has been accessed, returned as a human-readable string. Example: "2 weeks ago". """ return _format_timesince(self.last_accessed) @property def last_modified_str(self) -> str: """ (property) Last time a blob file of the repo has been modified, returned as a human-readable string. Example: "2 weeks ago". """ return _format_timesince(self.last_modified) @property def size_on_disk_str(self) -> str: """ (property) Sum of the blob file sizes as a human-readable string. Example: "42.2K". """ return _format_size(self.size_on_disk) @property def refs(self) -> Dict[str, CachedRevisionInfo]: """ (property) Mapping between `refs` and revision data structures. """ return {ref: revision for revision in self.revisions for ref in revision.refs} @dataclass(frozen=True) class DeleteCacheStrategy: """Frozen data structure holding the strategy to delete cached revisions. This object is not meant to be instantiated programmatically but to be returned by [`~utils.HFCacheInfo.delete_revisions`]. See documentation for usage example. Args: expected_freed_size (`float`): Expected freed size once strategy is executed. blobs (`FrozenSet[Path]`): Set of blob file paths to be deleted. refs (`FrozenSet[Path]`): Set of reference file paths to be deleted. repos (`FrozenSet[Path]`): Set of entire repo paths to be deleted. snapshots (`FrozenSet[Path]`): Set of snapshots to be deleted (directory of symlinks). """ expected_freed_size: int blobs: FrozenSet[Path] refs: FrozenSet[Path] repos: FrozenSet[Path] snapshots: FrozenSet[Path] @property def expected_freed_size_str(self) -> str: """ (property) Expected size that will be freed as a human-readable string. Example: "42.2K". """ return _format_size(self.expected_freed_size) def execute(self) -> None: """Execute the defined strategy. If this method is interrupted, the cache might get corrupted. Deletion order is implemented so that references and symlinks are deleted before the actual blob files. This method is irreversible. If executed, cached files are erased and must be downloaded again. """ # Deletion order matters. Blobs are deleted in last so that the user can't end # up in a state where a `ref`` refers to a missing snapshot or a snapshot # symlink refers to a deleted blob. # Delete entire repos for path in self.repos: _try_delete_path(path, path_type="repo") # Delete snapshot directories for path in self.snapshots: _try_delete_path(path, path_type="snapshot") # Delete refs files for path in self.refs: _try_delete_path(path, path_type="ref") # Delete blob files for path in self.blobs: _try_delete_path(path, path_type="blob") logger.info(f"Cache deletion done. Saved {self.expected_freed_size_str}.") @dataclass(frozen=True) class HFCacheInfo: """Frozen data structure holding information about the entire cache-system. This data structure is returned by [`scan_cache_dir`] and is immutable. Args: size_on_disk (`int`): Sum of all valid repo sizes in the cache-system. repos (`FrozenSet[CachedRepoInfo]`): Set of [`~CachedRepoInfo`] describing all valid cached repos found on the cache-system while scanning. warnings (`List[CorruptedCacheException]`): List of [`~CorruptedCacheException`] that occurred while scanning the cache. Those exceptions are captured so that the scan can continue. Corrupted repos are skipped from the scan. Here `size_on_disk` is equal to the sum of all repo sizes (only blobs). However if some cached repos are corrupted, their sizes are not taken into account. """ size_on_disk: int repos: FrozenSet[CachedRepoInfo] warnings: List[CorruptedCacheException] @property def size_on_disk_str(self) -> str: """ (property) Sum of all valid repo sizes in the cache-system as a human-readable string. Example: "42.2K". """ return _format_size(self.size_on_disk) def delete_revisions(self, *revisions: str) -> DeleteCacheStrategy: """Prepare the strategy to delete one or more revisions cached locally. Input revisions can be any revision hash. If a revision hash is not found in the local cache, a warning is thrown but no error is raised. Revisions can be from different cached repos since hashes are unique across repos, Examples: ```py >>> from huggingface_hub import scan_cache_dir >>> cache_info = scan_cache_dir() >>> delete_strategy = cache_info.delete_revisions( ... "81fd1d6e7847c99f5862c9fb81387956d99ec7aa" ... ) >>> print(f"Will free {delete_strategy.expected_freed_size_str}.") Will free 7.9K. >>> delete_strategy.execute() Cache deletion done. Saved 7.9K. ``` ```py >>> from huggingface_hub import scan_cache_dir >>> scan_cache_dir().delete_revisions( ... "81fd1d6e7847c99f5862c9fb81387956d99ec7aa", ... "e2983b237dccf3ab4937c97fa717319a9ca1a96d", ... "6c0e6080953db56375760c0471a8c5f2929baf11", ... ).execute() Cache deletion done. Saved 8.6G. ``` `delete_revisions` returns a [`~utils.DeleteCacheStrategy`] object that needs to be executed. The [`~utils.DeleteCacheStrategy`] is not meant to be modified but allows having a dry run before actually executing the deletion. """ hashes_to_delete: Set[str] = set(revisions) repos_with_revisions: Dict[CachedRepoInfo, Set[CachedRevisionInfo]] = defaultdict(set) for repo in self.repos: for revision in repo.revisions: if revision.commit_hash in hashes_to_delete: repos_with_revisions[repo].add(revision) hashes_to_delete.remove(revision.commit_hash) if len(hashes_to_delete) > 0: logger.warning(f"Revision(s) not found - cannot delete them: {', '.join(hashes_to_delete)}") delete_strategy_blobs: Set[Path] = set() delete_strategy_refs: Set[Path] = set() delete_strategy_repos: Set[Path] = set() delete_strategy_snapshots: Set[Path] = set() delete_strategy_expected_freed_size = 0 for affected_repo, revisions_to_delete in repos_with_revisions.items(): other_revisions = affected_repo.revisions - revisions_to_delete # If no other revisions, it means all revisions are deleted # -> delete the entire cached repo if len(other_revisions) == 0: delete_strategy_repos.add(affected_repo.repo_path) delete_strategy_expected_freed_size += affected_repo.size_on_disk continue # Some revisions of the repo will be deleted but not all. We need to filter # which blob files will not be linked anymore. for revision_to_delete in revisions_to_delete: # Snapshot dir delete_strategy_snapshots.add(revision_to_delete.snapshot_path) # Refs dir for ref in revision_to_delete.refs: delete_strategy_refs.add(affected_repo.repo_path / "refs" / ref) # Blobs dir for file in revision_to_delete.files: if file.blob_path not in delete_strategy_blobs: is_file_alone = True for revision in other_revisions: for rev_file in revision.files: if file.blob_path == rev_file.blob_path: is_file_alone = False break if not is_file_alone: break # Blob file not referenced by remaining revisions -> delete if is_file_alone: delete_strategy_blobs.add(file.blob_path) delete_strategy_expected_freed_size += file.size_on_disk # Return the strategy instead of executing it. return DeleteCacheStrategy( blobs=frozenset(delete_strategy_blobs), refs=frozenset(delete_strategy_refs), repos=frozenset(delete_strategy_repos), snapshots=frozenset(delete_strategy_snapshots), expected_freed_size=delete_strategy_expected_freed_size, ) def export_as_table(self, *, verbosity: int = 0) -> str: """Generate a table from the [`HFCacheInfo`] object. Pass `verbosity=0` to get a table with a single row per repo, with columns "repo_id", "repo_type", "size_on_disk", "nb_files", "last_accessed", "last_modified", "refs", "local_path". Pass `verbosity=1` to get a table with a row per repo and revision (thus multiple rows can appear for a single repo), with columns "repo_id", "repo_type", "revision", "size_on_disk", "nb_files", "last_modified", "refs", "local_path". Example: ```py >>> from huggingface_hub.utils import scan_cache_dir >>> hf_cache_info = scan_cache_dir() HFCacheInfo(...) >>> print(hf_cache_info.export_as_table()) REPO ID REPO TYPE SIZE ON DISK NB FILES LAST_ACCESSED LAST_MODIFIED REFS LOCAL PATH --------------------------------------------------- --------- ------------ -------- ------------- ------------- ---- -------------------------------------------------------------------------------------------------- roberta-base model 2.7M 5 1 day ago 1 week ago main ~/.cache/huggingface/hub/models--roberta-base suno/bark model 8.8K 1 1 week ago 1 week ago main ~/.cache/huggingface/hub/models--suno--bark t5-base model 893.8M 4 4 days ago 7 months ago main ~/.cache/huggingface/hub/models--t5-base t5-large model 3.0G 4 5 weeks ago 5 months ago main ~/.cache/huggingface/hub/models--t5-large >>> print(hf_cache_info.export_as_table(verbosity=1)) REPO ID REPO TYPE REVISION SIZE ON DISK NB FILES LAST_MODIFIED REFS LOCAL PATH --------------------------------------------------- --------- ---------------------------------------- ------------ -------- ------------- ---- ----------------------------------------------------------------------------------------------------------------------------------------------------- roberta-base model e2da8e2f811d1448a5b465c236feacd80ffbac7b 2.7M 5 1 week ago main ~/.cache/huggingface/hub/models--roberta-base/snapshots/e2da8e2f811d1448a5b465c236feacd80ffbac7b suno/bark model 70a8a7d34168586dc5d028fa9666aceade177992 8.8K 1 1 week ago main ~/.cache/huggingface/hub/models--suno--bark/snapshots/70a8a7d34168586dc5d028fa9666aceade177992 t5-base model a9723ea7f1b39c1eae772870f3b547bf6ef7e6c1 893.8M 4 7 months ago main ~/.cache/huggingface/hub/models--t5-base/snapshots/a9723ea7f1b39c1eae772870f3b547bf6ef7e6c1 t5-large model 150ebc2c4b72291e770f58e6057481c8d2ed331a 3.0G 4 5 months ago main ~/.cache/huggingface/hub/models--t5-large/snapshots/150ebc2c4b72291e770f58e6057481c8d2ed331a ``` Args: verbosity (`int`, *optional*): The verbosity level. Defaults to 0. Returns: `str`: The table as a string. """ if verbosity == 0: return tabulate( rows=[ [ repo.repo_id, repo.repo_type, "{:>12}".format(repo.size_on_disk_str), repo.nb_files, repo.last_accessed_str, repo.last_modified_str, ", ".join(sorted(repo.refs)), str(repo.repo_path), ] for repo in sorted(self.repos, key=lambda repo: repo.repo_path) ], headers=[ "REPO ID", "REPO TYPE", "SIZE ON DISK", "NB FILES", "LAST_ACCESSED", "LAST_MODIFIED", "REFS", "LOCAL PATH", ], ) else: return tabulate( rows=[ [ repo.repo_id, repo.repo_type, revision.commit_hash, "{:>12}".format(revision.size_on_disk_str), revision.nb_files, revision.last_modified_str, ", ".join(sorted(revision.refs)), str(revision.snapshot_path), ] for repo in sorted(self.repos, key=lambda repo: repo.repo_path) for revision in sorted(repo.revisions, key=lambda revision: revision.commit_hash) ], headers=[ "REPO ID", "REPO TYPE", "REVISION", "SIZE ON DISK", "NB FILES", "LAST_MODIFIED", "REFS", "LOCAL PATH", ], ) def scan_cache_dir(cache_dir: Optional[Union[str, Path]] = None) -> HFCacheInfo: """Scan the entire HF cache-system and return a [`~HFCacheInfo`] structure. Use `scan_cache_dir` in order to programmatically scan your cache-system. The cache will be scanned repo by repo. If a repo is corrupted, a [`~CorruptedCacheException`] will be thrown internally but captured and returned in the [`~HFCacheInfo`] structure. Only valid repos get a proper report. ```py >>> from huggingface_hub import scan_cache_dir >>> hf_cache_info = scan_cache_dir() HFCacheInfo( size_on_disk=3398085269, repos=frozenset({ CachedRepoInfo( repo_id='t5-small', repo_type='model', repo_path=PosixPath(...), size_on_disk=970726914, nb_files=11, revisions=frozenset({ CachedRevisionInfo( commit_hash='d78aea13fa7ecd06c29e3e46195d6341255065d5', size_on_disk=970726339, snapshot_path=PosixPath(...), files=frozenset({ CachedFileInfo( file_name='config.json', size_on_disk=1197 file_path=PosixPath(...), blob_path=PosixPath(...), ), CachedFileInfo(...), ... }), ), CachedRevisionInfo(...), ... }), ), CachedRepoInfo(...), ... }), warnings=[ CorruptedCacheException("Snapshots dir doesn't exist in cached repo: ..."), CorruptedCacheException(...), ... ], ) ``` You can also print a detailed report directly from the `huggingface-cli` using: ```text > huggingface-cli scan-cache REPO ID REPO TYPE SIZE ON DISK NB FILES REFS LOCAL PATH --------------------------- --------- ------------ -------- ------------------- ------------------------------------------------------------------------- glue dataset 116.3K 15 1.17.0, main, 2.4.0 /Users/lucain/.cache/huggingface/hub/datasets--glue google/fleurs dataset 64.9M 6 main, refs/pr/1 /Users/lucain/.cache/huggingface/hub/datasets--google--fleurs Jean-Baptiste/camembert-ner model 441.0M 7 main /Users/lucain/.cache/huggingface/hub/models--Jean-Baptiste--camembert-ner bert-base-cased model 1.9G 13 main /Users/lucain/.cache/huggingface/hub/models--bert-base-cased t5-base model 10.1K 3 main /Users/lucain/.cache/huggingface/hub/models--t5-base t5-small model 970.7M 11 refs/pr/1, main /Users/lucain/.cache/huggingface/hub/models--t5-small Done in 0.0s. Scanned 6 repo(s) for a total of 3.4G. Got 1 warning(s) while scanning. Use -vvv to print details. ``` Args: cache_dir (`str` or `Path`, `optional`): Cache directory to cache. Defaults to the default HF cache directory. Raises: `CacheNotFound` If the cache directory does not exist. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If the cache directory is a file, instead of a directory. Returns: a [`~HFCacheInfo`] object. """ if cache_dir is None: cache_dir = HF_HUB_CACHE cache_dir = Path(cache_dir).expanduser().resolve() if not cache_dir.exists(): raise CacheNotFound( f"Cache directory not found: {cache_dir}. Please use `cache_dir` argument or set `HF_HUB_CACHE` environment variable.", cache_dir=cache_dir, ) if cache_dir.is_file(): raise ValueError( f"Scan cache expects a directory but found a file: {cache_dir}. Please use `cache_dir` argument or set `HF_HUB_CACHE` environment variable." ) repos: Set[CachedRepoInfo] = set() warnings: List[CorruptedCacheException] = [] for repo_path in cache_dir.iterdir(): if repo_path.name == ".locks": # skip './.locks/' folder continue try: repos.add(_scan_cached_repo(repo_path)) except CorruptedCacheException as e: warnings.append(e) return HFCacheInfo( repos=frozenset(repos), size_on_disk=sum(repo.size_on_disk for repo in repos), warnings=warnings, ) def _scan_cached_repo(repo_path: Path) -> CachedRepoInfo: """Scan a single cache repo and return information about it. Any unexpected behavior will raise a [`~CorruptedCacheException`]. """ if not repo_path.is_dir(): raise CorruptedCacheException(f"Repo path is not a directory: {repo_path}") if "--" not in repo_path.name: raise CorruptedCacheException(f"Repo path is not a valid HuggingFace cache directory: {repo_path}") repo_type, repo_id = repo_path.name.split("--", maxsplit=1) repo_type = repo_type[:-1] # "models" -> "model" repo_id = repo_id.replace("--", "/") # google/fleurs -> "google/fleurs" if repo_type not in {"dataset", "model", "space"}: raise CorruptedCacheException( f"Repo type must be `dataset`, `model` or `space`, found `{repo_type}` ({repo_path})." ) blob_stats: Dict[Path, os.stat_result] = {} # Key is blob_path, value is blob stats snapshots_path = repo_path / "snapshots" refs_path = repo_path / "refs" if not snapshots_path.exists() or not snapshots_path.is_dir(): raise CorruptedCacheException(f"Snapshots dir doesn't exist in cached repo: {snapshots_path}") # Scan over `refs` directory # key is revision hash, value is set of refs refs_by_hash: Dict[str, Set[str]] = defaultdict(set) if refs_path.exists(): # Example of `refs` directory # ── refs # ├── main # └── refs # └── pr # └── 1 if refs_path.is_file(): raise CorruptedCacheException(f"Refs directory cannot be a file: {refs_path}") for ref_path in refs_path.glob("**/*"): # glob("**/*") iterates over all files and directories -> skip directories if ref_path.is_dir() or ref_path.name in FILES_TO_IGNORE: continue ref_name = str(ref_path.relative_to(refs_path)) with ref_path.open() as f: commit_hash = f.read() refs_by_hash[commit_hash].add(ref_name) # Scan snapshots directory cached_revisions: Set[CachedRevisionInfo] = set() for revision_path in snapshots_path.iterdir(): # Ignore OS-created helper files if revision_path.name in FILES_TO_IGNORE: continue if revision_path.is_file(): raise CorruptedCacheException(f"Snapshots folder corrupted. Found a file: {revision_path}") cached_files = set() for file_path in revision_path.glob("**/*"): # glob("**/*") iterates over all files and directories -> skip directories if file_path.is_dir(): continue blob_path = Path(file_path).resolve() if not blob_path.exists(): raise CorruptedCacheException(f"Blob missing (broken symlink): {blob_path}") if blob_path not in blob_stats: blob_stats[blob_path] = blob_path.stat() cached_files.add( CachedFileInfo( file_name=file_path.name, file_path=file_path, size_on_disk=blob_stats[blob_path].st_size, blob_path=blob_path, blob_last_accessed=blob_stats[blob_path].st_atime, blob_last_modified=blob_stats[blob_path].st_mtime, ) ) # Last modified is either the last modified blob file or the revision folder # itself if it is empty if len(cached_files) > 0: revision_last_modified = max(blob_stats[file.blob_path].st_mtime for file in cached_files) else: revision_last_modified = revision_path.stat().st_mtime cached_revisions.add( CachedRevisionInfo( commit_hash=revision_path.name, files=frozenset(cached_files), refs=frozenset(refs_by_hash.pop(revision_path.name, set())), size_on_disk=sum( blob_stats[blob_path].st_size for blob_path in set(file.blob_path for file in cached_files) ), snapshot_path=revision_path, last_modified=revision_last_modified, ) ) # Check that all refs referred to an existing revision if len(refs_by_hash) > 0: raise CorruptedCacheException( f"Reference(s) refer to missing commit hashes: {dict(refs_by_hash)} ({repo_path})." ) # Last modified is either the last modified blob file or the repo folder itself if # no blob files has been found. Same for last accessed. if len(blob_stats) > 0: repo_last_accessed = max(stat.st_atime for stat in blob_stats.values()) repo_last_modified = max(stat.st_mtime for stat in blob_stats.values()) else: repo_stats = repo_path.stat() repo_last_accessed = repo_stats.st_atime repo_last_modified = repo_stats.st_mtime # Build and return frozen structure return CachedRepoInfo( nb_files=len(blob_stats), repo_id=repo_id, repo_path=repo_path, repo_type=repo_type, # type: ignore revisions=frozenset(cached_revisions), size_on_disk=sum(stat.st_size for stat in blob_stats.values()), last_accessed=repo_last_accessed, last_modified=repo_last_modified, ) def _format_size(num: int) -> str: """Format size in bytes into a human-readable string. Taken from https://stackoverflow.com/a/1094933 """ num_f = float(num) for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]: if abs(num_f) < 1000.0: return f"{num_f:3.1f}{unit}" num_f /= 1000.0 return f"{num_f:.1f}Y" _TIMESINCE_CHUNKS = ( # Label, divider, max value ("second", 1, 60), ("minute", 60, 60), ("hour", 60 * 60, 24), ("day", 60 * 60 * 24, 6), ("week", 60 * 60 * 24 * 7, 6), ("month", 60 * 60 * 24 * 30, 11), ("year", 60 * 60 * 24 * 365, None), ) def _format_timesince(ts: float) -> str: """Format timestamp in seconds into a human-readable string, relative to now. Vaguely inspired by Django's `timesince` formatter. """ delta = time.time() - ts if delta < 20: return "a few seconds ago" for label, divider, max_value in _TIMESINCE_CHUNKS: # noqa: B007 value = round(delta / divider) if max_value is not None and value <= max_value: break return f"{value} {label}{'s' if value > 1 else ''} ago" def _try_delete_path(path: Path, path_type: str) -> None: """Try to delete a local file or folder. If the path does not exists, error is logged as a warning and then ignored. Args: path (`Path`) Path to delete. Can be a file or a folder. path_type (`str`) What path are we deleting ? Only for logging purposes. Example: "snapshot". """ logger.info(f"Delete {path_type}: {path}") try: if path.is_file(): os.remove(path) else: shutil.rmtree(path) except FileNotFoundError: logger.warning(f"Couldn't delete {path_type}: file not found ({path})", exc_info=True) except PermissionError: logger.warning(f"Couldn't delete {path_type}: permission denied ({path})", exc_info=True) huggingface_hub-0.31.1/src/huggingface_hub/utils/_chunk_utils.py000066400000000000000000000041221500667546600250140ustar00rootroot00000000000000# coding=utf-8 # Copyright 2022-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains a utility to iterate by chunks over an iterator.""" import itertools from typing import Iterable, TypeVar T = TypeVar("T") def chunk_iterable(iterable: Iterable[T], chunk_size: int) -> Iterable[Iterable[T]]: """Iterates over an iterator chunk by chunk. Taken from https://stackoverflow.com/a/8998040. See also https://github.com/huggingface/huggingface_hub/pull/920#discussion_r938793088. Args: iterable (`Iterable`): The iterable on which we want to iterate. chunk_size (`int`): Size of the chunks. Must be a strictly positive integer (e.g. >0). Example: ```python >>> from huggingface_hub.utils import chunk_iterable >>> for items in chunk_iterable(range(17), chunk_size=8): ... print(items) # [0, 1, 2, 3, 4, 5, 6, 7] # [8, 9, 10, 11, 12, 13, 14, 15] # [16] # smaller last chunk ``` Raises: [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If `chunk_size` <= 0. The last chunk can be smaller than `chunk_size`. """ if not isinstance(chunk_size, int) or chunk_size <= 0: raise ValueError("`chunk_size` must be a strictly positive integer (>0).") iterator = iter(iterable) while True: try: next_item = next(iterator) except StopIteration: return yield itertools.chain((next_item,), itertools.islice(iterator, chunk_size - 1)) huggingface_hub-0.31.1/src/huggingface_hub/utils/_datetime.py000066400000000000000000000053221500667546600242630ustar00rootroot00000000000000# coding=utf-8 # Copyright 2022-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains utilities to handle datetimes in Huggingface Hub.""" from datetime import datetime, timezone def parse_datetime(date_string: str) -> datetime: """ Parses a date_string returned from the server to a datetime object. This parser is a weak-parser is the sense that it handles only a single format of date_string. It is expected that the server format will never change. The implementation depends only on the standard lib to avoid an external dependency (python-dateutil). See full discussion about this decision on PR: https://github.com/huggingface/huggingface_hub/pull/999. Example: ```py > parse_datetime('2022-08-19T07:19:38.123Z') datetime.datetime(2022, 8, 19, 7, 19, 38, 123000, tzinfo=timezone.utc) ``` Args: date_string (`str`): A string representing a datetime returned by the Hub server. String is expected to follow '%Y-%m-%dT%H:%M:%S.%fZ' pattern. Returns: A python datetime object. Raises: :class:`ValueError`: If `date_string` cannot be parsed. """ try: # Normalize the string to always have 6 digits of fractional seconds if date_string.endswith("Z"): # Case 1: No decimal point (e.g., "2024-11-16T00:27:02Z") if "." not in date_string: # No fractional seconds - insert .000000 date_string = date_string[:-1] + ".000000Z" # Case 2: Has decimal point (e.g., "2022-08-19T07:19:38.123456789Z") else: # Get the fractional and base parts base, fraction = date_string[:-1].split(".") # fraction[:6] takes first 6 digits and :0<6 pads with zeros if less than 6 digits date_string = f"{base}.{fraction[:6]:0<6}Z" return datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%S.%fZ").replace(tzinfo=timezone.utc) except ValueError as e: raise ValueError( f"Cannot parse '{date_string}' as a datetime. Date string is expected to" " follow '%Y-%m-%dT%H:%M:%S.%fZ' pattern." ) from e huggingface_hub-0.31.1/src/huggingface_hub/utils/_deprecation.py000066400000000000000000000114101500667546600247570ustar00rootroot00000000000000import warnings from functools import wraps from inspect import Parameter, signature from typing import Iterable, Optional def _deprecate_positional_args(*, version: str): """Decorator for methods that issues warnings for positional arguments. Using the keyword-only argument syntax in pep 3102, arguments after the * will issue a warning when passed as a positional argument. Args: version (`str`): The version when positional arguments will result in error. """ def _inner_deprecate_positional_args(f): sig = signature(f) kwonly_args = [] all_args = [] for name, param in sig.parameters.items(): if param.kind == Parameter.POSITIONAL_OR_KEYWORD: all_args.append(name) elif param.kind == Parameter.KEYWORD_ONLY: kwonly_args.append(name) @wraps(f) def inner_f(*args, **kwargs): extra_args = len(args) - len(all_args) if extra_args <= 0: return f(*args, **kwargs) # extra_args > 0 args_msg = [ f"{name}='{arg}'" if isinstance(arg, str) else f"{name}={arg}" for name, arg in zip(kwonly_args[:extra_args], args[-extra_args:]) ] args_msg = ", ".join(args_msg) warnings.warn( f"Deprecated positional argument(s) used in '{f.__name__}': pass" f" {args_msg} as keyword args. From version {version} passing these" " as positional arguments will result in an error,", FutureWarning, ) kwargs.update(zip(sig.parameters, args)) return f(**kwargs) return inner_f return _inner_deprecate_positional_args def _deprecate_arguments( *, version: str, deprecated_args: Iterable[str], custom_message: Optional[str] = None, ): """Decorator to issue warnings when using deprecated arguments. TODO: could be useful to be able to set a custom error message. Args: version (`str`): The version when deprecated arguments will result in error. deprecated_args (`List[str]`): List of the arguments to be deprecated. custom_message (`str`, *optional*): Warning message that is raised. If not passed, a default warning message will be created. """ def _inner_deprecate_positional_args(f): sig = signature(f) @wraps(f) def inner_f(*args, **kwargs): # Check for used deprecated arguments used_deprecated_args = [] for _, parameter in zip(args, sig.parameters.values()): if parameter.name in deprecated_args: used_deprecated_args.append(parameter.name) for kwarg_name, kwarg_value in kwargs.items(): if ( # If argument is deprecated but still used kwarg_name in deprecated_args # And then the value is not the default value and kwarg_value != sig.parameters[kwarg_name].default ): used_deprecated_args.append(kwarg_name) # Warn and proceed if len(used_deprecated_args) > 0: message = ( f"Deprecated argument(s) used in '{f.__name__}':" f" {', '.join(used_deprecated_args)}. Will not be supported from" f" version '{version}'." ) if custom_message is not None: message += "\n\n" + custom_message warnings.warn(message, FutureWarning) return f(*args, **kwargs) return inner_f return _inner_deprecate_positional_args def _deprecate_method(*, version: str, message: Optional[str] = None): """Decorator to issue warnings when using a deprecated method. Args: version (`str`): The version when deprecated arguments will result in error. message (`str`, *optional*): Warning message that is raised. If not passed, a default warning message will be created. """ def _inner_deprecate_method(f): name = f.__name__ if name == "__init__": name = f.__qualname__.split(".")[0] # class name instead of method name @wraps(f) def inner_f(*args, **kwargs): warning_message = ( f"'{name}' (from '{f.__module__}') is deprecated and will be removed from version '{version}'." ) if message is not None: warning_message += " " + message warnings.warn(warning_message, FutureWarning) return f(*args, **kwargs) return inner_f return _inner_deprecate_method huggingface_hub-0.31.1/src/huggingface_hub/utils/_experimental.py000066400000000000000000000045331500667546600251670ustar00rootroot00000000000000# coding=utf-8 # Copyright 2023-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains utilities to flag a feature as "experimental" in Huggingface Hub.""" import warnings from functools import wraps from typing import Callable from .. import constants def experimental(fn: Callable) -> Callable: """Decorator to flag a feature as experimental. An experimental feature trigger a warning when used as it might be subject to breaking changes in the future. Warnings can be disabled by setting the environment variable `HF_EXPERIMENTAL_WARNING` to `0`. Args: fn (`Callable`): The function to flag as experimental. Returns: `Callable`: The decorated function. Example: ```python >>> from huggingface_hub.utils import experimental >>> @experimental ... def my_function(): ... print("Hello world!") >>> my_function() UserWarning: 'my_function' is experimental and might be subject to breaking changes in the future. You can disable this warning by setting `HF_HUB_DISABLE_EXPERIMENTAL_WARNING=1` as environment variable. Hello world! ``` """ # For classes, put the "experimental" around the "__new__" method => __new__ will be removed in warning message name = fn.__qualname__[: -len(".__new__")] if fn.__qualname__.endswith(".__new__") else fn.__qualname__ @wraps(fn) def _inner_fn(*args, **kwargs): if not constants.HF_HUB_DISABLE_EXPERIMENTAL_WARNING: warnings.warn( f"'{name}' is experimental and might be subject to breaking changes in the future." " You can disable this warning by setting `HF_HUB_DISABLE_EXPERIMENTAL_WARNING=1` as environment" " variable.", UserWarning, ) return fn(*args, **kwargs) return _inner_fn huggingface_hub-0.31.1/src/huggingface_hub/utils/_fixes.py000066400000000000000000000105251500667546600236060ustar00rootroot00000000000000# JSONDecodeError was introduced in requests=2.27 released in 2022. # This allows us to support older requests for users # More information: https://github.com/psf/requests/pull/5856 try: from requests import JSONDecodeError # type: ignore # noqa: F401 except ImportError: try: from simplejson import JSONDecodeError # type: ignore # noqa: F401 except ImportError: from json import JSONDecodeError # type: ignore # noqa: F401 import contextlib import os import shutil import stat import tempfile import time from functools import partial from pathlib import Path from typing import Callable, Generator, Optional, Union import yaml from filelock import BaseFileLock, FileLock, SoftFileLock, Timeout from .. import constants from . import logging logger = logging.get_logger(__name__) # Wrap `yaml.dump` to set `allow_unicode=True` by default. # # Example: # ```py # >>> yaml.dump({"emoji": "👀", "some unicode": "日本か"}) # 'emoji: "\\U0001F440"\nsome unicode: "\\u65E5\\u672C\\u304B"\n' # # >>> yaml_dump({"emoji": "👀", "some unicode": "日本か"}) # 'emoji: "👀"\nsome unicode: "日本か"\n' # ``` yaml_dump: Callable[..., str] = partial(yaml.dump, stream=None, allow_unicode=True) # type: ignore @contextlib.contextmanager def SoftTemporaryDirectory( suffix: Optional[str] = None, prefix: Optional[str] = None, dir: Optional[Union[Path, str]] = None, **kwargs, ) -> Generator[Path, None, None]: """ Context manager to create a temporary directory and safely delete it. If tmp directory cannot be deleted normally, we set the WRITE permission and retry. If cleanup still fails, we give up but don't raise an exception. This is equivalent to `tempfile.TemporaryDirectory(..., ignore_cleanup_errors=True)` introduced in Python 3.10. See https://www.scivision.dev/python-tempfile-permission-error-windows/. """ tmpdir = tempfile.TemporaryDirectory(prefix=prefix, suffix=suffix, dir=dir, **kwargs) yield Path(tmpdir.name).resolve() try: # First once with normal cleanup shutil.rmtree(tmpdir.name) except Exception: # If failed, try to set write permission and retry try: shutil.rmtree(tmpdir.name, onerror=_set_write_permission_and_retry) except Exception: pass # And finally, cleanup the tmpdir. # If it fails again, give up but do not throw error try: tmpdir.cleanup() except Exception: pass def _set_write_permission_and_retry(func, path, excinfo): os.chmod(path, stat.S_IWRITE) func(path) @contextlib.contextmanager def WeakFileLock( lock_file: Union[str, Path], *, timeout: Optional[float] = None ) -> Generator[BaseFileLock, None, None]: """A filelock with some custom logic. This filelock is weaker than the default filelock in that: 1. It won't raise an exception if release fails. 2. It will default to a SoftFileLock if the filesystem does not support flock. An INFO log message is emitted every 10 seconds if the lock is not acquired immediately. If a timeout is provided, a `filelock.Timeout` exception is raised if the lock is not acquired within the timeout. """ log_interval = constants.FILELOCK_LOG_EVERY_SECONDS lock = FileLock(lock_file, timeout=log_interval) start_time = time.time() while True: elapsed_time = time.time() - start_time if timeout is not None and elapsed_time >= timeout: raise Timeout(str(lock_file)) try: lock.acquire(timeout=min(log_interval, timeout - elapsed_time) if timeout else log_interval) except Timeout: logger.info( f"Still waiting to acquire lock on {lock_file} (elapsed: {time.time() - start_time:.1f} seconds)" ) except NotImplementedError as e: if "use SoftFileLock instead" in str(e): logger.warning( "FileSystem does not appear to support flock. Falling back to SoftFileLock for %s", lock_file ) lock = SoftFileLock(lock_file, timeout=log_interval) continue else: break try: yield lock finally: try: lock.release() except OSError: try: Path(lock_file).unlink() except OSError: pass huggingface_hub-0.31.1/src/huggingface_hub/utils/_git_credential.py000066400000000000000000000107641500667546600254520ustar00rootroot00000000000000# coding=utf-8 # Copyright 2022-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains utilities to manage Git credentials.""" import re import subprocess from typing import List, Optional from ..constants import ENDPOINT from ._subprocess import run_interactive_subprocess, run_subprocess GIT_CREDENTIAL_REGEX = re.compile( r""" ^\s* # start of line credential\.helper # credential.helper value \s*=\s* # separator (\w+) # the helper name (group 1) (\s|$) # whitespace or end of line """, flags=re.MULTILINE | re.IGNORECASE | re.VERBOSE, ) def list_credential_helpers(folder: Optional[str] = None) -> List[str]: """Return the list of git credential helpers configured. See https://git-scm.com/docs/gitcredentials. Credentials are saved in all configured helpers (store, cache, macOS keychain,...). Calls "`git credential approve`" internally. See https://git-scm.com/docs/git-credential. Args: folder (`str`, *optional*): The folder in which to check the configured helpers. """ try: output = run_subprocess("git config --list", folder=folder).stdout parsed = _parse_credential_output(output) return parsed except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) def set_git_credential(token: str, username: str = "hf_user", folder: Optional[str] = None) -> None: """Save a username/token pair in git credential for HF Hub registry. Credentials are saved in all configured helpers (store, cache, macOS keychain,...). Calls "`git credential approve`" internally. See https://git-scm.com/docs/git-credential. Args: username (`str`, defaults to `"hf_user"`): A git username. Defaults to `"hf_user"`, the default user used in the Hub. token (`str`, defaults to `"hf_user"`): A git password. In practice, the User Access Token for the Hub. See https://huggingface.co/settings/tokens. folder (`str`, *optional*): The folder in which to check the configured helpers. """ with run_interactive_subprocess("git credential approve", folder=folder) as ( stdin, _, ): stdin.write(f"url={ENDPOINT}\nusername={username.lower()}\npassword={token}\n\n") stdin.flush() def unset_git_credential(username: str = "hf_user", folder: Optional[str] = None) -> None: """Erase credentials from git credential for HF Hub registry. Credentials are erased from the configured helpers (store, cache, macOS keychain,...), if any. If `username` is not provided, any credential configured for HF Hub endpoint is erased. Calls "`git credential erase`" internally. See https://git-scm.com/docs/git-credential. Args: username (`str`, defaults to `"hf_user"`): A git username. Defaults to `"hf_user"`, the default user used in the Hub. folder (`str`, *optional*): The folder in which to check the configured helpers. """ with run_interactive_subprocess("git credential reject", folder=folder) as ( stdin, _, ): standard_input = f"url={ENDPOINT}\n" if username is not None: standard_input += f"username={username.lower()}\n" standard_input += "\n" stdin.write(standard_input) stdin.flush() def _parse_credential_output(output: str) -> List[str]: """Parse the output of `git credential fill` to extract the password. Args: output (`str`): The output of `git credential fill`. """ # NOTE: If user has set an helper for a custom URL, it will not we caught here. # Example: `credential.https://huggingface.co.helper=store` # See: https://github.com/huggingface/huggingface_hub/pull/1138#discussion_r1013324508 return sorted( # Sort for nice printing set( # Might have some duplicates match[0] for match in GIT_CREDENTIAL_REGEX.findall(output) ) ) huggingface_hub-0.31.1/src/huggingface_hub/utils/_headers.py000066400000000000000000000212541500667546600241040ustar00rootroot00000000000000# coding=utf-8 # Copyright 2022-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains utilities to handle headers to send in calls to Huggingface Hub.""" from typing import Dict, Optional, Union from huggingface_hub.errors import LocalTokenNotFoundError from .. import constants from ._auth import get_token from ._deprecation import _deprecate_arguments from ._runtime import ( get_fastai_version, get_fastcore_version, get_hf_hub_version, get_python_version, get_tf_version, get_torch_version, is_fastai_available, is_fastcore_available, is_tf_available, is_torch_available, ) from ._validators import validate_hf_hub_args @_deprecate_arguments( version="1.0", deprecated_args="is_write_action", custom_message="This argument is ignored and we let the server handle the permission error instead (if any).", ) @validate_hf_hub_args def build_hf_headers( *, token: Optional[Union[bool, str]] = None, library_name: Optional[str] = None, library_version: Optional[str] = None, user_agent: Union[Dict, str, None] = None, headers: Optional[Dict[str, str]] = None, is_write_action: bool = False, ) -> Dict[str, str]: """ Build headers dictionary to send in a HF Hub call. By default, authorization token is always provided either from argument (explicit use) or retrieved from the cache (implicit use). To explicitly avoid sending the token to the Hub, set `token=False` or set the `HF_HUB_DISABLE_IMPLICIT_TOKEN` environment variable. In case of an API call that requires write access, an error is thrown if token is `None` or token is an organization token (starting with `"api_org***"`). In addition to the auth header, a user-agent is added to provide information about the installed packages (versions of python, huggingface_hub, torch, tensorflow, fastai and fastcore). Args: token (`str`, `bool`, *optional*): The token to be sent in authorization header for the Hub call: - if a string, it is used as the Hugging Face token - if `True`, the token is read from the machine (cache or env variable) - if `False`, authorization header is not set - if `None`, the token is read from the machine only except if `HF_HUB_DISABLE_IMPLICIT_TOKEN` env variable is set. library_name (`str`, *optional*): The name of the library that is making the HTTP request. Will be added to the user-agent header. library_version (`str`, *optional*): The version of the library that is making the HTTP request. Will be added to the user-agent header. user_agent (`str`, `dict`, *optional*): The user agent info in the form of a dictionary or a single string. It will be completed with information about the installed packages. headers (`dict`, *optional*): Additional headers to include in the request. Those headers take precedence over the ones generated by this function. is_write_action (`bool`): Ignored and deprecated argument. Returns: A `Dict` of headers to pass in your API call. Example: ```py >>> build_hf_headers(token="hf_***") # explicit token {"authorization": "Bearer hf_***", "user-agent": ""} >>> build_hf_headers(token=True) # explicitly use cached token {"authorization": "Bearer hf_***",...} >>> build_hf_headers(token=False) # explicitly don't use cached token {"user-agent": ...} >>> build_hf_headers() # implicit use of the cached token {"authorization": "Bearer hf_***",...} # HF_HUB_DISABLE_IMPLICIT_TOKEN=True # to set as env variable >>> build_hf_headers() # token is not sent {"user-agent": ...} >>> build_hf_headers(library_name="transformers", library_version="1.2.3") {"authorization": ..., "user-agent": "transformers/1.2.3; hf_hub/0.10.2; python/3.10.4; tensorflow/1.55"} ``` Raises: [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If organization token is passed and "write" access is required. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If "write" access is required but token is not passed and not saved locally. [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) If `token=True` but token is not saved locally. """ # Get auth token to send token_to_send = get_token_to_send(token) # Combine headers hf_headers = { "user-agent": _http_user_agent( library_name=library_name, library_version=library_version, user_agent=user_agent, ) } if token_to_send is not None: hf_headers["authorization"] = f"Bearer {token_to_send}" if headers is not None: hf_headers.update(headers) return hf_headers def get_token_to_send(token: Optional[Union[bool, str]]) -> Optional[str]: """Select the token to send from either `token` or the cache.""" # Case token is explicitly provided if isinstance(token, str): return token # Case token is explicitly forbidden if token is False: return None # Token is not provided: we get it from local cache cached_token = get_token() # Case token is explicitly required if token is True: if cached_token is None: raise LocalTokenNotFoundError( "Token is required (`token=True`), but no token found. You" " need to provide a token or be logged in to Hugging Face with" " `huggingface-cli login` or `huggingface_hub.login`. See" " https://huggingface.co/settings/tokens." ) return cached_token # Case implicit use of the token is forbidden by env variable if constants.HF_HUB_DISABLE_IMPLICIT_TOKEN: return None # Otherwise: we use the cached token as the user has not explicitly forbidden it return cached_token def _http_user_agent( *, library_name: Optional[str] = None, library_version: Optional[str] = None, user_agent: Union[Dict, str, None] = None, ) -> str: """Format a user-agent string containing information about the installed packages. Args: library_name (`str`, *optional*): The name of the library that is making the HTTP request. library_version (`str`, *optional*): The version of the library that is making the HTTP request. user_agent (`str`, `dict`, *optional*): The user agent info in the form of a dictionary or a single string. Returns: The formatted user-agent string. """ if library_name is not None: ua = f"{library_name}/{library_version}" else: ua = "unknown/None" ua += f"; hf_hub/{get_hf_hub_version()}" ua += f"; python/{get_python_version()}" if not constants.HF_HUB_DISABLE_TELEMETRY: if is_torch_available(): ua += f"; torch/{get_torch_version()}" if is_tf_available(): ua += f"; tensorflow/{get_tf_version()}" if is_fastai_available(): ua += f"; fastai/{get_fastai_version()}" if is_fastcore_available(): ua += f"; fastcore/{get_fastcore_version()}" if isinstance(user_agent, dict): ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items()) elif isinstance(user_agent, str): ua += "; " + user_agent # Retrieve user-agent origin headers from environment variable origin = constants.HF_HUB_USER_AGENT_ORIGIN if origin is not None: ua += "; origin/" + origin return _deduplicate_user_agent(ua) def _deduplicate_user_agent(user_agent: str) -> str: """Deduplicate redundant information in the generated user-agent.""" # Split around ";" > Strip whitespaces > Store as dict keys (ensure unicity) > format back as string # Order is implicitly preserved by dictionary structure (see https://stackoverflow.com/a/53657523). return "; ".join({key.strip(): None for key in user_agent.split(";")}.keys()) huggingface_hub-0.31.1/src/huggingface_hub/utils/_hf_folder.py000066400000000000000000000046671500667546600244320ustar00rootroot00000000000000# coding=utf-8 # Copyright 2022-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contain helper class to retrieve/store token from/to local cache.""" from pathlib import Path from typing import Optional from .. import constants from ._auth import get_token class HfFolder: # TODO: deprecate when adapted in transformers/datasets/gradio # @_deprecate_method(version="1.0", message="Use `huggingface_hub.login` instead.") @classmethod def save_token(cls, token: str) -> None: """ Save token, creating folder as needed. Token is saved in the huggingface home folder. You can configure it by setting the `HF_HOME` environment variable. Args: token (`str`): The token to save to the [`HfFolder`] """ path_token = Path(constants.HF_TOKEN_PATH) path_token.parent.mkdir(parents=True, exist_ok=True) path_token.write_text(token) # TODO: deprecate when adapted in transformers/datasets/gradio # @_deprecate_method(version="1.0", message="Use `huggingface_hub.get_token` instead.") @classmethod def get_token(cls) -> Optional[str]: """ Get token or None if not existent. This method is deprecated in favor of [`huggingface_hub.get_token`] but is kept for backward compatibility. Its behavior is the same as [`huggingface_hub.get_token`]. Returns: `str` or `None`: The token, `None` if it doesn't exist. """ return get_token() # TODO: deprecate when adapted in transformers/datasets/gradio # @_deprecate_method(version="1.0", message="Use `huggingface_hub.logout` instead.") @classmethod def delete_token(cls) -> None: """ Deletes the token from storage. Does not fail if token does not exist. """ try: Path(constants.HF_TOKEN_PATH).unlink() except FileNotFoundError: pass huggingface_hub-0.31.1/src/huggingface_hub/utils/_http.py000066400000000000000000000616731500667546600234610ustar00rootroot00000000000000# coding=utf-8 # Copyright 2022-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains utilities to handle HTTP requests in Huggingface Hub.""" import io import os import re import threading import time import uuid from functools import lru_cache from http import HTTPStatus from shlex import quote from typing import Any, Callable, List, Optional, Tuple, Type, Union import requests from requests import HTTPError, Response from requests.adapters import HTTPAdapter from requests.models import PreparedRequest from huggingface_hub.errors import OfflineModeIsEnabled from .. import constants from ..errors import ( BadRequestError, DisabledRepoError, EntryNotFoundError, GatedRepoError, HfHubHTTPError, RepositoryNotFoundError, RevisionNotFoundError, ) from . import logging from ._fixes import JSONDecodeError from ._lfs import SliceFileObj from ._typing import HTTP_METHOD_T logger = logging.get_logger(__name__) # Both headers are used by the Hub to debug failed requests. # `X_AMZN_TRACE_ID` is better as it also works to debug on Cloudfront and ALB. # If `X_AMZN_TRACE_ID` is set, the Hub will use it as well. X_AMZN_TRACE_ID = "X-Amzn-Trace-Id" X_REQUEST_ID = "x-request-id" REPO_API_REGEX = re.compile( r""" # staging or production endpoint ^https://[^/]+ ( # on /api/repo_type/repo_id /api/(models|datasets|spaces)/(.+) | # or /repo_id/resolve/revision/... /(.+)/resolve/(.+) ) """, flags=re.VERBOSE, ) class UniqueRequestIdAdapter(HTTPAdapter): X_AMZN_TRACE_ID = "X-Amzn-Trace-Id" def add_headers(self, request, **kwargs): super().add_headers(request, **kwargs) # Add random request ID => easier for server-side debug if X_AMZN_TRACE_ID not in request.headers: request.headers[X_AMZN_TRACE_ID] = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4()) # Add debug log has_token = len(str(request.headers.get("authorization", ""))) > 0 logger.debug( f"Request {request.headers[X_AMZN_TRACE_ID]}: {request.method} {request.url} (authenticated: {has_token})" ) def send(self, request: PreparedRequest, *args, **kwargs) -> Response: """Catch any RequestException to append request id to the error message for debugging.""" if constants.HF_DEBUG: logger.debug(f"Send: {_curlify(request)}") try: return super().send(request, *args, **kwargs) except requests.RequestException as e: request_id = request.headers.get(X_AMZN_TRACE_ID) if request_id is not None: # Taken from https://stackoverflow.com/a/58270258 e.args = (*e.args, f"(Request ID: {request_id})") raise class OfflineAdapter(HTTPAdapter): def send(self, request: PreparedRequest, *args, **kwargs) -> Response: raise OfflineModeIsEnabled( f"Cannot reach {request.url}: offline mode is enabled. To disable it, please unset the `HF_HUB_OFFLINE` environment variable." ) def _default_backend_factory() -> requests.Session: session = requests.Session() if constants.HF_HUB_OFFLINE: session.mount("http://", OfflineAdapter()) session.mount("https://", OfflineAdapter()) else: session.mount("http://", UniqueRequestIdAdapter()) session.mount("https://", UniqueRequestIdAdapter()) return session BACKEND_FACTORY_T = Callable[[], requests.Session] _GLOBAL_BACKEND_FACTORY: BACKEND_FACTORY_T = _default_backend_factory def configure_http_backend(backend_factory: BACKEND_FACTORY_T = _default_backend_factory) -> None: """ Configure the HTTP backend by providing a `backend_factory`. Any HTTP calls made by `huggingface_hub` will use a Session object instantiated by this factory. This can be useful if you are running your scripts in a specific environment requiring custom configuration (e.g. custom proxy or certifications). Use [`get_session`] to get a configured Session. Since `requests.Session` is not guaranteed to be thread-safe, `huggingface_hub` creates 1 Session instance per thread. They are all instantiated using the same `backend_factory` set in [`configure_http_backend`]. A LRU cache is used to cache the created sessions (and connections) between calls. Max size is 128 to avoid memory leaks if thousands of threads are spawned. See [this issue](https://github.com/psf/requests/issues/2766) to know more about thread-safety in `requests`. Example: ```py import requests from huggingface_hub import configure_http_backend, get_session # Create a factory function that returns a Session with configured proxies def backend_factory() -> requests.Session: session = requests.Session() session.proxies = {"http": "http://10.10.1.10:3128", "https": "https://10.10.1.11:1080"} return session # Set it as the default session factory configure_http_backend(backend_factory=backend_factory) # In practice, this is mostly done internally in `huggingface_hub` session = get_session() ``` """ global _GLOBAL_BACKEND_FACTORY _GLOBAL_BACKEND_FACTORY = backend_factory reset_sessions() def get_session() -> requests.Session: """ Get a `requests.Session` object, using the session factory from the user. Use [`get_session`] to get a configured Session. Since `requests.Session` is not guaranteed to be thread-safe, `huggingface_hub` creates 1 Session instance per thread. They are all instantiated using the same `backend_factory` set in [`configure_http_backend`]. A LRU cache is used to cache the created sessions (and connections) between calls. Max size is 128 to avoid memory leaks if thousands of threads are spawned. See [this issue](https://github.com/psf/requests/issues/2766) to know more about thread-safety in `requests`. Example: ```py import requests from huggingface_hub import configure_http_backend, get_session # Create a factory function that returns a Session with configured proxies def backend_factory() -> requests.Session: session = requests.Session() session.proxies = {"http": "http://10.10.1.10:3128", "https": "https://10.10.1.11:1080"} return session # Set it as the default session factory configure_http_backend(backend_factory=backend_factory) # In practice, this is mostly done internally in `huggingface_hub` session = get_session() ``` """ return _get_session_from_cache(process_id=os.getpid(), thread_id=threading.get_ident()) def reset_sessions() -> None: """Reset the cache of sessions. Mostly used internally when sessions are reconfigured or an SSLError is raised. See [`configure_http_backend`] for more details. """ _get_session_from_cache.cache_clear() @lru_cache def _get_session_from_cache(process_id: int, thread_id: int) -> requests.Session: """ Create a new session per thread using global factory. Using LRU cache (maxsize 128) to avoid memory leaks when using thousands of threads. Cache is cleared when `configure_http_backend` is called. """ return _GLOBAL_BACKEND_FACTORY() def http_backoff( method: HTTP_METHOD_T, url: str, *, max_retries: int = 5, base_wait_time: float = 1, max_wait_time: float = 8, retry_on_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = ( requests.Timeout, requests.ConnectionError, ), retry_on_status_codes: Union[int, Tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE, **kwargs, ) -> Response: """Wrapper around requests to retry calls on an endpoint, with exponential backoff. Endpoint call is retried on exceptions (ex: connection timeout, proxy error,...) and/or on specific status codes (ex: service unavailable). If the call failed more than `max_retries`, the exception is thrown or `raise_for_status` is called on the response object. Re-implement mechanisms from the `backoff` library to avoid adding an external dependencies to `hugging_face_hub`. See https://github.com/litl/backoff. Args: method (`Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"]`): HTTP method to perform. url (`str`): The URL of the resource to fetch. max_retries (`int`, *optional*, defaults to `5`): Maximum number of retries, defaults to 5 (no retries). base_wait_time (`float`, *optional*, defaults to `1`): Duration (in seconds) to wait before retrying the first time. Wait time between retries then grows exponentially, capped by `max_wait_time`. max_wait_time (`float`, *optional*, defaults to `8`): Maximum duration (in seconds) to wait before retrying. retry_on_exceptions (`Type[Exception]` or `Tuple[Type[Exception]]`, *optional*): Define which exceptions must be caught to retry the request. Can be a single type or a tuple of types. By default, retry on `requests.Timeout` and `requests.ConnectionError`. retry_on_status_codes (`int` or `Tuple[int]`, *optional*, defaults to `503`): Define on which status codes the request must be retried. By default, only HTTP 503 Service Unavailable is retried. **kwargs (`dict`, *optional*): kwargs to pass to `requests.request`. Example: ``` >>> from huggingface_hub.utils import http_backoff # Same usage as "requests.request". >>> response = http_backoff("GET", "https://www.google.com") >>> response.raise_for_status() # If you expect a Gateway Timeout from time to time >>> http_backoff("PUT", upload_url, data=data, retry_on_status_codes=504) >>> response.raise_for_status() ``` When using `requests` it is possible to stream data by passing an iterator to the `data` argument. On http backoff this is a problem as the iterator is not reset after a failed call. This issue is mitigated for file objects or any IO streams by saving the initial position of the cursor (with `data.tell()`) and resetting the cursor between each call (with `data.seek()`). For arbitrary iterators, http backoff will fail. If this is a hard constraint for you, please let us know by opening an issue on [Github](https://github.com/huggingface/huggingface_hub). """ if isinstance(retry_on_exceptions, type): # Tuple from single exception type retry_on_exceptions = (retry_on_exceptions,) if isinstance(retry_on_status_codes, int): # Tuple from single status code retry_on_status_codes = (retry_on_status_codes,) nb_tries = 0 sleep_time = base_wait_time # If `data` is used and is a file object (or any IO), it will be consumed on the # first HTTP request. We need to save the initial position so that the full content # of the file is re-sent on http backoff. See warning tip in docstring. io_obj_initial_pos = None if "data" in kwargs and isinstance(kwargs["data"], (io.IOBase, SliceFileObj)): io_obj_initial_pos = kwargs["data"].tell() session = get_session() while True: nb_tries += 1 try: # If `data` is used and is a file object (or any IO), set back cursor to # initial position. if io_obj_initial_pos is not None: kwargs["data"].seek(io_obj_initial_pos) # Perform request and return if status_code is not in the retry list. response = session.request(method=method, url=url, **kwargs) if response.status_code not in retry_on_status_codes: return response # Wrong status code returned (HTTP 503 for instance) logger.warning(f"HTTP Error {response.status_code} thrown while requesting {method} {url}") if nb_tries > max_retries: response.raise_for_status() # Will raise uncaught exception # We return response to avoid infinite loop in the corner case where the # user ask for retry on a status code that doesn't raise_for_status. return response except retry_on_exceptions as err: logger.warning(f"'{err}' thrown while requesting {method} {url}") if isinstance(err, requests.ConnectionError): reset_sessions() # In case of SSLError it's best to reset the shared requests.Session objects if nb_tries > max_retries: raise err # Sleep for X seconds logger.warning(f"Retrying in {sleep_time}s [Retry {nb_tries}/{max_retries}].") time.sleep(sleep_time) # Update sleep time for next retry sleep_time = min(max_wait_time, sleep_time * 2) # Exponential backoff def fix_hf_endpoint_in_url(url: str, endpoint: Optional[str]) -> str: """Replace the default endpoint in a URL by a custom one. This is useful when using a proxy and the Hugging Face Hub returns a URL with the default endpoint. """ endpoint = endpoint.rstrip("/") if endpoint else constants.ENDPOINT # check if a proxy has been set => if yes, update the returned URL to use the proxy if endpoint not in (constants._HF_DEFAULT_ENDPOINT, constants._HF_DEFAULT_STAGING_ENDPOINT): url = url.replace(constants._HF_DEFAULT_ENDPOINT, endpoint) url = url.replace(constants._HF_DEFAULT_STAGING_ENDPOINT, endpoint) return url def hf_raise_for_status(response: Response, endpoint_name: Optional[str] = None) -> None: """ Internal version of `response.raise_for_status()` that will refine a potential HTTPError. Raised exception will be an instance of `HfHubHTTPError`. This helper is meant to be the unique method to raise_for_status when making a call to the Hugging Face Hub. Example: ```py import requests from huggingface_hub.utils import get_session, hf_raise_for_status, HfHubHTTPError response = get_session().post(...) try: hf_raise_for_status(response) except HfHubHTTPError as e: print(str(e)) # formatted message e.request_id, e.server_message # details returned by server # Complete the error message with additional information once it's raised e.append_to_message("\n`create_commit` expects the repository to exist.") raise ``` Args: response (`Response`): Response from the server. endpoint_name (`str`, *optional*): Name of the endpoint that has been called. If provided, the error message will be more complete. Raises when the request has failed: - [`~utils.RepositoryNotFoundError`] If the repository to download from cannot be found. This may be because it doesn't exist, because `repo_type` is not set correctly, or because the repo is `private` and you do not have access. - [`~utils.GatedRepoError`] If the repository exists but is gated and the user is not on the authorized list. - [`~utils.RevisionNotFoundError`] If the repository exists but the revision couldn't be find. - [`~utils.EntryNotFoundError`] If the repository exists but the entry (e.g. the requested file) couldn't be find. - [`~utils.BadRequestError`] If request failed with a HTTP 400 BadRequest error. - [`~utils.HfHubHTTPError`] If request failed for a reason not listed above. """ try: response.raise_for_status() except HTTPError as e: error_code = response.headers.get("X-Error-Code") error_message = response.headers.get("X-Error-Message") if error_code == "RevisionNotFound": message = f"{response.status_code} Client Error." + "\n\n" + f"Revision Not Found for url: {response.url}." raise _format(RevisionNotFoundError, message, response) from e elif error_code == "EntryNotFound": message = f"{response.status_code} Client Error." + "\n\n" + f"Entry Not Found for url: {response.url}." raise _format(EntryNotFoundError, message, response) from e elif error_code == "GatedRepo": message = ( f"{response.status_code} Client Error." + "\n\n" + f"Cannot access gated repo for url {response.url}." ) raise _format(GatedRepoError, message, response) from e elif error_message == "Access to this resource is disabled.": message = ( f"{response.status_code} Client Error." + "\n\n" + f"Cannot access repository for url {response.url}." + "\n" + "Access to this resource is disabled." ) raise _format(DisabledRepoError, message, response) from e elif error_code == "RepoNotFound" or ( response.status_code == 401 and error_message != "Invalid credentials in Authorization header" and response.request is not None and response.request.url is not None and REPO_API_REGEX.search(response.request.url) is not None ): # 401 is misleading as it is returned for: # - private and gated repos if user is not authenticated # - missing repos # => for now, we process them as `RepoNotFound` anyway. # See https://gist.github.com/Wauplin/46c27ad266b15998ce56a6603796f0b9 message = ( f"{response.status_code} Client Error." + "\n\n" + f"Repository Not Found for url: {response.url}." + "\nPlease make sure you specified the correct `repo_id` and" " `repo_type`.\nIf you are trying to access a private or gated repo," " make sure you are authenticated. For more details, see" " https://huggingface.co/docs/huggingface_hub/authentication" ) raise _format(RepositoryNotFoundError, message, response) from e elif response.status_code == 400: message = ( f"\n\nBad request for {endpoint_name} endpoint:" if endpoint_name is not None else "\n\nBad request:" ) raise _format(BadRequestError, message, response) from e elif response.status_code == 403: message = ( f"\n\n{response.status_code} Forbidden: {error_message}." + f"\nCannot access content at: {response.url}." + "\nMake sure your token has the correct permissions." ) raise _format(HfHubHTTPError, message, response) from e elif response.status_code == 416: range_header = response.request.headers.get("Range") message = f"{e}. Requested range: {range_header}. Content-Range: {response.headers.get('Content-Range')}." raise _format(HfHubHTTPError, message, response) from e # Convert `HTTPError` into a `HfHubHTTPError` to display request information # as well (request id and/or server error message) raise _format(HfHubHTTPError, str(e), response) from e def _format(error_type: Type[HfHubHTTPError], custom_message: str, response: Response) -> HfHubHTTPError: server_errors = [] # Retrieve server error from header from_headers = response.headers.get("X-Error-Message") if from_headers is not None: server_errors.append(from_headers) # Retrieve server error from body try: # Case errors are returned in a JSON format data = response.json() error = data.get("error") if error is not None: if isinstance(error, list): # Case {'error': ['my error 1', 'my error 2']} server_errors.extend(error) else: # Case {'error': 'my error'} server_errors.append(error) errors = data.get("errors") if errors is not None: # Case {'errors': [{'message': 'my error 1'}, {'message': 'my error 2'}]} for error in errors: if "message" in error: server_errors.append(error["message"]) except JSONDecodeError: # If content is not JSON and not HTML, append the text content_type = response.headers.get("Content-Type", "") if response.text and "html" not in content_type.lower(): server_errors.append(response.text) # Strip all server messages server_errors = [str(line).strip() for line in server_errors if str(line).strip()] # Deduplicate server messages (keep order) # taken from https://stackoverflow.com/a/17016257 server_errors = list(dict.fromkeys(server_errors)) # Format server error server_message = "\n".join(server_errors) # Add server error to custom message final_error_message = custom_message if server_message and server_message.lower() not in custom_message.lower(): if "\n\n" in custom_message: final_error_message += "\n" + server_message else: final_error_message += "\n\n" + server_message # Add Request ID request_id = str(response.headers.get(X_REQUEST_ID, "")) if request_id: request_id_message = f" (Request ID: {request_id})" else: # Fallback to X-Amzn-Trace-Id request_id = str(response.headers.get(X_AMZN_TRACE_ID, "")) if request_id: request_id_message = f" (Amzn Trace ID: {request_id})" if request_id and request_id.lower() not in final_error_message.lower(): if "\n" in final_error_message: newline_index = final_error_message.index("\n") final_error_message = ( final_error_message[:newline_index] + request_id_message + final_error_message[newline_index:] ) else: final_error_message += request_id_message # Return return error_type(final_error_message.strip(), response=response, server_message=server_message or None) def _curlify(request: requests.PreparedRequest) -> str: """Convert a `requests.PreparedRequest` into a curl command (str). Used for debug purposes only. Implementation vendored from https://github.com/ofw/curlify/blob/master/curlify.py. MIT License Copyright (c) 2016 Egor. """ parts: List[Tuple[Any, Any]] = [ ("curl", None), ("-X", request.method), ] for k, v in sorted(request.headers.items()): if k.lower() == "authorization": v = "" # Hide authorization header, no matter its value (can be Bearer, Key, etc.) parts += [("-H", "{0}: {1}".format(k, v))] if request.body: body = request.body if isinstance(body, bytes): body = body.decode("utf-8", errors="ignore") elif hasattr(body, "read"): body = "" # Don't try to read it to avoid consuming the stream if len(body) > 1000: body = body[:1000] + " ... [truncated]" parts += [("-d", body.replace("\n", ""))] parts += [(None, request.url)] flat_parts = [] for k, v in parts: if k: flat_parts.append(quote(k)) if v: flat_parts.append(quote(v)) return " ".join(flat_parts) # Regex to parse HTTP Range header RANGE_REGEX = re.compile(r"^\s*bytes\s*=\s*(\d*)\s*-\s*(\d*)\s*$", re.IGNORECASE) def _adjust_range_header(original_range: Optional[str], resume_size: int) -> Optional[str]: """ Adjust HTTP Range header to account for resume position. """ if not original_range: return f"bytes={resume_size}-" if "," in original_range: raise ValueError(f"Multiple ranges detected - {original_range!r}, not supported yet.") match = RANGE_REGEX.match(original_range) if not match: raise RuntimeError(f"Invalid range format - {original_range!r}.") start, end = match.groups() if not start: if not end: raise RuntimeError(f"Invalid range format - {original_range!r}.") new_suffix = int(end) - resume_size new_range = f"bytes=-{new_suffix}" if new_suffix <= 0: raise RuntimeError(f"Empty new range - {new_range!r}.") return new_range start = int(start) new_start = start + resume_size if end: end = int(end) new_range = f"bytes={new_start}-{end}" if new_start > end: raise RuntimeError(f"Empty new range - {new_range!r}.") return new_range return f"bytes={new_start}-" huggingface_hub-0.31.1/src/huggingface_hub/utils/_lfs.py000066400000000000000000000075651500667546600232660ustar00rootroot00000000000000# coding=utf-8 # Copyright 2019-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Git LFS related utilities""" import io import os from contextlib import AbstractContextManager from typing import BinaryIO class SliceFileObj(AbstractContextManager): """ Utility context manager to read a *slice* of a seekable file-like object as a seekable, file-like object. This is NOT thread safe Inspired by stackoverflow.com/a/29838711/593036 Credits to @julien-c Args: fileobj (`BinaryIO`): A file-like object to slice. MUST implement `tell()` and `seek()` (and `read()` of course). `fileobj` will be reset to its original position when exiting the context manager. seek_from (`int`): The start of the slice (offset from position 0 in bytes). read_limit (`int`): The maximum number of bytes to read from the slice. Attributes: previous_position (`int`): The previous position Examples: Reading 200 bytes with an offset of 128 bytes from a file (ie bytes 128 to 327): ```python >>> with open("path/to/file", "rb") as file: ... with SliceFileObj(file, seek_from=128, read_limit=200) as fslice: ... fslice.read(...) ``` Reading a file in chunks of 512 bytes ```python >>> import os >>> chunk_size = 512 >>> file_size = os.getsize("path/to/file") >>> with open("path/to/file", "rb") as file: ... for chunk_idx in range(ceil(file_size / chunk_size)): ... with SliceFileObj(file, seek_from=chunk_idx * chunk_size, read_limit=chunk_size) as fslice: ... chunk = fslice.read(...) ``` """ def __init__(self, fileobj: BinaryIO, seek_from: int, read_limit: int): self.fileobj = fileobj self.seek_from = seek_from self.read_limit = read_limit def __enter__(self): self._previous_position = self.fileobj.tell() end_of_stream = self.fileobj.seek(0, os.SEEK_END) self._len = min(self.read_limit, end_of_stream - self.seek_from) # ^^ The actual number of bytes that can be read from the slice self.fileobj.seek(self.seek_from, io.SEEK_SET) return self def __exit__(self, exc_type, exc_value, traceback): self.fileobj.seek(self._previous_position, io.SEEK_SET) def read(self, n: int = -1): pos = self.tell() if pos >= self._len: return b"" remaining_amount = self._len - pos data = self.fileobj.read(remaining_amount if n < 0 else min(n, remaining_amount)) return data def tell(self) -> int: return self.fileobj.tell() - self.seek_from def seek(self, offset: int, whence: int = os.SEEK_SET) -> int: start = self.seek_from end = start + self._len if whence in (os.SEEK_SET, os.SEEK_END): offset = start + offset if whence == os.SEEK_SET else end + offset offset = max(start, min(offset, end)) whence = os.SEEK_SET elif whence == os.SEEK_CUR: cur_pos = self.fileobj.tell() offset = max(start - cur_pos, min(offset, end - cur_pos)) else: raise ValueError(f"whence value {whence} is not supported") return self.fileobj.seek(offset, whence) - self.seek_from def __iter__(self): yield self.read(n=4 * 1024 * 1024) huggingface_hub-0.31.1/src/huggingface_hub/utils/_pagination.py000066400000000000000000000035621500667546600246240ustar00rootroot00000000000000# coding=utf-8 # Copyright 2022-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains utilities to handle pagination on Huggingface Hub.""" from typing import Dict, Iterable, Optional import requests from . import get_session, hf_raise_for_status, http_backoff, logging logger = logging.get_logger(__name__) def paginate(path: str, params: Dict, headers: Dict) -> Iterable: """Fetch a list of models/datasets/spaces and paginate through results. This is using the same "Link" header format as GitHub. See: - https://requests.readthedocs.io/en/latest/api/#requests.Response.links - https://docs.github.com/en/rest/guides/traversing-with-pagination#link-header """ session = get_session() r = session.get(path, params=params, headers=headers) hf_raise_for_status(r) yield from r.json() # Follow pages # Next link already contains query params next_page = _get_next_page(r) while next_page is not None: logger.debug(f"Pagination detected. Requesting next page: {next_page}") r = http_backoff("GET", next_page, max_retries=20, retry_on_status_codes=429, headers=headers) hf_raise_for_status(r) yield from r.json() next_page = _get_next_page(r) def _get_next_page(response: requests.Response) -> Optional[str]: return response.links.get("next", {}).get("url") huggingface_hub-0.31.1/src/huggingface_hub/utils/_paths.py000066400000000000000000000116621500667546600236120ustar00rootroot00000000000000# coding=utf-8 # Copyright 2022-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains utilities to handle paths in Huggingface Hub.""" from fnmatch import fnmatch from pathlib import Path from typing import Callable, Generator, Iterable, List, Optional, TypeVar, Union T = TypeVar("T") # Always ignore `.git` and `.cache/huggingface` folders in commits DEFAULT_IGNORE_PATTERNS = [ ".git", ".git/*", "*/.git", "**/.git/**", ".cache/huggingface", ".cache/huggingface/*", "*/.cache/huggingface", "**/.cache/huggingface/**", ] # Forbidden to commit these folders FORBIDDEN_FOLDERS = [".git", ".cache"] def filter_repo_objects( items: Iterable[T], *, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, key: Optional[Callable[[T], str]] = None, ) -> Generator[T, None, None]: """Filter repo objects based on an allowlist and a denylist. Input must be a list of paths (`str` or `Path`) or a list of arbitrary objects. In the later case, `key` must be provided and specifies a function of one argument that is used to extract a path from each element in iterable. Patterns are Unix shell-style wildcards which are NOT regular expressions. See https://docs.python.org/3/library/fnmatch.html for more details. Args: items (`Iterable`): List of items to filter. allow_patterns (`str` or `List[str]`, *optional*): Patterns constituting the allowlist. If provided, item paths must match at least one pattern from the allowlist. ignore_patterns (`str` or `List[str]`, *optional*): Patterns constituting the denylist. If provided, item paths must not match any patterns from the denylist. key (`Callable[[T], str]`, *optional*): Single-argument function to extract a path from each item. If not provided, the `items` must already be `str` or `Path`. Returns: Filtered list of objects, as a generator. Raises: :class:`ValueError`: If `key` is not provided and items are not `str` or `Path`. Example usage with paths: ```python >>> # Filter only PDFs that are not hidden. >>> list(filter_repo_objects( ... ["aaa.PDF", "bbb.jpg", ".ccc.pdf", ".ddd.png"], ... allow_patterns=["*.pdf"], ... ignore_patterns=[".*"], ... )) ["aaa.pdf"] ``` Example usage with objects: ```python >>> list(filter_repo_objects( ... [ ... CommitOperationAdd(path_or_fileobj="/tmp/aaa.pdf", path_in_repo="aaa.pdf") ... CommitOperationAdd(path_or_fileobj="/tmp/bbb.jpg", path_in_repo="bbb.jpg") ... CommitOperationAdd(path_or_fileobj="/tmp/.ccc.pdf", path_in_repo=".ccc.pdf") ... CommitOperationAdd(path_or_fileobj="/tmp/.ddd.png", path_in_repo=".ddd.png") ... ], ... allow_patterns=["*.pdf"], ... ignore_patterns=[".*"], ... key=lambda x: x.repo_in_path ... )) [CommitOperationAdd(path_or_fileobj="/tmp/aaa.pdf", path_in_repo="aaa.pdf")] ``` """ if isinstance(allow_patterns, str): allow_patterns = [allow_patterns] if isinstance(ignore_patterns, str): ignore_patterns = [ignore_patterns] if allow_patterns is not None: allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns] if ignore_patterns is not None: ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns] if key is None: def _identity(item: T) -> str: if isinstance(item, str): return item if isinstance(item, Path): return str(item) raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.") key = _identity # Items must be `str` or `Path`, otherwise raise ValueError for item in items: path = key(item) # Skip if there's an allowlist and path doesn't match any if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns): continue # Skip if there's a denylist and path matches any if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns): continue yield item def _add_wildcard_to_directories(pattern: str) -> str: if pattern[-1] == "/": return pattern + "*" return pattern huggingface_hub-0.31.1/src/huggingface_hub/utils/_runtime.py000066400000000000000000000265401500667546600241570ustar00rootroot00000000000000# coding=utf-8 # Copyright 2022-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Check presence of installed packages at runtime.""" import importlib.metadata import os import platform import sys import warnings from typing import Any, Dict from .. import __version__, constants _PY_VERSION: str = sys.version.split()[0].rstrip("+") _package_versions = {} _CANDIDATES = { "aiohttp": {"aiohttp"}, "fastai": {"fastai"}, "fastapi": {"fastapi"}, "fastcore": {"fastcore"}, "gradio": {"gradio"}, "graphviz": {"graphviz"}, "hf_transfer": {"hf_transfer"}, "hf_xet": {"hf_xet"}, "jinja": {"Jinja2"}, "keras": {"keras"}, "numpy": {"numpy"}, "pillow": {"Pillow"}, "pydantic": {"pydantic"}, "pydot": {"pydot"}, "safetensors": {"safetensors"}, "tensorboard": {"tensorboardX"}, "tensorflow": ( "tensorflow", "tensorflow-cpu", "tensorflow-gpu", "tf-nightly", "tf-nightly-cpu", "tf-nightly-gpu", "intel-tensorflow", "intel-tensorflow-avx512", "tensorflow-rocm", "tensorflow-macos", ), "torch": {"torch"}, } # Check once at runtime for candidate_name, package_names in _CANDIDATES.items(): _package_versions[candidate_name] = "N/A" for name in package_names: try: _package_versions[candidate_name] = importlib.metadata.version(name) break except importlib.metadata.PackageNotFoundError: pass def _get_version(package_name: str) -> str: return _package_versions.get(package_name, "N/A") def is_package_available(package_name: str) -> bool: return _get_version(package_name) != "N/A" # Python def get_python_version() -> str: return _PY_VERSION # Huggingface Hub def get_hf_hub_version() -> str: return __version__ # aiohttp def is_aiohttp_available() -> bool: return is_package_available("aiohttp") def get_aiohttp_version() -> str: return _get_version("aiohttp") # FastAI def is_fastai_available() -> bool: return is_package_available("fastai") def get_fastai_version() -> str: return _get_version("fastai") # FastAPI def is_fastapi_available() -> bool: return is_package_available("fastapi") def get_fastapi_version() -> str: return _get_version("fastapi") # Fastcore def is_fastcore_available() -> bool: return is_package_available("fastcore") def get_fastcore_version() -> str: return _get_version("fastcore") # FastAI def is_gradio_available() -> bool: return is_package_available("gradio") def get_gradio_version() -> str: return _get_version("gradio") # Graphviz def is_graphviz_available() -> bool: return is_package_available("graphviz") def get_graphviz_version() -> str: return _get_version("graphviz") # hf_transfer def is_hf_transfer_available() -> bool: return is_package_available("hf_transfer") def get_hf_transfer_version() -> str: return _get_version("hf_transfer") # xet def is_xet_available() -> bool: # since hf_xet is automatically used if available, allow explicit disabling via environment variable if constants._is_true(os.environ.get("HF_HUB_DISABLE_XET")): # type: ignore return False return is_package_available("hf_xet") def get_xet_version() -> str: return _get_version("hf_xet") # keras def is_keras_available() -> bool: return is_package_available("keras") def get_keras_version() -> str: return _get_version("keras") # Numpy def is_numpy_available() -> bool: return is_package_available("numpy") def get_numpy_version() -> str: return _get_version("numpy") # Jinja def is_jinja_available() -> bool: return is_package_available("jinja") def get_jinja_version() -> str: return _get_version("jinja") # Pillow def is_pillow_available() -> bool: return is_package_available("pillow") def get_pillow_version() -> str: return _get_version("pillow") # Pydantic def is_pydantic_available() -> bool: if not is_package_available("pydantic"): return False # For Pydantic, we add an extra check to test whether it is correctly installed or not. If both pydantic 2.x and # typing_extensions<=4.5.0 are installed, then pydantic will fail at import time. This should not happen when # it is installed with `pip install huggingface_hub[inference]` but it can happen when it is installed manually # by the user in an environment that we don't control. # # Usually we won't need to do this kind of check on optional dependencies. However, pydantic is a special case # as it is automatically imported when doing `from huggingface_hub import ...` even if the user doesn't use it. # # See https://github.com/huggingface/huggingface_hub/pull/1829 for more details. try: from pydantic import validator # noqa: F401 except ImportError: # Example: "ImportError: cannot import name 'TypeAliasType' from 'typing_extensions'" warnings.warn( "Pydantic is installed but cannot be imported. Please check your installation. `huggingface_hub` will " "default to not using Pydantic. Error message: '{e}'" ) return False return True def get_pydantic_version() -> str: return _get_version("pydantic") # Pydot def is_pydot_available() -> bool: return is_package_available("pydot") def get_pydot_version() -> str: return _get_version("pydot") # Tensorboard def is_tensorboard_available() -> bool: return is_package_available("tensorboard") def get_tensorboard_version() -> str: return _get_version("tensorboard") # Tensorflow def is_tf_available() -> bool: return is_package_available("tensorflow") def get_tf_version() -> str: return _get_version("tensorflow") # Torch def is_torch_available() -> bool: return is_package_available("torch") def get_torch_version() -> str: return _get_version("torch") # Safetensors def is_safetensors_available() -> bool: return is_package_available("safetensors") # Shell-related helpers try: # Set to `True` if script is running in a Google Colab notebook. # If running in Google Colab, git credential store is set globally which makes the # warning disappear. See https://github.com/huggingface/huggingface_hub/issues/1043 # # Taken from https://stackoverflow.com/a/63519730. _is_google_colab = "google.colab" in str(get_ipython()) # type: ignore # noqa: F821 except NameError: _is_google_colab = False def is_notebook() -> bool: """Return `True` if code is executed in a notebook (Jupyter, Colab, QTconsole). Taken from https://stackoverflow.com/a/39662359. Adapted to make it work with Google colab as well. """ try: shell_class = get_ipython().__class__ # type: ignore # noqa: F821 for parent_class in shell_class.__mro__: # e.g. "is subclass of" if parent_class.__name__ == "ZMQInteractiveShell": return True # Jupyter notebook, Google colab or qtconsole return False except NameError: return False # Probably standard Python interpreter def is_google_colab() -> bool: """Return `True` if code is executed in a Google colab. Taken from https://stackoverflow.com/a/63519730. """ return _is_google_colab def is_colab_enterprise() -> bool: """Return `True` if code is executed in a Google Colab Enterprise environment.""" return os.environ.get("VERTEX_PRODUCT") == "COLAB_ENTERPRISE" def dump_environment_info() -> Dict[str, Any]: """Dump information about the machine to help debugging issues. Similar helper exist in: - `datasets` (https://github.com/huggingface/datasets/blob/main/src/datasets/commands/env.py) - `diffusers` (https://github.com/huggingface/diffusers/blob/main/src/diffusers/commands/env.py) - `transformers` (https://github.com/huggingface/transformers/blob/main/src/transformers/commands/env.py) """ from huggingface_hub import get_token, whoami from huggingface_hub.utils import list_credential_helpers token = get_token() # Generic machine info info: Dict[str, Any] = { "huggingface_hub version": get_hf_hub_version(), "Platform": platform.platform(), "Python version": get_python_version(), } # Interpreter info try: shell_class = get_ipython().__class__ # type: ignore # noqa: F821 info["Running in iPython ?"] = "Yes" info["iPython shell"] = shell_class.__name__ except NameError: info["Running in iPython ?"] = "No" info["Running in notebook ?"] = "Yes" if is_notebook() else "No" info["Running in Google Colab ?"] = "Yes" if is_google_colab() else "No" info["Running in Google Colab Enterprise ?"] = "Yes" if is_colab_enterprise() else "No" # Login info info["Token path ?"] = constants.HF_TOKEN_PATH info["Has saved token ?"] = token is not None if token is not None: try: info["Who am I ?"] = whoami()["name"] except Exception: pass try: info["Configured git credential helpers"] = ", ".join(list_credential_helpers()) except Exception: pass # Installed dependencies info["FastAI"] = get_fastai_version() info["Tensorflow"] = get_tf_version() info["Torch"] = get_torch_version() info["Jinja2"] = get_jinja_version() info["Graphviz"] = get_graphviz_version() info["keras"] = get_keras_version() info["Pydot"] = get_pydot_version() info["Pillow"] = get_pillow_version() info["hf_transfer"] = get_hf_transfer_version() info["gradio"] = get_gradio_version() info["tensorboard"] = get_tensorboard_version() info["numpy"] = get_numpy_version() info["pydantic"] = get_pydantic_version() info["aiohttp"] = get_aiohttp_version() info["hf_xet"] = get_xet_version() # Environment variables info["ENDPOINT"] = constants.ENDPOINT info["HF_HUB_CACHE"] = constants.HF_HUB_CACHE info["HF_ASSETS_CACHE"] = constants.HF_ASSETS_CACHE info["HF_TOKEN_PATH"] = constants.HF_TOKEN_PATH info["HF_STORED_TOKENS_PATH"] = constants.HF_STORED_TOKENS_PATH info["HF_HUB_OFFLINE"] = constants.HF_HUB_OFFLINE info["HF_HUB_DISABLE_TELEMETRY"] = constants.HF_HUB_DISABLE_TELEMETRY info["HF_HUB_DISABLE_PROGRESS_BARS"] = constants.HF_HUB_DISABLE_PROGRESS_BARS info["HF_HUB_DISABLE_SYMLINKS_WARNING"] = constants.HF_HUB_DISABLE_SYMLINKS_WARNING info["HF_HUB_DISABLE_EXPERIMENTAL_WARNING"] = constants.HF_HUB_DISABLE_EXPERIMENTAL_WARNING info["HF_HUB_DISABLE_IMPLICIT_TOKEN"] = constants.HF_HUB_DISABLE_IMPLICIT_TOKEN info["HF_HUB_ENABLE_HF_TRANSFER"] = constants.HF_HUB_ENABLE_HF_TRANSFER info["HF_HUB_ETAG_TIMEOUT"] = constants.HF_HUB_ETAG_TIMEOUT info["HF_HUB_DOWNLOAD_TIMEOUT"] = constants.HF_HUB_DOWNLOAD_TIMEOUT print("\nCopy-and-paste the text below in your GitHub issue.\n") print("\n".join([f"- {prop}: {val}" for prop, val in info.items()]) + "\n") return info huggingface_hub-0.31.1/src/huggingface_hub/utils/_safetensors.py000066400000000000000000000105521500667546600250240ustar00rootroot00000000000000import functools import operator from collections import defaultdict from dataclasses import dataclass, field from typing import Dict, List, Literal, Optional, Tuple FILENAME_T = str TENSOR_NAME_T = str DTYPE_T = Literal["F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL"] @dataclass class TensorInfo: """Information about a tensor. For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format. Attributes: dtype (`str`): The data type of the tensor ("F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL"). shape (`List[int]`): The shape of the tensor. data_offsets (`Tuple[int, int]`): The offsets of the data in the file as a tuple `[BEGIN, END]`. parameter_count (`int`): The number of parameters in the tensor. """ dtype: DTYPE_T shape: List[int] data_offsets: Tuple[int, int] parameter_count: int = field(init=False) def __post_init__(self) -> None: # Taken from https://stackoverflow.com/a/13840436 try: self.parameter_count = functools.reduce(operator.mul, self.shape) except TypeError: self.parameter_count = 1 # scalar value has no shape @dataclass class SafetensorsFileMetadata: """Metadata for a Safetensors file hosted on the Hub. This class is returned by [`parse_safetensors_file_metadata`]. For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format. Attributes: metadata (`Dict`): The metadata contained in the file. tensors (`Dict[str, TensorInfo]`): A map of all tensors. Keys are tensor names and values are information about the corresponding tensor, as a [`TensorInfo`] object. parameter_count (`Dict[str, int]`): A map of the number of parameters per data type. Keys are data types and values are the number of parameters of that data type. """ metadata: Dict[str, str] tensors: Dict[TENSOR_NAME_T, TensorInfo] parameter_count: Dict[DTYPE_T, int] = field(init=False) def __post_init__(self) -> None: parameter_count: Dict[DTYPE_T, int] = defaultdict(int) for tensor in self.tensors.values(): parameter_count[tensor.dtype] += tensor.parameter_count self.parameter_count = dict(parameter_count) @dataclass class SafetensorsRepoMetadata: """Metadata for a Safetensors repo. A repo is considered to be a Safetensors repo if it contains either a 'model.safetensors' weight file (non-shared model) or a 'model.safetensors.index.json' index file (sharded model) at its root. This class is returned by [`get_safetensors_metadata`]. For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format. Attributes: metadata (`Dict`, *optional*): The metadata contained in the 'model.safetensors.index.json' file, if it exists. Only populated for sharded models. sharded (`bool`): Whether the repo contains a sharded model or not. weight_map (`Dict[str, str]`): A map of all weights. Keys are tensor names and values are filenames of the files containing the tensors. files_metadata (`Dict[str, SafetensorsFileMetadata]`): A map of all files metadata. Keys are filenames and values are the metadata of the corresponding file, as a [`SafetensorsFileMetadata`] object. parameter_count (`Dict[str, int]`): A map of the number of parameters per data type. Keys are data types and values are the number of parameters of that data type. """ metadata: Optional[Dict] sharded: bool weight_map: Dict[TENSOR_NAME_T, FILENAME_T] # tensor name -> filename files_metadata: Dict[FILENAME_T, SafetensorsFileMetadata] # filename -> metadata parameter_count: Dict[DTYPE_T, int] = field(init=False) def __post_init__(self) -> None: parameter_count: Dict[DTYPE_T, int] = defaultdict(int) for file_metadata in self.files_metadata.values(): for dtype, nb_parameters_ in file_metadata.parameter_count.items(): parameter_count[dtype] += nb_parameters_ self.parameter_count = dict(parameter_count) huggingface_hub-0.31.1/src/huggingface_hub/utils/_subprocess.py000066400000000000000000000110211500667546600246500ustar00rootroot00000000000000# coding=utf-8 # Copyright 2021 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License """Contains utilities to easily handle subprocesses in `huggingface_hub`.""" import os import subprocess import sys from contextlib import contextmanager from io import StringIO from pathlib import Path from typing import IO, Generator, List, Optional, Tuple, Union from .logging import get_logger logger = get_logger(__name__) @contextmanager def capture_output() -> Generator[StringIO, None, None]: """Capture output that is printed to terminal. Taken from https://stackoverflow.com/a/34738440 Example: ```py >>> with capture_output() as output: ... print("hello world") >>> assert output.getvalue() == "hello world\n" ``` """ output = StringIO() previous_output = sys.stdout sys.stdout = output try: yield output finally: sys.stdout = previous_output def run_subprocess( command: Union[str, List[str]], folder: Optional[Union[str, Path]] = None, check=True, **kwargs, ) -> subprocess.CompletedProcess: """ Method to run subprocesses. Calling this will capture the `stderr` and `stdout`, please call `subprocess.run` manually in case you would like for them not to be captured. Args: command (`str` or `List[str]`): The command to execute as a string or list of strings. folder (`str`, *optional*): The folder in which to run the command. Defaults to current working directory (from `os.getcwd()`). check (`bool`, *optional*, defaults to `True`): Setting `check` to `True` will raise a `subprocess.CalledProcessError` when the subprocess has a non-zero exit code. kwargs (`Dict[str]`): Keyword arguments to be passed to the `subprocess.run` underlying command. Returns: `subprocess.CompletedProcess`: The completed process. """ if isinstance(command, str): command = command.split() if isinstance(folder, Path): folder = str(folder) return subprocess.run( command, stderr=subprocess.PIPE, stdout=subprocess.PIPE, check=check, encoding="utf-8", errors="replace", # if not utf-8, replace char by � cwd=folder or os.getcwd(), **kwargs, ) @contextmanager def run_interactive_subprocess( command: Union[str, List[str]], folder: Optional[Union[str, Path]] = None, **kwargs, ) -> Generator[Tuple[IO[str], IO[str]], None, None]: """Run a subprocess in an interactive mode in a context manager. Args: command (`str` or `List[str]`): The command to execute as a string or list of strings. folder (`str`, *optional*): The folder in which to run the command. Defaults to current working directory (from `os.getcwd()`). kwargs (`Dict[str]`): Keyword arguments to be passed to the `subprocess.run` underlying command. Returns: `Tuple[IO[str], IO[str]]`: A tuple with `stdin` and `stdout` to interact with the process (input and output are utf-8 encoded). Example: ```python with _interactive_subprocess("git credential-store get") as (stdin, stdout): # Write to stdin stdin.write("url=hf.co\nusername=obama\n".encode("utf-8")) stdin.flush() # Read from stdout output = stdout.read().decode("utf-8") ``` """ if isinstance(command, str): command = command.split() with subprocess.Popen( command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8", errors="replace", # if not utf-8, replace char by � cwd=folder or os.getcwd(), **kwargs, ) as process: assert process.stdin is not None, "subprocess is opened as subprocess.PIPE" assert process.stdout is not None, "subprocess is opened as subprocess.PIPE" yield process.stdin, process.stdout huggingface_hub-0.31.1/src/huggingface_hub/utils/_telemetry.py000066400000000000000000000114321500667546600245000ustar00rootroot00000000000000from queue import Queue from threading import Lock, Thread from typing import Dict, Optional, Union from urllib.parse import quote from .. import constants, logging from . import build_hf_headers, get_session, hf_raise_for_status logger = logging.get_logger(__name__) # Telemetry is sent by a separate thread to avoid blocking the main thread. # A daemon thread is started once and consume tasks from the _TELEMETRY_QUEUE. # If the thread stops for some reason -shouldn't happen-, we restart a new one. _TELEMETRY_THREAD: Optional[Thread] = None _TELEMETRY_THREAD_LOCK = Lock() # Lock to avoid starting multiple threads in parallel _TELEMETRY_QUEUE: Queue = Queue() def send_telemetry( topic: str, *, library_name: Optional[str] = None, library_version: Optional[str] = None, user_agent: Union[Dict, str, None] = None, ) -> None: """ Sends telemetry that helps tracking usage of different HF libraries. This usage data helps us debug issues and prioritize new features. However, we understand that not everyone wants to share additional information, and we respect your privacy. You can disable telemetry collection by setting the `HF_HUB_DISABLE_TELEMETRY=1` as environment variable. Telemetry is also disabled in offline mode (i.e. when setting `HF_HUB_OFFLINE=1`). Telemetry collection is run in a separate thread to minimize impact for the user. Args: topic (`str`): Name of the topic that is monitored. The topic is directly used to build the URL. If you want to monitor subtopics, just use "/" separation. Examples: "gradio", "transformers/examples",... library_name (`str`, *optional*): The name of the library that is making the HTTP request. Will be added to the user-agent header. library_version (`str`, *optional*): The version of the library that is making the HTTP request. Will be added to the user-agent header. user_agent (`str`, `dict`, *optional*): The user agent info in the form of a dictionary or a single string. It will be completed with information about the installed packages. Example: ```py >>> from huggingface_hub.utils import send_telemetry # Send telemetry without library information >>> send_telemetry("ping") # Send telemetry to subtopic with library information >>> send_telemetry("gradio/local_link", library_name="gradio", library_version="3.22.1") # Send telemetry with additional data >>> send_telemetry( ... topic="examples", ... library_name="transformers", ... library_version="4.26.0", ... user_agent={"pipeline": "text_classification", "framework": "flax"}, ... ) ``` """ if constants.HF_HUB_OFFLINE or constants.HF_HUB_DISABLE_TELEMETRY: return _start_telemetry_thread() # starts thread only if doesn't exist yet _TELEMETRY_QUEUE.put( {"topic": topic, "library_name": library_name, "library_version": library_version, "user_agent": user_agent} ) def _start_telemetry_thread(): """Start a daemon thread to consume tasks from the telemetry queue. If the thread is interrupted, start a new one. """ with _TELEMETRY_THREAD_LOCK: # avoid to start multiple threads if called concurrently global _TELEMETRY_THREAD if _TELEMETRY_THREAD is None or not _TELEMETRY_THREAD.is_alive(): _TELEMETRY_THREAD = Thread(target=_telemetry_worker, daemon=True) _TELEMETRY_THREAD.start() def _telemetry_worker(): """Wait for a task and consume it.""" while True: kwargs = _TELEMETRY_QUEUE.get() _send_telemetry_in_thread(**kwargs) _TELEMETRY_QUEUE.task_done() def _send_telemetry_in_thread( topic: str, *, library_name: Optional[str] = None, library_version: Optional[str] = None, user_agent: Union[Dict, str, None] = None, ) -> None: """Contains the actual data sending data to the Hub. This function is called directly in gradio's analytics because it is not possible to send telemetry from a daemon thread. See here: https://github.com/gradio-app/gradio/pull/8180 Please do not rename or remove this function. """ path = "/".join(quote(part) for part in topic.split("/") if len(part) > 0) try: r = get_session().head( f"{constants.ENDPOINT}/api/telemetry/{path}", headers=build_hf_headers( token=False, # no need to send a token for telemetry library_name=library_name, library_version=library_version, user_agent=user_agent, ), ) hf_raise_for_status(r) except Exception as e: # We don't want to error in case of connection errors of any kind. logger.debug(f"Error while sending telemetry: {e}") huggingface_hub-0.31.1/src/huggingface_hub/utils/_typing.py000066400000000000000000000055271500667546600240100ustar00rootroot00000000000000# coding=utf-8 # Copyright 2022-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Handle typing imports based on system compatibility.""" import sys from typing import Any, Callable, List, Literal, Type, TypeVar, Union, get_args, get_origin UNION_TYPES: List[Any] = [Union] if sys.version_info >= (3, 10): from types import UnionType UNION_TYPES += [UnionType] HTTP_METHOD_T = Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"] # type hint meaning "function signature not changed by decorator" CallableT = TypeVar("CallableT", bound=Callable) _JSON_SERIALIZABLE_TYPES = (int, float, str, bool, type(None)) def is_jsonable(obj: Any) -> bool: """Check if an object is JSON serializable. This is a weak check, as it does not check for the actual JSON serialization, but only for the types of the object. It works correctly for basic use cases but do not guarantee an exhaustive check. Object is considered to be recursively json serializable if: - it is an instance of int, float, str, bool, or NoneType - it is a list or tuple and all its items are json serializable - it is a dict and all its keys are strings and all its values are json serializable """ try: if isinstance(obj, _JSON_SERIALIZABLE_TYPES): return True if isinstance(obj, (list, tuple)): return all(is_jsonable(item) for item in obj) if isinstance(obj, dict): return all(isinstance(key, _JSON_SERIALIZABLE_TYPES) and is_jsonable(value) for key, value in obj.items()) if hasattr(obj, "__json__"): return True return False except RecursionError: return False def is_simple_optional_type(type_: Type) -> bool: """Check if a type is optional, i.e. Optional[Type] or Union[Type, None] or Type | None, where Type is a non-composite type.""" if get_origin(type_) in UNION_TYPES: union_args = get_args(type_) if len(union_args) == 2 and type(None) in union_args: return True return False def unwrap_simple_optional_type(optional_type: Type) -> Type: """Unwraps a simple optional type, i.e. returns Type from Optional[Type].""" for arg in get_args(optional_type): if arg is not type(None): return arg raise ValueError(f"'{optional_type}' is not an optional type") huggingface_hub-0.31.1/src/huggingface_hub/utils/_validators.py000066400000000000000000000217641500667546600246470ustar00rootroot00000000000000# coding=utf-8 # Copyright 2022-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains utilities to validate argument values in `huggingface_hub`.""" import inspect import re import warnings from functools import wraps from itertools import chain from typing import Any, Dict from huggingface_hub.errors import HFValidationError from ._typing import CallableT REPO_ID_REGEX = re.compile( r""" ^ (\b[\w\-.]+\b/)? # optional namespace (username or organization) \b # starts with a word boundary [\w\-.]{1,96} # repo_name: alphanumeric + . _ - \b # ends with a word boundary $ """, flags=re.VERBOSE, ) def validate_hf_hub_args(fn: CallableT) -> CallableT: """Validate values received as argument for any public method of `huggingface_hub`. The goal of this decorator is to harmonize validation of arguments reused everywhere. By default, all defined validators are tested. Validators: - [`~utils.validate_repo_id`]: `repo_id` must be `"repo_name"` or `"namespace/repo_name"`. Namespace is a username or an organization. - [`~utils.smoothly_deprecate_use_auth_token`]: Use `token` instead of `use_auth_token` (only if `use_auth_token` is not expected by the decorated function - in practice, always the case in `huggingface_hub`). Example: ```py >>> from huggingface_hub.utils import validate_hf_hub_args >>> @validate_hf_hub_args ... def my_cool_method(repo_id: str): ... print(repo_id) >>> my_cool_method(repo_id="valid_repo_id") valid_repo_id >>> my_cool_method("other..repo..id") huggingface_hub.utils._validators.HFValidationError: Cannot have -- or .. in repo_id: 'other..repo..id'. >>> my_cool_method(repo_id="other..repo..id") huggingface_hub.utils._validators.HFValidationError: Cannot have -- or .. in repo_id: 'other..repo..id'. >>> @validate_hf_hub_args ... def my_cool_auth_method(token: str): ... print(token) >>> my_cool_auth_method(token="a token") "a token" >>> my_cool_auth_method(use_auth_token="a use_auth_token") "a use_auth_token" >>> my_cool_auth_method(token="a token", use_auth_token="a use_auth_token") UserWarning: Both `token` and `use_auth_token` are passed (...) "a token" ``` Raises: [`~utils.HFValidationError`]: If an input is not valid. """ # TODO: add an argument to opt-out validation for specific argument? signature = inspect.signature(fn) # Should the validator switch `use_auth_token` values to `token`? In practice, always # True in `huggingface_hub`. Might not be the case in a downstream library. check_use_auth_token = "use_auth_token" not in signature.parameters and "token" in signature.parameters @wraps(fn) def _inner_fn(*args, **kwargs): has_token = False for arg_name, arg_value in chain( zip(signature.parameters, args), # Args values kwargs.items(), # Kwargs values ): if arg_name in ["repo_id", "from_id", "to_id"]: validate_repo_id(arg_value) elif arg_name == "token" and arg_value is not None: has_token = True if check_use_auth_token: kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs) return fn(*args, **kwargs) return _inner_fn # type: ignore def validate_repo_id(repo_id: str) -> None: """Validate `repo_id` is valid. This is not meant to replace the proper validation made on the Hub but rather to avoid local inconsistencies whenever possible (example: passing `repo_type` in the `repo_id` is forbidden). Rules: - Between 1 and 96 characters. - Either "repo_name" or "namespace/repo_name" - [a-zA-Z0-9] or "-", "_", "." - "--" and ".." are forbidden Valid: `"foo"`, `"foo/bar"`, `"123"`, `"Foo-BAR_foo.bar123"` Not valid: `"datasets/foo/bar"`, `".repo_id"`, `"foo--bar"`, `"foo.git"` Example: ```py >>> from huggingface_hub.utils import validate_repo_id >>> validate_repo_id(repo_id="valid_repo_id") >>> validate_repo_id(repo_id="other..repo..id") huggingface_hub.utils._validators.HFValidationError: Cannot have -- or .. in repo_id: 'other..repo..id'. ``` Discussed in https://github.com/huggingface/huggingface_hub/issues/1008. In moon-landing (internal repository): - https://github.com/huggingface/moon-landing/blob/main/server/lib/Names.ts#L27 - https://github.com/huggingface/moon-landing/blob/main/server/views/components/NewRepoForm/NewRepoForm.svelte#L138 """ if not isinstance(repo_id, str): # Typically, a Path is not a repo_id raise HFValidationError(f"Repo id must be a string, not {type(repo_id)}: '{repo_id}'.") if repo_id.count("/") > 1: raise HFValidationError( "Repo id must be in the form 'repo_name' or 'namespace/repo_name':" f" '{repo_id}'. Use `repo_type` argument if needed." ) if not REPO_ID_REGEX.match(repo_id): raise HFValidationError( "Repo id must use alphanumeric chars or '-', '_', '.', '--' and '..' are" " forbidden, '-' and '.' cannot start or end the name, max length is 96:" f" '{repo_id}'." ) if "--" in repo_id or ".." in repo_id: raise HFValidationError(f"Cannot have -- or .. in repo_id: '{repo_id}'.") if repo_id.endswith(".git"): raise HFValidationError(f"Repo_id cannot end by '.git': '{repo_id}'.") def smoothly_deprecate_use_auth_token(fn_name: str, has_token: bool, kwargs: Dict[str, Any]) -> Dict[str, Any]: """Smoothly deprecate `use_auth_token` in the `huggingface_hub` codebase. The long-term goal is to remove any mention of `use_auth_token` in the codebase in favor of a unique and less verbose `token` argument. This will be done a few steps: 0. Step 0: methods that require a read-access to the Hub use the `use_auth_token` argument (`str`, `bool` or `None`). Methods requiring write-access have a `token` argument (`str`, `None`). This implicit rule exists to be able to not send the token when not necessary (`use_auth_token=False`) even if logged in. 1. Step 1: we want to harmonize everything and use `token` everywhere (supporting `token=False` for read-only methods). In order not to break existing code, if `use_auth_token` is passed to a function, the `use_auth_token` value is passed as `token` instead, without any warning. a. Corner case: if both `use_auth_token` and `token` values are passed, a warning is thrown and the `use_auth_token` value is ignored. 2. Step 2: Once it is release, we should push downstream libraries to switch from `use_auth_token` to `token` as much as possible, but without throwing a warning (e.g. manually create issues on the corresponding repos). 3. Step 3: After a transitional period (6 months e.g. until April 2023?), we update `huggingface_hub` to throw a warning on `use_auth_token`. Hopefully, very few users will be impacted as it would have already been fixed. In addition, unit tests in `huggingface_hub` must be adapted to expect warnings to be thrown (but still use `use_auth_token` as before). 4. Step 4: After a normal deprecation cycle (3 releases ?), remove this validator. `use_auth_token` will definitely not be supported. In addition, we update unit tests in `huggingface_hub` to use `token` everywhere. This has been discussed in: - https://github.com/huggingface/huggingface_hub/issues/1094. - https://github.com/huggingface/huggingface_hub/pull/928 - (related) https://github.com/huggingface/huggingface_hub/pull/1064 """ new_kwargs = kwargs.copy() # do not mutate input ! use_auth_token = new_kwargs.pop("use_auth_token", None) # remove from kwargs if use_auth_token is not None: if has_token: warnings.warn( "Both `token` and `use_auth_token` are passed to" f" `{fn_name}` with non-None values. `token` is now the" " preferred argument to pass a User Access Token." " `use_auth_token` value will be ignored." ) else: # `token` argument is not passed and a non-None value is passed in # `use_auth_token` => use `use_auth_token` value as `token` kwarg. new_kwargs["token"] = use_auth_token return new_kwargs huggingface_hub-0.31.1/src/huggingface_hub/utils/_xet.py000066400000000000000000000155541500667546600232770ustar00rootroot00000000000000from dataclasses import dataclass from enum import Enum from typing import Dict, Optional import requests from .. import constants from . import get_session, hf_raise_for_status, validate_hf_hub_args class XetTokenType(str, Enum): READ = "read" WRITE = "write" @dataclass(frozen=True) class XetFileData: file_hash: str refresh_route: str @dataclass(frozen=True) class XetConnectionInfo: access_token: str expiration_unix_epoch: int endpoint: str def parse_xet_file_data_from_response(response: requests.Response) -> Optional[XetFileData]: """ Parse XET file metadata from an HTTP response. This function extracts XET file metadata from the HTTP headers or HTTP links of a given response object. If the required metadata is not found, it returns `None`. Args: response (`requests.Response`): The HTTP response object containing headers dict and links dict to extract the XET metadata from. Returns: `Optional[XetFileData]`: An instance of `XetFileData` containing the file hash and refresh route if the metadata is found. Returns `None` if the required metadata is missing. """ if response is None: return None try: file_hash = response.headers[constants.HUGGINGFACE_HEADER_X_XET_HASH] if constants.HUGGINGFACE_HEADER_LINK_XET_AUTH_KEY in response.links: refresh_route = response.links[constants.HUGGINGFACE_HEADER_LINK_XET_AUTH_KEY]["url"] else: refresh_route = response.headers[constants.HUGGINGFACE_HEADER_X_XET_REFRESH_ROUTE] except KeyError: return None return XetFileData( file_hash=file_hash, refresh_route=refresh_route, ) def parse_xet_connection_info_from_headers(headers: Dict[str, str]) -> Optional[XetConnectionInfo]: """ Parse XET connection info from the HTTP headers or return None if not found. Args: headers (`Dict`): HTTP headers to extract the XET metadata from. Returns: `XetConnectionInfo` or `None`: The information needed to connect to the XET storage service. Returns `None` if the headers do not contain the XET connection info. """ try: endpoint = headers[constants.HUGGINGFACE_HEADER_X_XET_ENDPOINT] access_token = headers[constants.HUGGINGFACE_HEADER_X_XET_ACCESS_TOKEN] expiration_unix_epoch = int(headers[constants.HUGGINGFACE_HEADER_X_XET_EXPIRATION]) except (KeyError, ValueError, TypeError): return None return XetConnectionInfo( endpoint=endpoint, access_token=access_token, expiration_unix_epoch=expiration_unix_epoch, ) @validate_hf_hub_args def refresh_xet_connection_info( *, file_data: XetFileData, headers: Dict[str, str], ) -> XetConnectionInfo: """ Utilizes the information in the parsed metadata to request the Hub xet connection information. This includes the access token, expiration, and XET service URL. Args: file_data: (`XetFileData`): The file data needed to refresh the xet connection information. headers (`Dict[str, str]`): Headers to use for the request, including authorization headers and user agent. Returns: `XetConnectionInfo`: The connection information needed to make the request to the xet storage service. Raises: [`~utils.HfHubHTTPError`] If the Hub API returned an error. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If the Hub API response is improperly formatted. """ if file_data.refresh_route is None: raise ValueError("The provided xet metadata does not contain a refresh endpoint.") return _fetch_xet_connection_info_with_url(file_data.refresh_route, headers) @validate_hf_hub_args def fetch_xet_connection_info_from_repo_info( *, token_type: XetTokenType, repo_id: str, repo_type: str, revision: Optional[str] = None, headers: Dict[str, str], endpoint: Optional[str] = None, params: Optional[Dict[str, str]] = None, ) -> XetConnectionInfo: """ Uses the repo info to request a xet access token from Hub. Args: token_type (`XetTokenType`): Type of the token to request: `"read"` or `"write"`. repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. repo_type (`str`): Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`. revision (`str`, `optional`): The revision of the repo to get the token for. headers (`Dict[str, str]`): Headers to use for the request, including authorization headers and user agent. endpoint (`str`, `optional`): The endpoint to use for the request. Defaults to the Hub endpoint. params (`Dict[str, str]`, `optional`): Additional parameters to pass with the request. Returns: `XetConnectionInfo`: The connection information needed to make the request to the xet storage service. Raises: [`~utils.HfHubHTTPError`] If the Hub API returned an error. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If the Hub API response is improperly formatted. """ endpoint = endpoint if endpoint is not None else constants.ENDPOINT url = f"{endpoint}/api/{repo_type}s/{repo_id}/xet-{token_type.value}-token/{revision}" return _fetch_xet_connection_info_with_url(url, headers, params) @validate_hf_hub_args def _fetch_xet_connection_info_with_url( url: str, headers: Dict[str, str], params: Optional[Dict[str, str]] = None, ) -> XetConnectionInfo: """ Requests the xet connection info from the supplied URL. This includes the access token, expiration time, and endpoint to use for the xet storage service. Args: url: (`str`): The access token endpoint URL. headers (`Dict[str, str]`): Headers to use for the request, including authorization headers and user agent. params (`Dict[str, str]`, `optional`): Additional parameters to pass with the request. Returns: `XetConnectionInfo`: The connection information needed to make the request to the xet storage service. Raises: [`~utils.HfHubHTTPError`] If the Hub API returned an error. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If the Hub API response is improperly formatted. """ resp = get_session().get(headers=headers, url=url, params=params) hf_raise_for_status(resp) metadata = parse_xet_connection_info_from_headers(resp.headers) # type: ignore if metadata is None: raise ValueError("Xet headers have not been correctly set by the server.") return metadata huggingface_hub-0.31.1/src/huggingface_hub/utils/endpoint_helpers.py000066400000000000000000000044761500667546600257030ustar00rootroot00000000000000# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Helpful utility functions and classes in relation to exploring API endpoints with the aim for a user-friendly interface. """ import math import re from typing import TYPE_CHECKING from ..repocard_data import ModelCardData if TYPE_CHECKING: from ..hf_api import ModelInfo def _is_emission_within_threshold(model_info: "ModelInfo", minimum_threshold: float, maximum_threshold: float) -> bool: """Checks if a model's emission is within a given threshold. Args: model_info (`ModelInfo`): A model info object containing the model's emission information. minimum_threshold (`float`): A minimum carbon threshold to filter by, such as 1. maximum_threshold (`float`): A maximum carbon threshold to filter by, such as 10. Returns: `bool`: Whether the model's emission is within the given threshold. """ if minimum_threshold is None and maximum_threshold is None: raise ValueError("Both `minimum_threshold` and `maximum_threshold` cannot both be `None`") if minimum_threshold is None: minimum_threshold = -1 if maximum_threshold is None: maximum_threshold = math.inf card_data = getattr(model_info, "card_data", None) if card_data is None or not isinstance(card_data, (dict, ModelCardData)): return False # Get CO2 emission metadata emission = card_data.get("co2_eq_emissions", None) if isinstance(emission, dict): emission = emission["emissions"] if not emission: return False # Filter out if value is missing or out of range matched = re.search(r"\d+\.\d+|\d+", str(emission)) if matched is None: return False emission_value = float(matched.group(0)) return minimum_threshold <= emission_value <= maximum_threshold huggingface_hub-0.31.1/src/huggingface_hub/utils/insecure_hashlib.py000066400000000000000000000020421500667546600256330ustar00rootroot00000000000000# Taken from https://github.com/mlflow/mlflow/pull/10119 # # DO NOT use this function for security purposes (e.g., password hashing). # # In Python >= 3.9, insecure hashing algorithms such as MD5 fail in FIPS-compliant # environments unless `usedforsecurity=False` is explicitly passed. # # References: # - https://github.com/mlflow/mlflow/issues/9905 # - https://github.com/mlflow/mlflow/pull/10119 # - https://docs.python.org/3/library/hashlib.html # - https://github.com/huggingface/transformers/pull/27038 # # Usage: # ```python # # Use # from huggingface_hub.utils.insecure_hashlib import sha256 # # instead of # from hashlib import sha256 # # # Use # from huggingface_hub.utils import insecure_hashlib # # instead of # import hashlib # ``` import functools import hashlib import sys _kwargs = {"usedforsecurity": False} if sys.version_info >= (3, 9) else {} md5 = functools.partial(hashlib.md5, **_kwargs) sha1 = functools.partial(hashlib.sha1, **_kwargs) sha256 = functools.partial(hashlib.sha256, **_kwargs) huggingface_hub-0.31.1/src/huggingface_hub/utils/logging.py000066400000000000000000000114551500667546600237620ustar00rootroot00000000000000# coding=utf-8 # Copyright 2020 Optuna, Hugging Face # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Logging utilities.""" import logging import os from logging import ( CRITICAL, # NOQA DEBUG, # NOQA ERROR, # NOQA FATAL, # NOQA INFO, # NOQA NOTSET, # NOQA WARN, # NOQA WARNING, # NOQA ) from typing import Optional from .. import constants log_levels = { "debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING, "error": logging.ERROR, "critical": logging.CRITICAL, } _default_log_level = logging.WARNING def _get_library_name() -> str: return __name__.split(".")[0] def _get_library_root_logger() -> logging.Logger: return logging.getLogger(_get_library_name()) def _get_default_logging_level(): """ If `HF_HUB_VERBOSITY` env var is set to one of the valid choices return that as the new default level. If it is not - fall back to `_default_log_level` """ env_level_str = os.getenv("HF_HUB_VERBOSITY", None) if env_level_str: if env_level_str in log_levels: return log_levels[env_level_str] else: logging.getLogger().warning( f"Unknown option HF_HUB_VERBOSITY={env_level_str}, has to be one of: {', '.join(log_levels.keys())}" ) return _default_log_level def _configure_library_root_logger() -> None: library_root_logger = _get_library_root_logger() library_root_logger.addHandler(logging.StreamHandler()) library_root_logger.setLevel(_get_default_logging_level()) def _reset_library_root_logger() -> None: library_root_logger = _get_library_root_logger() library_root_logger.setLevel(logging.NOTSET) def get_logger(name: Optional[str] = None) -> logging.Logger: """ Returns a logger with the specified name. This function is not supposed to be directly accessed by library users. Args: name (`str`, *optional*): The name of the logger to get, usually the filename Example: ```python >>> from huggingface_hub import get_logger >>> logger = get_logger(__file__) >>> logger.set_verbosity_info() ``` """ if name is None: name = _get_library_name() return logging.getLogger(name) def get_verbosity() -> int: """Return the current level for the HuggingFace Hub's root logger. Returns: Logging level, e.g., `huggingface_hub.logging.DEBUG` and `huggingface_hub.logging.INFO`. HuggingFace Hub has following logging levels: - `huggingface_hub.logging.CRITICAL`, `huggingface_hub.logging.FATAL` - `huggingface_hub.logging.ERROR` - `huggingface_hub.logging.WARNING`, `huggingface_hub.logging.WARN` - `huggingface_hub.logging.INFO` - `huggingface_hub.logging.DEBUG` """ return _get_library_root_logger().getEffectiveLevel() def set_verbosity(verbosity: int) -> None: """ Sets the level for the HuggingFace Hub's root logger. Args: verbosity (`int`): Logging level, e.g., `huggingface_hub.logging.DEBUG` and `huggingface_hub.logging.INFO`. """ _get_library_root_logger().setLevel(verbosity) def set_verbosity_info(): """ Sets the verbosity to `logging.INFO`. """ return set_verbosity(INFO) def set_verbosity_warning(): """ Sets the verbosity to `logging.WARNING`. """ return set_verbosity(WARNING) def set_verbosity_debug(): """ Sets the verbosity to `logging.DEBUG`. """ return set_verbosity(DEBUG) def set_verbosity_error(): """ Sets the verbosity to `logging.ERROR`. """ return set_verbosity(ERROR) def disable_propagation() -> None: """ Disable propagation of the library log outputs. Note that log propagation is disabled by default. """ _get_library_root_logger().propagate = False def enable_propagation() -> None: """ Enable propagation of the library log outputs. Please disable the HuggingFace Hub's default handler to prevent double logging if the root logger has been configured. """ _get_library_root_logger().propagate = True _configure_library_root_logger() if constants.HF_DEBUG: # If `HF_DEBUG` environment variable is set, set the verbosity of `huggingface_hub` logger to `DEBUG`. set_verbosity_debug() huggingface_hub-0.31.1/src/huggingface_hub/utils/sha.py000066400000000000000000000041261500667546600231040ustar00rootroot00000000000000"""Utilities to efficiently compute the SHA 256 hash of a bunch of bytes.""" from typing import BinaryIO, Optional from .insecure_hashlib import sha1, sha256 def sha_fileobj(fileobj: BinaryIO, chunk_size: Optional[int] = None) -> bytes: """ Computes the sha256 hash of the given file object, by chunks of size `chunk_size`. Args: fileobj (file-like object): The File object to compute sha256 for, typically obtained with `open(path, "rb")` chunk_size (`int`, *optional*): The number of bytes to read from `fileobj` at once, defaults to 1MB. Returns: `bytes`: `fileobj`'s sha256 hash as bytes """ chunk_size = chunk_size if chunk_size is not None else 1024 * 1024 sha = sha256() while True: chunk = fileobj.read(chunk_size) sha.update(chunk) if not chunk: break return sha.digest() def git_hash(data: bytes) -> str: """ Computes the git-sha1 hash of the given bytes, using the same algorithm as git. This is equivalent to running `git hash-object`. See https://git-scm.com/docs/git-hash-object for more details. Note: this method is valid for regular files. For LFS files, the proper git hash is supposed to be computed on the pointer file content, not the actual file content. However, for simplicity, we directly compare the sha256 of the LFS file content when we want to compare LFS files. Args: data (`bytes`): The data to compute the git-hash for. Returns: `str`: the git-hash of `data` as an hexadecimal string. Example: ```python >>> from huggingface_hub.utils.sha import git_hash >>> git_hash(b"Hello, World!") 'b45ef6fec89518d314f546fd6c3025367b721684' ``` """ # Taken from https://gist.github.com/msabramo/763200 # Note: no need to optimize by reading the file in chunks as we're not supposed to hash huge files (5MB maximum). sha = sha1() sha.update(b"blob ") sha.update(str(len(data)).encode()) sha.update(b"\0") sha.update(data) return sha.hexdigest() huggingface_hub-0.31.1/src/huggingface_hub/utils/tqdm.py000066400000000000000000000246571500667546600233110ustar00rootroot00000000000000# coding=utf-8 # Copyright 2021 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License """Utility helpers to handle progress bars in `huggingface_hub`. Example: 1. Use `huggingface_hub.utils.tqdm` as you would use `tqdm.tqdm` or `tqdm.auto.tqdm`. 2. To disable progress bars, either use `disable_progress_bars()` helper or set the environment variable `HF_HUB_DISABLE_PROGRESS_BARS` to 1. 3. To re-enable progress bars, use `enable_progress_bars()`. 4. To check whether progress bars are disabled, use `are_progress_bars_disabled()`. NOTE: Environment variable `HF_HUB_DISABLE_PROGRESS_BARS` has the priority. Example: ```py >>> from huggingface_hub.utils import are_progress_bars_disabled, disable_progress_bars, enable_progress_bars, tqdm # Disable progress bars globally >>> disable_progress_bars() # Use as normal `tqdm` >>> for _ in tqdm(range(5)): ... pass # Still not showing progress bars, as `disable=False` is overwritten to `True`. >>> for _ in tqdm(range(5), disable=False): ... pass >>> are_progress_bars_disabled() True # Re-enable progress bars globally >>> enable_progress_bars() # Progress bar will be shown ! >>> for _ in tqdm(range(5)): ... pass 100%|███████████████████████████████████████| 5/5 [00:00<00:00, 117817.53it/s] ``` Group-based control: ```python # Disable progress bars for a specific group >>> disable_progress_bars("peft.foo") # Check state of different groups >>> assert not are_progress_bars_disabled("peft")) >>> assert not are_progress_bars_disabled("peft.something") >>> assert are_progress_bars_disabled("peft.foo")) >>> assert are_progress_bars_disabled("peft.foo.bar")) # Enable progress bars for a subgroup >>> enable_progress_bars("peft.foo.bar") # Check if enabling a subgroup affects the parent group >>> assert are_progress_bars_disabled("peft.foo")) >>> assert not are_progress_bars_disabled("peft.foo.bar")) # No progress bar for `name="peft.foo"` >>> for _ in tqdm(range(5), name="peft.foo"): ... pass # Progress bar will be shown for `name="peft.foo.bar"` >>> for _ in tqdm(range(5), name="peft.foo.bar"): ... pass 100%|███████████████████████████████████████| 5/5 [00:00<00:00, 117817.53it/s] ``` """ import io import logging import os import warnings from contextlib import contextmanager, nullcontext from pathlib import Path from typing import ContextManager, Dict, Iterator, Optional, Union from tqdm.auto import tqdm as old_tqdm from ..constants import HF_HUB_DISABLE_PROGRESS_BARS # The `HF_HUB_DISABLE_PROGRESS_BARS` environment variable can be True, False, or not set (None), # allowing for control over progress bar visibility. When set, this variable takes precedence # over programmatic settings, dictating whether progress bars should be shown or hidden globally. # Essentially, the environment variable's setting overrides any code-based configurations. # # If `HF_HUB_DISABLE_PROGRESS_BARS` is not defined (None), it implies that users can manage # progress bar visibility through code. By default, progress bars are turned on. progress_bar_states: Dict[str, bool] = {} def disable_progress_bars(name: Optional[str] = None) -> None: """ Disable progress bars either globally or for a specified group. This function updates the state of progress bars based on a group name. If no group name is provided, all progress bars are disabled. The operation respects the `HF_HUB_DISABLE_PROGRESS_BARS` environment variable's setting. Args: name (`str`, *optional*): The name of the group for which to disable the progress bars. If None, progress bars are disabled globally. Raises: Warning: If the environment variable precludes changes. """ if HF_HUB_DISABLE_PROGRESS_BARS is False: warnings.warn( "Cannot disable progress bars: environment variable `HF_HUB_DISABLE_PROGRESS_BARS=0` is set and has priority." ) return if name is None: progress_bar_states.clear() progress_bar_states["_global"] = False else: keys_to_remove = [key for key in progress_bar_states if key.startswith(f"{name}.")] for key in keys_to_remove: del progress_bar_states[key] progress_bar_states[name] = False def enable_progress_bars(name: Optional[str] = None) -> None: """ Enable progress bars either globally or for a specified group. This function sets the progress bars to enabled for the specified group or globally if no group is specified. The operation is subject to the `HF_HUB_DISABLE_PROGRESS_BARS` environment setting. Args: name (`str`, *optional*): The name of the group for which to enable the progress bars. If None, progress bars are enabled globally. Raises: Warning: If the environment variable precludes changes. """ if HF_HUB_DISABLE_PROGRESS_BARS is True: warnings.warn( "Cannot enable progress bars: environment variable `HF_HUB_DISABLE_PROGRESS_BARS=1` is set and has priority." ) return if name is None: progress_bar_states.clear() progress_bar_states["_global"] = True else: keys_to_remove = [key for key in progress_bar_states if key.startswith(f"{name}.")] for key in keys_to_remove: del progress_bar_states[key] progress_bar_states[name] = True def are_progress_bars_disabled(name: Optional[str] = None) -> bool: """ Check if progress bars are disabled globally or for a specific group. This function returns whether progress bars are disabled for a given group or globally. It checks the `HF_HUB_DISABLE_PROGRESS_BARS` environment variable first, then the programmatic settings. Args: name (`str`, *optional*): The group name to check; if None, checks the global setting. Returns: `bool`: True if progress bars are disabled, False otherwise. """ if HF_HUB_DISABLE_PROGRESS_BARS is True: return True if name is None: return not progress_bar_states.get("_global", True) while name: if name in progress_bar_states: return not progress_bar_states[name] name = ".".join(name.split(".")[:-1]) return not progress_bar_states.get("_global", True) def is_tqdm_disabled(log_level: int) -> Optional[bool]: """ Determine if tqdm progress bars should be disabled based on logging level and environment settings. see https://github.com/huggingface/huggingface_hub/pull/2000 and https://github.com/huggingface/huggingface_hub/pull/2698. """ if log_level == logging.NOTSET: return True if os.getenv("TQDM_POSITION") == "-1": return False return None class tqdm(old_tqdm): """ Class to override `disable` argument in case progress bars are globally disabled. Taken from https://github.com/tqdm/tqdm/issues/619#issuecomment-619639324. """ def __init__(self, *args, **kwargs): name = kwargs.pop("name", None) # do not pass `name` to `tqdm` if are_progress_bars_disabled(name): kwargs["disable"] = True super().__init__(*args, **kwargs) def __delattr__(self, attr: str) -> None: """Fix for https://github.com/huggingface/huggingface_hub/issues/1603""" try: super().__delattr__(attr) except AttributeError: if attr != "_lock": raise @contextmanager def tqdm_stream_file(path: Union[Path, str]) -> Iterator[io.BufferedReader]: """ Open a file as binary and wrap the `read` method to display a progress bar when it's streamed. First implemented in `transformers` in 2019 but removed when switched to git-lfs. Used in `huggingface_hub` to show progress bar when uploading an LFS file to the Hub. See github.com/huggingface/transformers/pull/2078#discussion_r354739608 for implementation details. Note: currently implementation handles only files stored on disk as it is the most common use case. Could be extended to stream any `BinaryIO` object but we might have to debug some corner cases. Example: ```py >>> with tqdm_stream_file("config.json") as f: >>> requests.put(url, data=f) config.json: 100%|█████████████████████████| 8.19k/8.19k [00:02<00:00, 3.72kB/s] ``` """ if isinstance(path, str): path = Path(path) with path.open("rb") as f: total_size = path.stat().st_size pbar = tqdm( unit="B", unit_scale=True, total=total_size, initial=0, desc=path.name, ) f_read = f.read def _inner_read(size: Optional[int] = -1) -> bytes: data = f_read(size) pbar.update(len(data)) return data f.read = _inner_read # type: ignore yield f pbar.close() def _get_progress_bar_context( *, desc: str, log_level: int, total: Optional[int] = None, initial: int = 0, unit: str = "B", unit_scale: bool = True, name: Optional[str] = None, _tqdm_bar: Optional[tqdm] = None, ) -> ContextManager[tqdm]: if _tqdm_bar is not None: return nullcontext(_tqdm_bar) # ^ `contextlib.nullcontext` mimics a context manager that does nothing # Makes it easier to use the same code path for both cases but in the later # case, the progress bar is not closed when exiting the context manager. return tqdm( unit=unit, unit_scale=unit_scale, total=total, initial=initial, desc=desc, disable=is_tqdm_disabled(log_level=log_level), name=name, ) huggingface_hub-0.31.1/tests/000077500000000000000000000000001500667546600160625ustar00rootroot00000000000000huggingface_hub-0.31.1/tests/README.md000066400000000000000000000006621500667546600173450ustar00rootroot00000000000000# Running Tests To run the test suite, please perform the following from the root directory of this repository: 1. `pip install -e .[testing]` This will install all the testing requirements. 2. `sudo apt-get update; sudo apt-get install git-lfs -y` We need git-lfs on our system to run some of the tests 3. `pytest ./tests/` We need to set an environmental variable to make sure the private API tests can run. huggingface_hub-0.31.1/tests/__init__.py000066400000000000000000000000001500667546600201610ustar00rootroot00000000000000huggingface_hub-0.31.1/tests/cassettes/000077500000000000000000000000001500667546600200605ustar00rootroot00000000000000huggingface_hub-0.31.1/tests/cassettes/InferenceApiTest.test_inference_missing_input.yaml000066400000000000000000000417371500667546600321140ustar00rootroot00000000000000interactions: - request: body: null headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate Connection: - keep-alive X-Amzn-Trace-Id: - ecfdb72e-5e5e-43ee-a695-54ac3c83fddd method: GET uri: https://huggingface.co/api/models/deepset/roberta-base-squad2 response: body: string: "{\"_id\":\"621ffdc136468d709f17a5fd\",\"id\":\"deepset/roberta-base-squad2\",\"private\":false,\"pipeline_tag\":\"question-answering\",\"library_name\":\"transformers\",\"tags\":[\"transformers\",\"pytorch\",\"tf\",\"jax\",\"rust\",\"safetensors\",\"roberta\",\"question-answering\",\"en\",\"dataset:squad_v2\",\"base_model:FacebookAI/roberta-base\",\"base_model:finetune:FacebookAI/roberta-base\",\"license:cc-by-4.0\",\"model-index\",\"endpoints_compatible\",\"region:us\"],\"downloads\":1801802,\"likes\":842,\"modelId\":\"deepset/roberta-base-squad2\",\"author\":\"deepset\",\"sha\":\"adc3b06f79f797d1c575d5479d6f5efe54a9e3b4\",\"lastModified\":\"2024-09-24T15:48:47.000Z\",\"gated\":false,\"inference\":\"warm\",\"disabled\":false,\"mask_token\":\"\",\"widgetData\":[{\"text\":\"Where do I live?\",\"context\":\"My name is Wolfgang and I live in Berlin\"},{\"text\":\"Where do I live?\",\"context\":\"My name is Sarah and I live in London\"},{\"text\":\"What's my name?\",\"context\":\"My name is Clara and I live in Berkeley.\"},{\"text\":\"Which name is also used to describe the Amazon rainforest in English?\",\"context\":\"The Amazon rainforest (Portuguese: Floresta Amaz\xF4nica or Amaz\xF4nia; Spanish: Selva Amaz\xF3nica, Amazon\xEDa or usually Amazonia; French: For\xEAt amazonienne; Dutch: Amazoneregenwoud), also known in English as Amazonia or the Amazon Jungle, is a moist broadleaf forest that covers most of the Amazon basin of South America. This basin encompasses 7,000,000 square kilometres (2,700,000 sq mi), of which 5,500,000 square kilometres (2,100,000 sq mi) are covered by the rainforest. This region includes territory belonging to nine nations. The majority of the forest is contained within Brazil, with 60% of the rainforest, followed by Peru with 13%, Colombia with 10%, and with minor amounts in Venezuela, Ecuador, Bolivia, Guyana, Suriname and French Guiana. States or departments in four nations contain \\\"Amazonas\\\" in their names. The Amazon represents over half of the planet's remaining rainforests, and comprises the largest and most biodiverse tract of tropical rainforest in the world, with an estimated 390 billion individual trees divided into 16,000 species.\"}],\"model-index\":[{\"name\":\"deepset/roberta-base-squad2\",\"results\":[{\"task\":{\"type\":\"question-answering\",\"name\":\"Question Answering\"},\"dataset\":{\"name\":\"squad_v2\",\"type\":\"squad_v2\",\"config\":\"squad_v2\",\"split\":\"validation\"},\"metrics\":[{\"type\":\"exact_match\",\"value\":79.9309,\"name\":\"Exact Match\",\"verified\":true,\"verifyToken\":\"eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiMDhhNjg5YzNiZGQ1YTIyYTAwZGUwOWEzZTRiYzdjM2QzYjA3ZTUxNDM1NjE1MTUyMjE1MGY1YzEzMjRjYzVjYiIsInZlcnNpb24iOjF9.EH5JJo8EEFwU7osPz3s7qanw_tigeCFhCXjSfyN0Y1nWVnSfulSxIk_DbAEI5iE80V4EKLyp5-mYFodWvL2KDA\"},{\"type\":\"f1\",\"value\":82.9501,\"name\":\"F1\",\"verified\":true,\"verifyToken\":\"eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiMjk5ZDYwOGQyNjNkMWI0OTE4YzRmOTlkY2JjNjQ0YTZkNTMzMzNkYTA0MDFmNmI3NjA3NjNlMjhiMDQ2ZjJjNSIsInZlcnNpb24iOjF9.DDm0LNTkdLbGsue58bg1aH_s67KfbcmkvL-6ZiI2s8IoxhHJMSf29H_uV2YLyevwx900t-MwTVOW3qfFnMMEAQ\"},{\"type\":\"total\",\"value\":11869,\"name\":\"total\",\"verified\":true,\"verifyToken\":\"eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiMGFkMmI2ODM0NmY5NGNkNmUxYWViOWYxZDNkY2EzYWFmOWI4N2VhYzY5MGEzMTVhOTU4Zjc4YWViOGNjOWJjMCIsInZlcnNpb24iOjF9.fexrU1icJK5_MiifBtZWkeUvpmFISqBLDXSQJ8E6UnrRof-7cU0s4tX_dIsauHWtUpIHMPZCf5dlMWQKXZuAAA\"}]},{\"task\":{\"type\":\"question-answering\",\"name\":\"Question Answering\"},\"dataset\":{\"name\":\"squad\",\"type\":\"squad\",\"config\":\"plain_text\",\"split\":\"validation\"},\"metrics\":[{\"type\":\"exact_match\",\"value\":85.289,\"name\":\"Exact Match\",\"verified\":false},{\"type\":\"f1\",\"value\":91.841,\"name\":\"F1\",\"verified\":false}]},{\"task\":{\"type\":\"question-answering\",\"name\":\"Question Answering\"},\"dataset\":{\"name\":\"adversarial_qa\",\"type\":\"adversarial_qa\",\"config\":\"adversarialQA\",\"split\":\"validation\"},\"metrics\":[{\"type\":\"exact_match\",\"value\":29.5,\"name\":\"Exact Match\",\"verified\":false},{\"type\":\"f1\",\"value\":40.367,\"name\":\"F1\",\"verified\":false}]},{\"task\":{\"type\":\"question-answering\",\"name\":\"Question Answering\"},\"dataset\":{\"name\":\"squad_adversarial\",\"type\":\"squad_adversarial\",\"config\":\"AddOneSent\",\"split\":\"validation\"},\"metrics\":[{\"type\":\"exact_match\",\"value\":78.567,\"name\":\"Exact Match\",\"verified\":false},{\"type\":\"f1\",\"value\":84.469,\"name\":\"F1\",\"verified\":false}]},{\"task\":{\"type\":\"question-answering\",\"name\":\"Question Answering\"},\"dataset\":{\"name\":\"squadshifts amazon\",\"type\":\"squadshifts\",\"config\":\"amazon\",\"split\":\"test\"},\"metrics\":[{\"type\":\"exact_match\",\"value\":69.924,\"name\":\"Exact Match\",\"verified\":false},{\"type\":\"f1\",\"value\":83.284,\"name\":\"F1\",\"verified\":false}]},{\"task\":{\"type\":\"question-answering\",\"name\":\"Question Answering\"},\"dataset\":{\"name\":\"squadshifts new_wiki\",\"type\":\"squadshifts\",\"config\":\"new_wiki\",\"split\":\"test\"},\"metrics\":[{\"type\":\"exact_match\",\"value\":81.204,\"name\":\"Exact Match\",\"verified\":false},{\"type\":\"f1\",\"value\":90.595,\"name\":\"F1\",\"verified\":false}]},{\"task\":{\"type\":\"question-answering\",\"name\":\"Question Answering\"},\"dataset\":{\"name\":\"squadshifts nyt\",\"type\":\"squadshifts\",\"config\":\"nyt\",\"split\":\"test\"},\"metrics\":[{\"type\":\"exact_match\",\"value\":82.931,\"name\":\"Exact Match\",\"verified\":false},{\"type\":\"f1\",\"value\":90.756,\"name\":\"F1\",\"verified\":false}]},{\"task\":{\"type\":\"question-answering\",\"name\":\"Question Answering\"},\"dataset\":{\"name\":\"squadshifts reddit\",\"type\":\"squadshifts\",\"config\":\"reddit\",\"split\":\"test\"},\"metrics\":[{\"type\":\"exact_match\",\"value\":71.55,\"name\":\"Exact Match\",\"verified\":false},{\"type\":\"f1\",\"value\":82.939,\"name\":\"F1\",\"verified\":false}]}]}],\"config\":{\"architectures\":[\"RobertaForQuestionAnswering\"],\"model_type\":\"roberta\",\"tokenizer_config\":{}},\"cardData\":{\"language\":\"en\",\"license\":\"cc-by-4.0\",\"datasets\":[\"squad_v2\"],\"model-index\":[{\"name\":\"deepset/roberta-base-squad2\",\"results\":[{\"task\":{\"type\":\"question-answering\",\"name\":\"Question Answering\"},\"dataset\":{\"name\":\"squad_v2\",\"type\":\"squad_v2\",\"config\":\"squad_v2\",\"split\":\"validation\"},\"metrics\":[{\"type\":\"exact_match\",\"value\":79.9309,\"name\":\"Exact Match\",\"verified\":true,\"verifyToken\":\"eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiMDhhNjg5YzNiZGQ1YTIyYTAwZGUwOWEzZTRiYzdjM2QzYjA3ZTUxNDM1NjE1MTUyMjE1MGY1YzEzMjRjYzVjYiIsInZlcnNpb24iOjF9.EH5JJo8EEFwU7osPz3s7qanw_tigeCFhCXjSfyN0Y1nWVnSfulSxIk_DbAEI5iE80V4EKLyp5-mYFodWvL2KDA\"},{\"type\":\"f1\",\"value\":82.9501,\"name\":\"F1\",\"verified\":true,\"verifyToken\":\"eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiMjk5ZDYwOGQyNjNkMWI0OTE4YzRmOTlkY2JjNjQ0YTZkNTMzMzNkYTA0MDFmNmI3NjA3NjNlMjhiMDQ2ZjJjNSIsInZlcnNpb24iOjF9.DDm0LNTkdLbGsue58bg1aH_s67KfbcmkvL-6ZiI2s8IoxhHJMSf29H_uV2YLyevwx900t-MwTVOW3qfFnMMEAQ\"},{\"type\":\"total\",\"value\":11869,\"name\":\"total\",\"verified\":true,\"verifyToken\":\"eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiMGFkMmI2ODM0NmY5NGNkNmUxYWViOWYxZDNkY2EzYWFmOWI4N2VhYzY5MGEzMTVhOTU4Zjc4YWViOGNjOWJjMCIsInZlcnNpb24iOjF9.fexrU1icJK5_MiifBtZWkeUvpmFISqBLDXSQJ8E6UnrRof-7cU0s4tX_dIsauHWtUpIHMPZCf5dlMWQKXZuAAA\"}]},{\"task\":{\"type\":\"question-answering\",\"name\":\"Question Answering\"},\"dataset\":{\"name\":\"squad\",\"type\":\"squad\",\"config\":\"plain_text\",\"split\":\"validation\"},\"metrics\":[{\"type\":\"exact_match\",\"value\":85.289,\"name\":\"Exact Match\",\"verified\":false},{\"type\":\"f1\",\"value\":91.841,\"name\":\"F1\",\"verified\":false}]},{\"task\":{\"type\":\"question-answering\",\"name\":\"Question Answering\"},\"dataset\":{\"name\":\"adversarial_qa\",\"type\":\"adversarial_qa\",\"config\":\"adversarialQA\",\"split\":\"validation\"},\"metrics\":[{\"type\":\"exact_match\",\"value\":29.5,\"name\":\"Exact Match\",\"verified\":false},{\"type\":\"f1\",\"value\":40.367,\"name\":\"F1\",\"verified\":false}]},{\"task\":{\"type\":\"question-answering\",\"name\":\"Question Answering\"},\"dataset\":{\"name\":\"squad_adversarial\",\"type\":\"squad_adversarial\",\"config\":\"AddOneSent\",\"split\":\"validation\"},\"metrics\":[{\"type\":\"exact_match\",\"value\":78.567,\"name\":\"Exact Match\",\"verified\":false},{\"type\":\"f1\",\"value\":84.469,\"name\":\"F1\",\"verified\":false}]},{\"task\":{\"type\":\"question-answering\",\"name\":\"Question Answering\"},\"dataset\":{\"name\":\"squadshifts amazon\",\"type\":\"squadshifts\",\"config\":\"amazon\",\"split\":\"test\"},\"metrics\":[{\"type\":\"exact_match\",\"value\":69.924,\"name\":\"Exact Match\",\"verified\":false},{\"type\":\"f1\",\"value\":83.284,\"name\":\"F1\",\"verified\":false}]},{\"task\":{\"type\":\"question-answering\",\"name\":\"Question Answering\"},\"dataset\":{\"name\":\"squadshifts new_wiki\",\"type\":\"squadshifts\",\"config\":\"new_wiki\",\"split\":\"test\"},\"metrics\":[{\"type\":\"exact_match\",\"value\":81.204,\"name\":\"Exact Match\",\"verified\":false},{\"type\":\"f1\",\"value\":90.595,\"name\":\"F1\",\"verified\":false}]},{\"task\":{\"type\":\"question-answering\",\"name\":\"Question Answering\"},\"dataset\":{\"name\":\"squadshifts nyt\",\"type\":\"squadshifts\",\"config\":\"nyt\",\"split\":\"test\"},\"metrics\":[{\"type\":\"exact_match\",\"value\":82.931,\"name\":\"Exact Match\",\"verified\":false},{\"type\":\"f1\",\"value\":90.756,\"name\":\"F1\",\"verified\":false}]},{\"task\":{\"type\":\"question-answering\",\"name\":\"Question Answering\"},\"dataset\":{\"name\":\"squadshifts reddit\",\"type\":\"squadshifts\",\"config\":\"reddit\",\"split\":\"test\"},\"metrics\":[{\"type\":\"exact_match\",\"value\":71.55,\"name\":\"Exact Match\",\"verified\":false},{\"type\":\"f1\",\"value\":82.939,\"name\":\"F1\",\"verified\":false}]}]}],\"base_model\":[\"FacebookAI/roberta-base\"]},\"transformersInfo\":{\"auto_model\":\"AutoModelForQuestionAnswering\",\"pipeline_tag\":\"question-answering\",\"processor\":\"AutoTokenizer\"},\"siblings\":[{\"rfilename\":\".gitattributes\"},{\"rfilename\":\"README.md\"},{\"rfilename\":\"config.json\"},{\"rfilename\":\"flax_model.msgpack\"},{\"rfilename\":\"merges.txt\"},{\"rfilename\":\"model.safetensors\"},{\"rfilename\":\"pytorch_model.bin\"},{\"rfilename\":\"rust_model.ot\"},{\"rfilename\":\"special_tokens_map.json\"},{\"rfilename\":\"tf_model.h5\"},{\"rfilename\":\"tokenizer_config.json\"},{\"rfilename\":\"vocab.json\"}],\"spaces\":[\"microsoft/HuggingGPT\",\"razakhan/text-summarizer\",\"anakin87/who-killed-laura-palmer\",\"AmazonScience/QA-NLU\",\"Hellisotherpeople/HF-SHAP\",\"taesiri/HuggingGPT-Lite\",\"Aeon-Avinash/GenAI_Document_QnA_with_Vision\",\"course-demos/question-answering-simple\",\"Eemansleepdeprived/Study_For_Me_AI\",\"manishjaiswal/05-SOTA-Question-Answer-From-TextFileContext-Demo\",\"nsethi610/ns-gradio-apps\",\"Wootang01/question_answer\",\"raphaelsty/games\",\"Abhilashvj/haystack_QA\",\"IsmayilMasimov36/question-answering-app\",\"jayesh95/Voice-QA\",\"awacke1/CarePlanQnAWithContext\",\"jorge-henao/ask2democracy\",\"awacke1/SOTA-Plan\",\"amsterdamNLP/attention-rollout\",\"AIZ2H/05-SOTA-Question-Answer-From-TextFileContext\",\"drift-ai/question-answer-text\",\"emmetmayer/Large-Context-Question-and-Answering\",\"leomaurodesenv/qasports-website\",\"nkatraga/7.22.CarePlanQnAWithContext\",\"unco3892/real_estate_ie\",\"HemanthSai7/IntelligentQuestionGenerator\",\"Timjo88/toy-board-game-QA\",\"awacke1/NLPContextQATransformersRobertaBaseSquad2\",\"camillevanhoffelen/langchain-HuggingGPT\",\"cyberspyde/chatbot-team4\",\"awacke1/CarePlanQnAWithContext2\",\"williambr/CarePlanSOTAQnA\",\"niksyad/CarePlanQnAWithContext\",\"sdande11/CarePlanQnAWithContext2\",\"cpnepo/Harry-Potter-Q-A\",\"edemgold/QA-App\",\"gulabpatel/Question-Answering_roberta\",\"Chatop/Lab10\",\"awacke1/ContextQuestionAnswerNLP\",\"BilalSardar/QuestionAndAnswer\",\"mishtert/tracer\",\"Sasidhar/information-extraction-demo\",\"Jonni/05-QandA-from-textfile\",\"tracinginsights/QuotesBot\",\"ccarr0807/HuggingGPT\",\"cshallah/qna-ancient-1\",\"theholycityweb/HuggingGPT\",\"hhalim/NLPContextQATransformersRobertaBaseSquad2\",\"abhilashb/NLP-Test\",\"awacke1/NLPDemo1\",\"sanjayw/nlpDemo1\",\"allieannez/NLPContextQASquad2Demo\",\"Alfasign/HuggingGPT-Lite\",\"Kelvinhjk/QnA_chatbot_for_Swinburne_cs_course\",\"Th3BossC/TranscriptApi\",\"saurshaz/HuggingGPT\",\"Jaehan/Question-Answering-1\",\"roshithindia/ayureasybot\",\"MachineLearningReply/search_mlReply\",\"knotmesh/deepset-roberta-base-squad2\",\"AyselRahimli/Project2\",\"Charles95/gradio-tasks\",\"Nikhil0987/omm\",\"umair894/fastapi-document-qa_semantic\",\"swamisharan/pdf-gpt\",\"Manoj21k/Custom-QandA\",\"Rohankumar31/Prakruti_LLM\",\"mikepastor11/PennwickHoneybeeRobot\",\"abdala9512/dsrp-demo-example\",\"Jforeverss/finchat222\",\"aidinro/qqqqqqqqqqqqq\",\"wenchu79/test\",\"AkshaySharma770/meeting-minute-generator-and-question-and-answer-chatbot\",\"Walid-Ahmed/Q_A_with_document\",\"ff98/ctp-audio-image\",\"leonferreira/as05-leon-martins-pucminas\",\"ANASAKHTAR/Document_Question_And_Answer\",\"dakhos/ProjectDarkhan\",\"abhinavyadav11/RAG_Enhanced_Chatbot\",\"JarvisOnSolana/Jarvis\",\"Thouseef1234/chatbot\",\"ddriscoll/EurybiaMini\",\"jaydeepkum/CarePlanQnaWithContext\",\"ziyadbastaili/get_special_meeting\",\"charlesfrye/test-space-117\",\"Myrna/CarePlan\",\"santoshsindham/CarePlanQnAWithContext\",\"PrafulUHG/CarePlan\",\"osanseviero/all_nlp_demos\",\"awacke1/CarePlanSOTAQnA\",\"vnemala/CarePlanSOTAQnA\",\"SudarshanaR/CarePlanQnaWithContext\",\"Vasanthp/CarePlanSOTAQnA\",\"ocordes/CarePlanSOTAQnA\",\"vsaripella/CarePlanSOTAQnA\",\"mm2593/CarePlan\",\"MateusA/CarePlanSOTAQnA\",\"Desh/test1\",\"Preetesh/CarePlanQnAWithContext\"],\"createdAt\":\"2022-03-02T23:29:05.000Z\",\"safetensors\":{\"parameters\":{\"F32\":124056578,\"I64\":514},\"total\":124057092},\"usedStorage\":3943613347}" headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '12780' Content-Type: - application/json; charset=utf-8 Date: - Wed, 22 Jan 2025 18:24:32 GMT ETag: - W/"31ec-CFYGviWwMc4TdzaDXh40zZgdQ1A" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 9737f42d74643b8e3ceb7ecfa2015ed2.cloudfront.net (CloudFront) X-Amz-Cf-Id: - 2jcYDsud_JpXtC9208uc2NR1L9xQ27wK1H_xN-FzC3XlliScNiMVDQ== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-679137e0-63d24d411ffaaa6a270bbfcf;ecfdb72e-5e5e-43ee-a695-54ac3c83fddd cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"options": {"wait_for_model": true, "use_gpu": false}, "inputs": {"question": "What''s my name?"}}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate Connection: - keep-alive Content-Length: - '98' Content-Type: - application/json X-Amzn-Trace-Id: - 9af17110-dd6c-43b3-976f-50b533762cd3 method: POST uri: https://api-inference.huggingface.co/pipeline/question-answering/deepset/roberta-base-squad2 response: body: string: '{"error":["Error in `inputs.context`: field required"]}' headers: Connection: - keep-alive Content-Type: - application/json Date: - Wed, 22 Jan 2025 18:24:33 GMT Transfer-Encoding: - chunked access-control-allow-credentials: - 'true' server: - uvicorn vary: - Origin, Access-Control-Request-Method, Access-Control-Request-Headers - origin, access-control-request-method, access-control-request-headers x-proxied-host: - internal.api-inference.huggingface.co x-proxied-path: - / x-request-id: - hAZ8Na x-sha: - adc3b06f79f797d1c575d5479d6f5efe54a9e3b4 status: code: 400 message: Bad Request version: 1 huggingface_hub-0.31.1/tests/cassettes/InferenceApiTest.test_inference_overriding_invalid_task.yaml000066400000000000000000000165521500667546600341210ustar00rootroot00000000000000interactions: - request: body: null headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate Connection: - keep-alive X-Amzn-Trace-Id: - e3510979-5d21-4bc5-b19b-88217f02822b method: GET uri: https://huggingface.co/api/models/bert-base-uncased response: body: string: Temporary Redirect. Redirecting to /api/models/google-bert/bert-base-uncased headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '76' Content-Type: - text/plain; charset=utf-8 Date: - Wed, 22 Jan 2025 18:24:33 GMT Location: - /api/models/google-bert/bert-base-uncased Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin, Accept Via: - 1.1 4587dd93b6f56d2b3f35f25ef2cabe70.cloudfront.net (CloudFront) X-Amz-Cf-Id: - QYfF2QIo7x2Nqc7z0DJulCrvKxsbjnoa0HJn44f7xPFCmn0J8KXZ5g== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-679137e1-10c621dc54b6dd8948e428ed;e3510979-5d21-4bc5-b19b-88217f02822b cross-origin-opener-policy: - same-origin status: code: 307 message: Temporary Redirect - request: body: null headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate Connection: - keep-alive X-Amzn-Trace-Id: - e3510979-5d21-4bc5-b19b-88217f02822b method: GET uri: https://huggingface.co/api/models/google-bert/bert-base-uncased response: body: string: '{"_id":"621ffdc036468d709f174338","id":"google-bert/bert-base-uncased","private":false,"pipeline_tag":"fill-mask","library_name":"transformers","tags":["transformers","pytorch","tf","jax","rust","coreml","onnx","safetensors","bert","fill-mask","exbert","en","dataset:bookcorpus","dataset:wikipedia","arxiv:1810.04805","license:apache-2.0","autotrain_compatible","endpoints_compatible","region:us"],"downloads":72376843,"likes":2065,"modelId":"google-bert/bert-base-uncased","author":"google-bert","sha":"86b5e0934494bd15c9632b12f734a8a67f723594","lastModified":"2024-02-19T11:06:12.000Z","gated":false,"inference":"warm","disabled":false,"mask_token":"[MASK]","widgetData":[{"text":"Paris is the [MASK] of France."},{"text":"The goal of life is [MASK]."}],"model-index":null,"config":{"architectures":["BertForMaskedLM"],"model_type":"bert","tokenizer_config":{}},"cardData":{"language":"en","tags":["exbert"],"license":"apache-2.0","datasets":["bookcorpus","wikipedia"]},"transformersInfo":{"auto_model":"AutoModelForMaskedLM","pipeline_tag":"fill-mask","processor":"AutoTokenizer"},"siblings":[{"rfilename":".gitattributes"},{"rfilename":"LICENSE"},{"rfilename":"README.md"},{"rfilename":"config.json"},{"rfilename":"coreml/fill-mask/float32_model.mlpackage/Data/com.apple.CoreML/model.mlmodel"},{"rfilename":"coreml/fill-mask/float32_model.mlpackage/Data/com.apple.CoreML/weights/weight.bin"},{"rfilename":"coreml/fill-mask/float32_model.mlpackage/Manifest.json"},{"rfilename":"flax_model.msgpack"},{"rfilename":"model.onnx"},{"rfilename":"model.safetensors"},{"rfilename":"pytorch_model.bin"},{"rfilename":"rust_model.ot"},{"rfilename":"tf_model.h5"},{"rfilename":"tokenizer.json"},{"rfilename":"tokenizer_config.json"},{"rfilename":"vocab.txt"}],"spaces":["mteb/leaderboard","microsoft/HuggingGPT","Vision-CAIR/minigpt4","lnyan/stablediffusion-infinity","multimodalart/latentdiffusion","Salesforce/BLIP","mrfakename/MeloTTS","shi-labs/Versatile-Diffusion","yizhangliu/Grounded-Segment-Anything","cvlab/zero123-live","xinyu1205/recognize-anything","AIGC-Audio/AudioGPT","hilamanor/audioEditing","Audio-AGI/AudioSep","jadechoghari/OpenMusic","m-ric/chunk_visualizer","DAMO-NLP-SG/Video-LLaMA","gligen/demo","declare-lab/mustango","Yiwen-ntu/MeshAnything","shgao/EditAnything","LiruiZhao/Diffree","exbert-project/exbert","Vision-CAIR/MiniGPT-v2","Yuliang/ECON","THUdyh/Oryx","IDEA-Research/Grounded-SAM","Awiny/Image2Paragraph","ShilongLiu/Grounding_DINO_demo","eswardivi/Podcastify","liuyuan-pal/SyncDreamer","haotiz/glip-zeroshot-demo","nateraw/lavila","sam-hq-team/sam-hq","abyildirim/inst-inpaint","TencentARC/BrushEdit","merve/Grounding_DINO_demo","Yiwen-ntu/MeshAnythingV2","Pinwheel/GLIP-BLIP-Object-Detection-VQA","Junfeng5/GLEE_demo","shi-labs/Matting-Anything","fffiloni/Video-Matting-Anything","linfanluntan/Grounded-SAM","magicr/BuboGPT","Nick088/Audio-SR","OpenGVLab/InternGPT","clip-italian/clip-italian-demo","hongfz16/3DTopia","Vision-CAIR/MiniGPT4-video","yenniejun/tokenizers-languages","mmlab-ntu/relate-anything-model","nikigoli/countgd","byeongjun-park/HarmonyView","keras-io/bert-semantic-similarity","MirageML/sjc","amphion/PicoAudio","NAACL2022/CLIP-Caption-Reward","society-ethics/model-card-regulatory-check","fffiloni/miniGPT4-Video-Zero","Gladiator/Text-Summarizer","fffiloni/vta-ldm","milyiyo/reimagine-it","AIGC-Audio/AudioLCM","ethanchern/Anole","ysharma/text-to-image-to-video","OpenGVLab/VideoChatGPT","avid-ml/bias-detection","LittleFrog/IntrinsicAnything","RitaParadaRamos/SmallCapDemo","llizhx/TinyGPT-V","acmc/whatsapp-chats-finetuning-formatter","kaushalya/medclip-roco","AIGC-Audio/Make_An_Audio","flax-community/koclip","sasha/BiasDetection","TencentARC/VLog","ynhe/AskAnything","Pusheen/LoCo","pseudolab/AI_Tutor_BERT","ZebangCheng/Emotion-LLaMA","sonalkum/GAMA","flax-community/clip-reply-demo","SeViLA/SeViLA","PSLD/PSLD","AnimaLab/bias-test-gpt-pairs","optimum/auto-benchmark","Volkopat/SegmentAnythingxGroundingDINO","thewhole/GaussianDreamer_Demo","CosmoAI/BhagwatGeeta","codelion/Grounding_DINO_demo","phyloforfun/VoucherVision","wenkai/FAPM_demo","flosstradamus/FluxMusicGUI","AILab-CVC/SEED-LLaMA","ALM/CALM","tornadoslims/instruct-pix2pix","MykolaL/evp","zdou0830/desco","attention-refocusing/Attention-refocusing","sasha/WinoBiasCheck"],"createdAt":"2022-03-02T23:29:04.000Z","safetensors":{"parameters":{"F32":110106428},"total":110106428},"usedStorage":12904182200}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '4420' Content-Type: - application/json; charset=utf-8 Date: - Wed, 22 Jan 2025 18:24:33 GMT ETag: - W/"1144-0lLZ4rZ6fR3gbhjw06/nTxkknu0" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 4587dd93b6f56d2b3f35f25ef2cabe70.cloudfront.net (CloudFront) X-Amz-Cf-Id: - pdTj3oTK7I_em8rndxqLMzArrdGyjf0TWg4K55lOjqBIC88u9k0Weg== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-679137e1-082086a975745c011675ebbc;e3510979-5d21-4bc5-b19b-88217f02822b cross-origin-opener-policy: - same-origin status: code: 200 message: OK version: 1 huggingface_hub-0.31.1/tests/cassettes/InferenceApiTest.test_inference_overriding_task.yaml000066400000000000000000000530221500667546600324040ustar00rootroot00000000000000interactions: - request: body: null headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate Connection: - keep-alive X-Amzn-Trace-Id: - 48899fdf-d647-479d-8d12-ce1b1ff9ee6e method: GET uri: https://huggingface.co/api/models/sentence-transformers/paraphrase-albert-small-v2 response: body: string: '{"_id":"621ffdc136468d709f1802e8","id":"sentence-transformers/paraphrase-albert-small-v2","private":false,"pipeline_tag":"sentence-similarity","library_name":"sentence-transformers","tags":["sentence-transformers","pytorch","tf","rust","onnx","safetensors","openvino","albert","feature-extraction","sentence-similarity","transformers","dataset:flax-sentence-embeddings/stackexchange_xml","dataset:s2orc","dataset:ms_marco","dataset:wiki_atomic_edits","dataset:snli","dataset:multi_nli","dataset:embedding-data/altlex","dataset:embedding-data/simple-wiki","dataset:embedding-data/flickr30k-captions","dataset:embedding-data/coco_captions","dataset:embedding-data/sentence-compression","dataset:embedding-data/QQP","dataset:yahoo_answers_topics","arxiv:1908.10084","license:apache-2.0","autotrain_compatible","endpoints_compatible","region:us"],"downloads":285986,"likes":9,"modelId":"sentence-transformers/paraphrase-albert-small-v2","author":"sentence-transformers","sha":"39d5b65549dbfc88a4c56fc853a8b7242873d583","lastModified":"2024-11-05T18:20:00.000Z","gated":false,"inference":"cold","disabled":false,"mask_token":"[MASK]","widgetData":[{"source_sentence":"That is a happy person","sentences":["That is a happy dog","That is a very happy person","Today is a sunny day"]}],"model-index":null,"config":{"architectures":["AlbertModel"],"model_type":"albert","tokenizer_config":{"bos_token":"[CLS]","eos_token":"[SEP]","unk_token":"","sep_token":"[SEP]","pad_token":"","cls_token":"[CLS]","mask_token":{"content":"[MASK]","single_word":false,"lstrip":true,"rstrip":false,"normalized":true,"__type":"AddedToken"}}},"cardData":{"license":"apache-2.0","library_name":"sentence-transformers","tags":["sentence-transformers","feature-extraction","sentence-similarity","transformers"],"datasets":["flax-sentence-embeddings/stackexchange_xml","s2orc","ms_marco","wiki_atomic_edits","snli","multi_nli","embedding-data/altlex","embedding-data/simple-wiki","embedding-data/flickr30k-captions","embedding-data/coco_captions","embedding-data/sentence-compression","embedding-data/QQP","yahoo_answers_topics"],"pipeline_tag":"sentence-similarity"},"transformersInfo":{"auto_model":"AutoModel","pipeline_tag":"feature-extraction","processor":"AutoTokenizer"},"siblings":[{"rfilename":".gitattributes"},{"rfilename":"1_Pooling/config.json"},{"rfilename":"README.md"},{"rfilename":"config.json"},{"rfilename":"config_sentence_transformers.json"},{"rfilename":"model.safetensors"},{"rfilename":"modules.json"},{"rfilename":"onnx/model.onnx"},{"rfilename":"onnx/model_O1.onnx"},{"rfilename":"onnx/model_O2.onnx"},{"rfilename":"onnx/model_O3.onnx"},{"rfilename":"onnx/model_O4.onnx"},{"rfilename":"onnx/model_qint8_arm64.onnx"},{"rfilename":"onnx/model_qint8_avx512.onnx"},{"rfilename":"onnx/model_qint8_avx512_vnni.onnx"},{"rfilename":"onnx/model_quint8_avx2.onnx"},{"rfilename":"openvino/openvino_model.bin"},{"rfilename":"openvino/openvino_model.xml"},{"rfilename":"openvino/openvino_model_qint8_quantized.bin"},{"rfilename":"openvino/openvino_model_qint8_quantized.xml"},{"rfilename":"pytorch_model.bin"},{"rfilename":"rust_model.ot"},{"rfilename":"sentence_bert_config.json"},{"rfilename":"special_tokens_map.json"},{"rfilename":"spiece.model"},{"rfilename":"tf_model.h5"},{"rfilename":"tokenizer.json"},{"rfilename":"tokenizer_config.json"}],"spaces":["Gradio-Blocks/pubmed-abstract-retriever","pritamdeka/health-article-keyphrase-generator","pritamdeka/pubmed-abstract-retriever"],"createdAt":"2022-03-02T23:29:05.000Z","safetensors":{"parameters":{"I64":512,"F32":11683584},"total":11684096},"usedStorage":578652826}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '3617' Content-Type: - application/json; charset=utf-8 Date: - Wed, 22 Jan 2025 18:24:33 GMT ETag: - W/"e21-MXShfRGRyOneN66PubYqdWkGPSU" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 887aba73f027fe4e82f965d15238ed3e.cloudfront.net (CloudFront) X-Amz-Cf-Id: - 2NNdzWr0gEdSvzElBjXv2GvinessqJ6FucSiK0g1dEqEUAcdUPurgw== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-679137e1-51fa1cab438b9e665860af66;48899fdf-d647-479d-8d12-ce1b1ff9ee6e cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"options": {"wait_for_model": true, "use_gpu": false}, "inputs": "This is an example again"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate Connection: - keep-alive Content-Length: - '93' Content-Type: - application/json X-Amzn-Trace-Id: - ae0eac6b-0519-4d7f-9e77-b7e3e943f4c5 method: POST uri: https://api-inference.huggingface.co/pipeline/feature-extraction/sentence-transformers/paraphrase-albert-small-v2 response: body: string: '[-0.341085821390152,0.16878847777843475,0.5349205732345581,-0.7717202305793762,0.3694867789745331,0.041570067405700684,-0.22994324564933777,-0.5638588666915894,-0.030669301748275757,-0.299565851688385,-0.8804835081100464,0.16595038771629333,-0.5313254594802856,0.18190892040729523,0.549403965473175,0.45036911964416504,0.45066139101982117,-0.5457428693771362,-0.6259514093399048,0.01874767616391182,0.38988572359085083,0.3517792522907257,-0.3381141126155853,-0.5400908589363098,0.2046586126089096,-0.2273494154214859,-0.13636690378189087,-0.5397760272026062,0.1338198035955429,0.25351065397262573,-0.39263030886650085,-0.6871981024742126,0.672836422920227,-0.03747480362653732,0.012533562257885933,-0.2833794951438904,-0.6461300849914551,-0.12365946918725967,-0.48956984281539917,0.18023069202899933,-0.22581827640533447,-0.3118274509906769,0.25834259390830994,0.6930604577064514,-0.7131444811820984,-0.21616850793361664,0.2454346865415573,-0.4833442270755768,0.07898270338773727,-0.030935684219002724,-0.1647939532995224,-0.26670706272125244,-0.00622419361025095,0.5999671220779419,-0.05106488987803459,-0.20105615258216858,0.18153749406337738,0.3733285367488861,0.15388838946819305,-0.7217402458190918,-0.09930419921875,-0.14671428501605988,0.8378180861473083,-0.1422334611415863,0.029067745432257652,0.1419973522424698,-0.6724494099617004,0.0512361042201519,-0.8505474925041199,0.16212500631809235,0.5462662577629089,0.27694663405418396,0.03470375016331673,0.05365396663546562,0.32530850172042847,0.10443305969238281,-0.4620659351348877,0.3613378703594208,0.08676077425479889,-0.4396112859249115,0.4775589406490326,0.18862219154834747,0.020406311377882957,0.8618583679199219,0.34356027841567993,0.5950371026992798,0.28181976079940796,0.22128866612911224,-0.19320161640644073,-0.8068356513977051,-0.6325915455818176,-0.20719096064567566,0.16645391285419464,-0.27316245436668396,-0.2558119297027588,0.4094904363155365,1.0869005918502808,-0.39742010831832886,-0.04294157773256302,0.66911381483078,0.05142701417207718,-0.32665207982063293,-0.18507659435272217,0.1801418513059616,-0.03082272782921791,0.866555392742157,0.015129602514207363,0.5369774103164673,0.035978395491838455,-0.596277117729187,-0.16099193692207336,0.21034523844718933,-0.31352949142456055,0.025525595992803574,0.15932805836200714,-0.24438704550266266,-0.3194598853588104,-0.22482778131961823,-0.2245563417673111,-0.6627371907234192,-0.005691751837730408,0.24560369551181793,0.17312154173851013,-0.26991763710975647,-0.12625451385974884,0.0537274107336998,-0.16675636172294617,0.26175934076309204,-0.34527286887168884,-0.2696242928504944,0.0382714606821537,0.9279422760009766,-0.1439134180545807,-0.17225460708141327,0.45578283071517944,0.2669826149940491,0.697851300239563,-0.31583133339881897,-0.05987294390797615,-0.16314975917339325,-0.2923843562602997,0.24195994436740875,0.7339079976081848,0.2875455915927887,0.38133737444877625,-0.06411625444889069,-0.5289252400398254,-0.12572170794010162,-0.41820162534713745,0.3145720064640045,-0.5967114567756653,0.2003154754638672,0.10518056899309158,0.1360904723405838,-0.18886713683605194,0.8033667802810669,0.4109314978122711,0.385867178440094,0.37393197417259216,-0.41822633147239685,0.5445504784584045,0.2698557674884796,-1.8897572755813599,-0.2654189169406891,0.12983624637126923,0.7082985043525696,-0.09554287046194077,0.7264841794967651,0.5856768488883972,0.09566565603017807,-0.06346061825752258,0.07220730930566788,-0.22825609147548676,0.6896306276321411,0.23028457164764404,0.6359089016914368,0.45442646741867065,0.4459366798400879,-0.3403448760509491,-0.9826788306236267,0.14824046194553375,0.6716114282608032,-0.7234140634536743,-0.042630940675735474,0.118319071829319,0.02029203437268734,0.006172746419906616,0.17694832384586334,-0.20284245908260345,-0.36010926961898804,-0.04080190509557724,-0.4041801989078522,0.5723574757575989,0.09892728179693222,0.005016509909182787,0.0600854828953743,0.4601345360279083,0.7505092620849609,-0.18042218685150146,-0.5690661072731018,-0.23548705875873566,0.6462931632995605,0.05932934209704399,-0.595089316368103,-0.23779405653476715,0.1962585747241974,0.32613605260849,-0.02546888217329979,-0.39885401725769043,0.2758044898509979,0.7516045570373535,0.4053671061992645,0.8974498510360718,-0.2762814462184906,-0.4561653733253479,0.6845132112503052,0.2907380759716034,-0.5976382493972778,0.4360603392124176,0.015573863871395588,0.36492055654525757,-0.17371919751167297,-0.4527595341205597,-0.8105295896530151,-0.30739954113960266,0.1275881677865982,-0.1640186756849289,1.0794562101364136,0.7435057759284973,-0.6585963368415833,-0.21766488254070282,0.5734838247299194,0.12493384629487991,0.25958165526390076,0.060625579208135605,-0.28618744015693665,0.46601101756095886,0.598228931427002,0.43606945872306824,-0.2665109932422638,0.32704880833625793,0.6004233956336975,-0.052457503974437714,-0.37291011214256287,-0.43347328901290894,-0.2699574828147888,0.5017542243003845,0.4700179100036621,0.22589385509490967,0.3331119418144226,0.5643407702445984,-0.4756787121295929,0.06478320807218552,0.47116750478744507,0.13153815269470215,0.20873501896858215,0.14722657203674316,-0.24565395712852478,-0.5516343116760254,-0.9115244746208191,-0.21274243295192719,-0.1721031814813614,-0.4704669117927551,-0.07386789470911026,-0.3552326560020447,0.4350995421409607,-0.2942036986351013,-0.28887930512428284,0.03486905246973038,-0.3166646361351013,0.10742438584566116,-0.17663109302520752,-0.48503705859184265,-0.13814832270145416,-0.08231735229492188,-0.06921812146902084,-0.1515282541513443,-0.3927438259124756,-0.10666313022375107,-0.40085285902023315,-0.3298726975917816,0.53345787525177,0.019323186948895454,-0.04476296156644821,0.016309704631567,0.6542044878005981,0.31786397099494934,-0.5420541167259216,-0.27084144949913025,0.3143649697303772,-0.4421720504760742,-0.11013428866863251,-0.06388164311647415,0.10421539843082428,-0.08173049986362457,-0.18816842138767242,0.2160194367170334,0.2033705860376358,0.5406883955001831,0.29751259088516235,0.16873298585414886,0.0014406612608581781,-0.6311850547790527,0.0005692084087058902,-0.42688751220703125,0.06314754486083984,-0.15346358716487885,0.20693521201610565,-0.1735556423664093,1.027342438697815,0.44072577357292175,-0.08022470027208328,-0.7083168029785156,-0.25857123732566833,0.3348306119441986,0.10101362317800522,0.17043828964233398,-0.10061045736074448,-0.2677077651023865,-0.08575468510389328,-0.1335678994655609,-0.3336067497730255,0.4870910346508026,-0.1337885856628418,-0.11630354821681976,0.050022754818201065,0.35443347692489624,0.08481835573911667,-0.43006640672683716,0.7748094201087952,-0.22517773509025574,-1.1933925151824951,0.12336816638708115,0.2770504057407379,-0.738603949546814,0.08051047474145889,0.04614692181348801,0.549460768699646,0.5680838823318481,-0.7405627965927124,0.10769277065992355,0.17911015450954437,0.2530647814273834,-0.13467685878276825,0.19446393847465515,-0.22782649099826813,-0.5038869380950928,0.48206526041030884,-0.46882668137550354,0.20342156291007996,-0.47249892354011536,0.5391000509262085,-0.07303949445486069,0.2430342584848404,-0.001972017576918006,-0.1673245131969452,0.060410283505916595,0.23200726509094238,0.23997804522514343,-0.2583170533180237,-0.9572028517723083,0.31147530674934387,-0.33262303471565247,-0.7998296022415161,-0.29539060592651367,-0.3838532865047455,-0.42259612679481506,-0.02299588732421398,-0.3338901698589325,0.10806461423635483,-0.4678882956504822,0.36553049087524414,0.20797577500343323,-0.4135935306549072,-0.3180953860282898,0.30124756693840027,0.8880809545516968,-0.11435394734144211,-0.05101349577307701,0.058348145335912704,-0.3907618224620819,0.12461119145154953,0.4072445034980774,-0.47625136375427246,-0.13791881501674652,-0.4741803705692291,-0.143830344080925,0.041737206280231476,-0.03461047634482384,0.01895022951066494,0.03245576471090317,0.28546658158302307,-0.13945673406124115,0.16599597036838531,-0.056719355285167694,-0.05164090543985367,0.0507022924721241,0.3546411693096161,-0.35955148935317993,0.6372845768928528,-0.9327383637428284,-0.12595877051353455,-0.10195805877447128,0.9906986355781555,0.0897512286901474,-0.040978629142045975,0.05636298656463623,-0.41859984397888184,0.18043269217014313,0.6323961019515991,-0.49062541127204895,0.22218415141105652,-0.37711790204048157,-0.3506550192832947,-0.4174796938896179,-0.08447526395320892,0.02368023432791233,-0.04415525868535042,0.202593594789505,-0.3836112320423126,0.37979814410209656,0.2973119616508484,0.03428228572010994,0.39918723702430725,-0.5313287377357483,-0.5142784714698792,0.9161316156387329,-0.8148826360702515,0.054495714604854584,-0.3410138785839081,-0.41300663352012634,-0.001639466150663793,0.12173599749803543,-0.44217580556869507,-0.6984124183654785,0.044708602130413055,0.7483633756637573,0.4097954332828522,0.6392425894737244,0.21882930397987366,-0.29012084007263184,-0.07355164736509323,0.28806376457214355,-0.8170408010482788,0.361832857131958,0.7832428216934204,0.168807715177536,-0.3501279056072235,-0.15446676313877106,-0.19089294970035553,-0.029323413968086243,0.3778229057788849,0.258955717086792,0.043066490441560745,-0.04226478934288025,0.21908776462078094,-0.17194804549217224,0.10165756195783615,-0.16221055388450623,0.018216874450445175,0.19914865493774414,-0.6726843118667603,0.09280387312173843,-0.5145979523658752,-0.06835389137268066,0.5380922555923462,0.1117040291428566,-0.25401750206947327,0.10142628103494644,0.4065215587615967,-0.2644296884536743,-0.3170887529850006,0.4023997187614441,0.3616196811199188,0.7512936592102051,0.7863553166389465,-0.5765145421028137,0.16855789721012115,1.5743578672409058,-0.2001160830259323,-0.6387949585914612,0.04473016783595085,-0.12076958268880844,-0.12282005697488785,-0.07770853489637375,0.11992402374744415,-0.1383485496044159,0.2918226420879364,0.10086144506931305,0.321867972612381,-0.053768906742334366,0.8262871503829956,0.09977172315120697,0.504828155040741,-0.35106131434440613,-0.33405494689941406,0.08995568007230759,-0.5260513424873352,-0.14189279079437256,0.41677728295326233,-0.007449401076883078,0.14026223123073578,0.28630560636520386,0.5080655813217163,-0.3142798840999603,-0.5547899007797241,0.07761329412460327,-0.1579141467809677,0.2332477867603302,-0.4157712161540985,-0.14580772817134857,-0.06452633440494537,-0.24453851580619812,-0.05624595284461975,0.11309627443552017,0.547134518623352,0.09909097105264664,-0.10484247654676437,-0.36512207984924316,0.7924814820289612,-0.1282169222831726,-0.10216403007507324,-0.030454959720373154,-0.27468141913414,0.37200960516929626,-0.06678207218647003,-0.3525021970272064,-0.21927209198474884,-0.39576977491378784,-0.37979254126548767,-0.6828380227088928,0.08713977783918381,0.6307284235954285,0.38557252287864685,-0.09959852695465088,0.4161994755268097,0.10460691899061203,0.28119996190071106,0.02771858684718609,-0.8788250088691711,-0.8748244047164917,-0.1198461651802063,0.3177807033061981,-0.46434029936790466,-0.37645503878593445,0.44777411222457886,0.1052740290760994,0.1326437145471573,0.6207156777381897,-0.4646264910697937,-0.5540279746055603,0.09666383266448975,0.5574221014976501,-0.0033486655447632074,-0.3210422694683075,-0.06035223230719566,0.5543868541717529,0.6094021201133728,0.2556997835636139,0.4786379635334015,0.4292359948158264,-0.3265448212623596,0.04899412393569946,0.4917512536048889,-1.1309148073196411,-0.5065527558326721,-0.7608033418655396,0.147862508893013,0.13593778014183044,0.39828047156333923,-0.28475698828697205,-0.0807085633277893,-0.5964204668998718,-0.2894887328147888,-0.020872993394732475,0.13188301026821136,0.38005486130714417,0.4969795346260071,0.5717677474021912,0.7432876229286194,-0.4028305411338806,-0.6246048212051392,0.47527340054512024,0.5970536470413208,-0.30089256167411804,-0.05530531331896782,-0.2688148021697998,0.6364954710006714,0.37265101075172424,0.5815147161483765,-0.350161612033844,0.19852495193481445,-0.13809694349765778,0.2194371223449707,-0.4028398096561432,-0.08065301924943924,0.7774210572242737,-0.3288760185241699,0.05748599022626877,-0.31139469146728516,0.23917606472969055,-0.3568010926246643,0.1992173194885254,0.10943685472011566,0.376640647649765,-0.08813752979040146,-0.2793470025062561,-1.1041699647903442,-0.6468858122825623,-0.018859224393963814,0.47495463490486145,-0.46383124589920044,0.04411156848073006,-0.1337435096502304,-0.2648039162158966,-1.2122631072998047,-0.24149669706821442,0.0051557389087975025,0.169891357421875,0.6742338538169861,-0.2402162253856659,0.16225044429302216,0.17258556187152863,-0.09926360100507736,-0.14262987673282623,-0.00912623293697834,-0.02926088310778141,0.131485715508461,-0.34209638833999634,-0.6729570031166077,0.4638628363609314,-0.6177125573158264,0.16800956428050995,-0.12893234193325043,-0.37443071603775024,-0.4301226735115051,-0.6492246389389038,-0.6880965232849121,-0.13862739503383636,0.27621862292289734,0.16595342755317688,-0.04998268559575081,0.047806985676288605,0.07906274497509003,-0.19737276434898376,0.2344685047864914,-0.7219113707542419,0.0947716236114502,0.30830976366996765,-0.15827342867851257,0.12446241825819016,-0.06680746376514435,-0.69204181432724,-0.09553401172161102,-0.1823764592409134,-0.3485519587993622,-0.28945425152778625,-0.09598355740308762,0.05012020841240883,0.1183229312300682,0.6960463523864746,-0.03801960498094559,-0.16029830276966095,0.10588861256837845,0.07330849021673203,0.24087336659431458,0.5877622961997986,0.44379445910453796,-0.17087765038013458,-0.1985747516155243,-0.2103971391916275,0.49884381890296936,0.3981620669364929,-0.4980464279651642,-0.06242205575108528,-0.2433273047208786,0.09865893423557281,0.38428398966789246,-0.4490985572338104,-0.04858693107962608,0.055084459483623505,0.5652283430099487,0.10908620804548264,0.4355883002281189,0.489664763212204,-0.35382866859436035,-0.6618562936782837,-0.2774604856967926,-0.12857243418693542,-0.32179927825927734,0.06028534844517708,-0.28948572278022766,0.228628009557724,-0.010540487244725227,0.42249542474746704,-0.6885339021682739,0.3894968032836914,-0.23503552377223969,-0.29027843475341797,-0.1969570815563202,0.2716953456401825,0.0744878277182579,-0.541286051273346,0.061627861112356186,-0.07427655905485153,-0.1666872501373291,-0.33925214409828186,-0.4131779074668884,-0.1113651916384697,-0.9221831560134888,-0.7618919610977173,-0.12339460849761963,0.16053859889507294,0.11736509948968887,-0.44852396845817566,-0.5015748739242554,-0.34559038281440735,-0.20842424035072327,1.5283548831939697,0.6495966911315918,0.37010619044303894,1.0158330202102661,-0.6882320046424866,-0.14081844687461853,-0.28634029626846313,-0.25794336199760437,0.6416827440261841,0.6321927309036255,0.5005248188972473,0.7902107238769531,0.6415908932685852,0.26437488198280334,0.49416327476501465,0.3526041507720947,-0.5043381452560425,-0.1065618246793747,0.3336275517940521,0.048981040716171265,0.3086264133453369,0.13828492164611816,0.25052839517593384,-0.03406331315636635,-0.11697522550821304,0.19562925398349762,-0.02777908183634281,-0.29367318749427795,-0.2487489879131317,0.34076806902885437,0.33187103271484375,0.21295860409736633,-0.26522132754325867,0.3742654323577881,-0.008183690719306469,-0.018883705139160156,-0.12140800058841705,-0.19533288478851318,0.024712085723876953,0.2337304800748825,-0.6195656061172485,0.2786346673965454,-0.12573401629924774,0.21495799720287323,-0.14248785376548767,-0.028164101764559746,0.028098242357373238,-0.3465019762516022,-0.24627400934696198,-0.6070567965507507]' headers: Connection: - keep-alive Content-Type: - application/json Date: - Wed, 22 Jan 2025 18:24:44 GMT Transfer-Encoding: - chunked access-control-allow-credentials: - 'true' access-control-expose-headers: - x-compute-type, x-compute-time server: - uvicorn vary: - Accept-Encoding, Origin, Access-Control-Request-Method, Access-Control-Request-Headers - origin, access-control-request-method, access-control-request-headers x-compute-characters: - '24' x-compute-time: - '0.013' x-compute-type: - cpu x-proxied-host: - internal.api-inference.huggingface.co x-proxied-path: - / x-request-id: - 6cNDfB x-sha: - 39d5b65549dbfc88a4c56fc853a8b7242873d583 status: code: 200 message: OK version: 1 huggingface_hub-0.31.1/tests/cassettes/InferenceApiTest.test_simple_inference.yaml000066400000000000000000000224161500667546600305060ustar00rootroot00000000000000interactions: - request: body: null headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate Connection: - keep-alive X-Amzn-Trace-Id: - 26fac15b-0ae6-43db-8d40-1192c1fe92a1 method: GET uri: https://huggingface.co/api/models/bert-base-uncased response: body: string: Temporary Redirect. Redirecting to /api/models/google-bert/bert-base-uncased headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '76' Content-Type: - text/plain; charset=utf-8 Date: - Wed, 22 Jan 2025 18:24:45 GMT Location: - /api/models/google-bert/bert-base-uncased Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin, Accept Via: - 1.1 b2ba040f19ad0239b9239a26b1640b9e.cloudfront.net (CloudFront) X-Amz-Cf-Id: - Wa1k6mDMbt66t5osRpB-sR497I2c1T97iJ2K8CGBVCD__ic8mEeFyQ== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-679137ed-006e422f6ee95900016b4c8e;26fac15b-0ae6-43db-8d40-1192c1fe92a1 cross-origin-opener-policy: - same-origin status: code: 307 message: Temporary Redirect - request: body: null headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate Connection: - keep-alive X-Amzn-Trace-Id: - 26fac15b-0ae6-43db-8d40-1192c1fe92a1 method: GET uri: https://huggingface.co/api/models/google-bert/bert-base-uncased response: body: string: '{"_id":"621ffdc036468d709f174338","id":"google-bert/bert-base-uncased","private":false,"pipeline_tag":"fill-mask","library_name":"transformers","tags":["transformers","pytorch","tf","jax","rust","coreml","onnx","safetensors","bert","fill-mask","exbert","en","dataset:bookcorpus","dataset:wikipedia","arxiv:1810.04805","license:apache-2.0","autotrain_compatible","endpoints_compatible","region:us"],"downloads":72376843,"likes":2065,"modelId":"google-bert/bert-base-uncased","author":"google-bert","sha":"86b5e0934494bd15c9632b12f734a8a67f723594","lastModified":"2024-02-19T11:06:12.000Z","gated":false,"inference":"warm","disabled":false,"mask_token":"[MASK]","widgetData":[{"text":"Paris is the [MASK] of France."},{"text":"The goal of life is [MASK]."}],"model-index":null,"config":{"architectures":["BertForMaskedLM"],"model_type":"bert","tokenizer_config":{}},"cardData":{"language":"en","tags":["exbert"],"license":"apache-2.0","datasets":["bookcorpus","wikipedia"]},"transformersInfo":{"auto_model":"AutoModelForMaskedLM","pipeline_tag":"fill-mask","processor":"AutoTokenizer"},"siblings":[{"rfilename":".gitattributes"},{"rfilename":"LICENSE"},{"rfilename":"README.md"},{"rfilename":"config.json"},{"rfilename":"coreml/fill-mask/float32_model.mlpackage/Data/com.apple.CoreML/model.mlmodel"},{"rfilename":"coreml/fill-mask/float32_model.mlpackage/Data/com.apple.CoreML/weights/weight.bin"},{"rfilename":"coreml/fill-mask/float32_model.mlpackage/Manifest.json"},{"rfilename":"flax_model.msgpack"},{"rfilename":"model.onnx"},{"rfilename":"model.safetensors"},{"rfilename":"pytorch_model.bin"},{"rfilename":"rust_model.ot"},{"rfilename":"tf_model.h5"},{"rfilename":"tokenizer.json"},{"rfilename":"tokenizer_config.json"},{"rfilename":"vocab.txt"}],"spaces":["mteb/leaderboard","microsoft/HuggingGPT","Vision-CAIR/minigpt4","lnyan/stablediffusion-infinity","multimodalart/latentdiffusion","Salesforce/BLIP","mrfakename/MeloTTS","shi-labs/Versatile-Diffusion","yizhangliu/Grounded-Segment-Anything","cvlab/zero123-live","xinyu1205/recognize-anything","AIGC-Audio/AudioGPT","hilamanor/audioEditing","Audio-AGI/AudioSep","jadechoghari/OpenMusic","m-ric/chunk_visualizer","DAMO-NLP-SG/Video-LLaMA","gligen/demo","declare-lab/mustango","Yiwen-ntu/MeshAnything","shgao/EditAnything","LiruiZhao/Diffree","exbert-project/exbert","Vision-CAIR/MiniGPT-v2","Yuliang/ECON","THUdyh/Oryx","IDEA-Research/Grounded-SAM","Awiny/Image2Paragraph","ShilongLiu/Grounding_DINO_demo","eswardivi/Podcastify","liuyuan-pal/SyncDreamer","haotiz/glip-zeroshot-demo","nateraw/lavila","sam-hq-team/sam-hq","abyildirim/inst-inpaint","TencentARC/BrushEdit","merve/Grounding_DINO_demo","Yiwen-ntu/MeshAnythingV2","Pinwheel/GLIP-BLIP-Object-Detection-VQA","Junfeng5/GLEE_demo","shi-labs/Matting-Anything","fffiloni/Video-Matting-Anything","linfanluntan/Grounded-SAM","magicr/BuboGPT","Nick088/Audio-SR","OpenGVLab/InternGPT","clip-italian/clip-italian-demo","hongfz16/3DTopia","Vision-CAIR/MiniGPT4-video","yenniejun/tokenizers-languages","mmlab-ntu/relate-anything-model","nikigoli/countgd","byeongjun-park/HarmonyView","keras-io/bert-semantic-similarity","MirageML/sjc","amphion/PicoAudio","NAACL2022/CLIP-Caption-Reward","society-ethics/model-card-regulatory-check","fffiloni/miniGPT4-Video-Zero","Gladiator/Text-Summarizer","fffiloni/vta-ldm","milyiyo/reimagine-it","AIGC-Audio/AudioLCM","ethanchern/Anole","ysharma/text-to-image-to-video","OpenGVLab/VideoChatGPT","avid-ml/bias-detection","LittleFrog/IntrinsicAnything","RitaParadaRamos/SmallCapDemo","llizhx/TinyGPT-V","acmc/whatsapp-chats-finetuning-formatter","kaushalya/medclip-roco","AIGC-Audio/Make_An_Audio","flax-community/koclip","sasha/BiasDetection","TencentARC/VLog","ynhe/AskAnything","Pusheen/LoCo","pseudolab/AI_Tutor_BERT","ZebangCheng/Emotion-LLaMA","sonalkum/GAMA","flax-community/clip-reply-demo","SeViLA/SeViLA","PSLD/PSLD","AnimaLab/bias-test-gpt-pairs","optimum/auto-benchmark","Volkopat/SegmentAnythingxGroundingDINO","thewhole/GaussianDreamer_Demo","CosmoAI/BhagwatGeeta","codelion/Grounding_DINO_demo","phyloforfun/VoucherVision","wenkai/FAPM_demo","flosstradamus/FluxMusicGUI","AILab-CVC/SEED-LLaMA","ALM/CALM","tornadoslims/instruct-pix2pix","MykolaL/evp","zdou0830/desco","attention-refocusing/Attention-refocusing","sasha/WinoBiasCheck"],"createdAt":"2022-03-02T23:29:04.000Z","safetensors":{"parameters":{"F32":110106428},"total":110106428},"usedStorage":12904182200}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '4420' Content-Type: - application/json; charset=utf-8 Date: - Wed, 22 Jan 2025 18:24:45 GMT ETag: - W/"1144-0lLZ4rZ6fR3gbhjw06/nTxkknu0" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 b2ba040f19ad0239b9239a26b1640b9e.cloudfront.net (CloudFront) X-Amz-Cf-Id: - nDRIVcY91T1kKrdUdduDGqo4u74TtN2muKi6_VOvWf1Bn8-qzD15TQ== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-679137ed-5b402e715ad9fc2745fe095d;26fac15b-0ae6-43db-8d40-1192c1fe92a1 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"options": {"wait_for_model": true, "use_gpu": false}, "inputs": "Hi, I think [MASK]\u00a0is cool"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate Connection: - keep-alive Content-Length: - '100' Content-Type: - application/json X-Amzn-Trace-Id: - c5d6dbee-a766-43de-af2a-801dce1c3d8c method: POST uri: https://api-inference.huggingface.co/pipeline/fill-mask/bert-base-uncased response: body: string: '[{"score":0.800057053565979,"token":2023,"token_str":"this","sequence":"hi, i think this is cool"},{"score":0.05881328508257866,"token":2008,"token_str":"that","sequence":"hi, i think that is cool"},{"score":0.05029096454381943,"token":2009,"token_str":"it","sequence":"hi, i think it is cool"},{"score":0.01182125136256218,"token":2673,"token_str":"everything","sequence":"hi, i think everything is cool"},{"score":0.008278430439531803,"token":2002,"token_str":"he","sequence":"hi, i think he is cool"}]' headers: Connection: - keep-alive Content-Type: - application/json Date: - Wed, 22 Jan 2025 18:24:45 GMT Transfer-Encoding: - chunked access-control-allow-credentials: - 'true' access-control-expose-headers: - x-compute-type, x-compute-time server: - uvicorn vary: - Origin, Access-Control-Request-Method, Access-Control-Request-Headers - origin, access-control-request-method, access-control-request-headers x-compute-characters: - '26' x-compute-time: - '0.033' x-compute-type: - cpu x-proxied-host: - internal.api-inference.huggingface.co x-proxied-path: - / x-request-id: - NCN3_s x-sha: - 86b5e0934494bd15c9632b12f734a8a67f723594 status: code: 200 message: OK version: 1 huggingface_hub-0.31.1/tests/cassettes/InferenceApiTest.test_text_to_image.yaml000066400000000000000000000151071500667546600300260ustar00rootroot00000000000000interactions: - request: body: null headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate Connection: - keep-alive X-Amzn-Trace-Id: - f41cda26-6bb8-42f0-80a1-da64d4c707e6 method: GET uri: https://huggingface.co/api/models/stabilityai/stable-diffusion-2-1 response: body: string: '{"_id":"638f7ae36c25af4071044105","id":"stabilityai/stable-diffusion-2-1","private":false,"pipeline_tag":"text-to-image","library_name":"diffusers","tags":["diffusers","safetensors","stable-diffusion","text-to-image","arxiv:2112.10752","arxiv:2202.00512","arxiv:1910.09700","license:openrail++","autotrain_compatible","endpoints_compatible","diffusers:StableDiffusionPipeline","region:us"],"downloads":1015720,"likes":3928,"modelId":"stabilityai/stable-diffusion-2-1","author":"stabilityai","sha":"5cae40e6a2745ae2b01ad92ae5043f95f23644d6","lastModified":"2023-07-05T16:19:17.000Z","gated":false,"inference":"warm","disabled":false,"model-index":null,"config":{"diffusers":{"_class_name":"StableDiffusionPipeline"}},"cardData":{"license":"openrail++","tags":["stable-diffusion","text-to-image"],"pinned":true},"siblings":[{"rfilename":".gitattributes"},{"rfilename":"README.md"},{"rfilename":"feature_extractor/preprocessor_config.json"},{"rfilename":"model_index.json"},{"rfilename":"scheduler/scheduler_config.json"},{"rfilename":"text_encoder/config.json"},{"rfilename":"text_encoder/model.fp16.safetensors"},{"rfilename":"text_encoder/model.safetensors"},{"rfilename":"text_encoder/pytorch_model.bin"},{"rfilename":"text_encoder/pytorch_model.fp16.bin"},{"rfilename":"tokenizer/merges.txt"},{"rfilename":"tokenizer/special_tokens_map.json"},{"rfilename":"tokenizer/tokenizer_config.json"},{"rfilename":"tokenizer/vocab.json"},{"rfilename":"unet/config.json"},{"rfilename":"unet/diffusion_pytorch_model.bin"},{"rfilename":"unet/diffusion_pytorch_model.fp16.bin"},{"rfilename":"unet/diffusion_pytorch_model.fp16.safetensors"},{"rfilename":"unet/diffusion_pytorch_model.safetensors"},{"rfilename":"v2-1_768-ema-pruned.ckpt"},{"rfilename":"v2-1_768-ema-pruned.safetensors"},{"rfilename":"v2-1_768-nonema-pruned.ckpt"},{"rfilename":"v2-1_768-nonema-pruned.safetensors"},{"rfilename":"vae/config.json"},{"rfilename":"vae/diffusion_pytorch_model.bin"},{"rfilename":"vae/diffusion_pytorch_model.fp16.bin"},{"rfilename":"vae/diffusion_pytorch_model.fp16.safetensors"},{"rfilename":"vae/diffusion_pytorch_model.safetensors"}],"spaces":["microsoft/HuggingGPT","multimodalart/dreambooth-training","kadirnar/Video-Diffusion-WebUI","aliabid94/AutoGPT","VAST-AI/CharacterGen","PAIR/StreamingT2V","JingyeChen22/TextDiffuser","shgao/EditAnything","Nymbo/HH-ImgGen","ChenyangSi/FreeU","garibida/ReNoise-Inversion","vorstcavry/ai","MirageML/dreambooth","TencentARC/ColorFlow","tetrisd/Diffusion-Attentive-Attribution-Maps","baulab/ConceptSliders","kamiyamai/stable-diffusion-webui","jeasinema/UltraEdit-SD3","trysem/SD-2.1-Img2Img","multimodalart/civitai-to-hf","ennov8ion/3dart-Models","Nymbo/image_gen_supaqueue","kxic/EscherNet","Truepic/watermarked-content-credentials","cownclown/Image-and-3D-Model-Creator","carloscar/stable-diffusion-webui-controlnet-docker","Komorebizyd/DrawApp","prs-eth/rollingdepth","Truepic/ai-content-credentials","ennov8ion/comicbook-models","Nick088/stable-diffusion-arena","SUPERSHANKY/Finetuned_Diffusion_Max","nasttam/Image-and-3D-Model-Creator","IAmXenos21/stable-diffusion-webui-VORST2","svjack/AIIDiffusion","AlStable/AlPrompt","Fabrice-TIERCELIN/Text-to-Audio","Nymbo/Flood","EPFL-VILAB/ViPer","showlab/Show-o","yuan2023/Stable-Diffusion-ControlNet-WebUI","estusgroup/ai-qr-code-generator-beta-v2","decodemai/Stable-Diffusion-Ads","akhaliq/webui-orangemixs","gaspar-avit/Movie_Poster_Generator","xnetba/text2image","Make-A-Protagonist/Make-A-Protagonist-inference","yuan2023/stable-diffusion-webui-controlnet-docker","SVGRender/DiffSketcher","samthakur/stable-diffusion-2.1","taesiri/HuggingGPT-Lite","mindtube/Diffusion50XX","kamwoh/dreamcreature","sky24h/Stable-Makeup-unofficial","AI-ML-API-tutorials/ai-sticker-maker","rhfeiyang/Art-Free-Diffusion","YeOldHermit/StableDiffusion_AnythingV3_ModelCamenduru","Datasculptor/ImageGPT","Omnibus/Video-Diffusion-WebUI","Adam111/stable-diffusion-webui","vs4vijay/stable-diffusion","Yasu55/stable-diffusion-webui","ennov8ion/stablediffusion-models","sagarkarn/text2image","JoPmt/Multi-SD_Cntrl_Cny_Pse_Img2Img","JoPmt/Txt-to-video","imseldrith/Text-to-Image2","duchaba/sd_prompt_helper","Crossper6/stable-diffusion-webui","bobu5/SD-webui-controlnet-docker","JoPmt/Vid2Vid_Cntrl_Canny_Multi_SD","lianzhou/stable-diffusion-webui","Missinginaction/stablediffusionwithnofilter","achyuth1344/stable-diffusion-webui","ennov8ion/Scifi-Models","ennov8ion/semirealistic-models","ennov8ion/FantasyArt-Models","ennov8ion/dreamlike-models","meowingamogus69/stable-diffusion-webui-controlnet-docker","Ababababababbababa/SD-2.1-Img2Img","noes14155/img_All_models","Dagfinn1962/prodia2","tokeron/DiffusionLens","dilightnet/DiLightNet","Nymbo/Game-Creator","JoPmt/Img2Img_SD_Control_Canny_Pose_Multi","nihun/image-gen","AnimeStudio/anime-models","Emma02/LVM","pieeetre/stable-diffusion-webui","VincentZB/Stable-Diffusion-ControlNet-WebUI","luluneko1/stable-diffusion-webui","MathysL/AutoGPT4","voltcutter/stable-diffusion-webui","Duskfallcrew/newdreambooth-toclone","Dao3/Top-20-Models","Proveedy/dreambooth-trainingv15","MetaWabbit/Auto-GPT","Shriharsh/Text_To_Image","pikto/Diffuser"],"createdAt":"2022-12-06T17:24:51.000Z","usedStorage":57147037677}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '5204' Content-Type: - application/json; charset=utf-8 Date: - Wed, 22 Jan 2025 18:24:45 GMT ETag: - W/"1454-32uEi6TyuaGiZXENOHB0lYqMKCI" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 9737f42d74643b8e3ceb7ecfa2015ed2.cloudfront.net (CloudFront) X-Amz-Cf-Id: - LJEGaCxU02yCLV4NyGkElvZH2shCTe358L8HGoP9_OmFrHvegRzRkw== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-679137ed-284f9ab8779e96556218657d;f41cda26-6bb8-42f0-80a1-da64d4c707e6 cross-origin-opener-policy: - same-origin status: code: 200 message: OK version: 1 huggingface_hub-0.31.1/tests/cassettes/InferenceApiTest.test_text_to_image_raw_response.yaml000066400000000000000000000151071500667546600326150ustar00rootroot00000000000000interactions: - request: body: null headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate Connection: - keep-alive X-Amzn-Trace-Id: - b96ba40c-6425-4b42-acaa-0be1248fbde9 method: GET uri: https://huggingface.co/api/models/stabilityai/stable-diffusion-2-1 response: body: string: '{"_id":"638f7ae36c25af4071044105","id":"stabilityai/stable-diffusion-2-1","private":false,"pipeline_tag":"text-to-image","library_name":"diffusers","tags":["diffusers","safetensors","stable-diffusion","text-to-image","arxiv:2112.10752","arxiv:2202.00512","arxiv:1910.09700","license:openrail++","autotrain_compatible","endpoints_compatible","diffusers:StableDiffusionPipeline","region:us"],"downloads":1015720,"likes":3928,"modelId":"stabilityai/stable-diffusion-2-1","author":"stabilityai","sha":"5cae40e6a2745ae2b01ad92ae5043f95f23644d6","lastModified":"2023-07-05T16:19:17.000Z","gated":false,"inference":"warm","disabled":false,"model-index":null,"config":{"diffusers":{"_class_name":"StableDiffusionPipeline"}},"cardData":{"license":"openrail++","tags":["stable-diffusion","text-to-image"],"pinned":true},"siblings":[{"rfilename":".gitattributes"},{"rfilename":"README.md"},{"rfilename":"feature_extractor/preprocessor_config.json"},{"rfilename":"model_index.json"},{"rfilename":"scheduler/scheduler_config.json"},{"rfilename":"text_encoder/config.json"},{"rfilename":"text_encoder/model.fp16.safetensors"},{"rfilename":"text_encoder/model.safetensors"},{"rfilename":"text_encoder/pytorch_model.bin"},{"rfilename":"text_encoder/pytorch_model.fp16.bin"},{"rfilename":"tokenizer/merges.txt"},{"rfilename":"tokenizer/special_tokens_map.json"},{"rfilename":"tokenizer/tokenizer_config.json"},{"rfilename":"tokenizer/vocab.json"},{"rfilename":"unet/config.json"},{"rfilename":"unet/diffusion_pytorch_model.bin"},{"rfilename":"unet/diffusion_pytorch_model.fp16.bin"},{"rfilename":"unet/diffusion_pytorch_model.fp16.safetensors"},{"rfilename":"unet/diffusion_pytorch_model.safetensors"},{"rfilename":"v2-1_768-ema-pruned.ckpt"},{"rfilename":"v2-1_768-ema-pruned.safetensors"},{"rfilename":"v2-1_768-nonema-pruned.ckpt"},{"rfilename":"v2-1_768-nonema-pruned.safetensors"},{"rfilename":"vae/config.json"},{"rfilename":"vae/diffusion_pytorch_model.bin"},{"rfilename":"vae/diffusion_pytorch_model.fp16.bin"},{"rfilename":"vae/diffusion_pytorch_model.fp16.safetensors"},{"rfilename":"vae/diffusion_pytorch_model.safetensors"}],"spaces":["microsoft/HuggingGPT","multimodalart/dreambooth-training","kadirnar/Video-Diffusion-WebUI","aliabid94/AutoGPT","VAST-AI/CharacterGen","PAIR/StreamingT2V","JingyeChen22/TextDiffuser","shgao/EditAnything","Nymbo/HH-ImgGen","ChenyangSi/FreeU","garibida/ReNoise-Inversion","vorstcavry/ai","MirageML/dreambooth","TencentARC/ColorFlow","tetrisd/Diffusion-Attentive-Attribution-Maps","baulab/ConceptSliders","kamiyamai/stable-diffusion-webui","jeasinema/UltraEdit-SD3","trysem/SD-2.1-Img2Img","multimodalart/civitai-to-hf","ennov8ion/3dart-Models","Nymbo/image_gen_supaqueue","kxic/EscherNet","Truepic/watermarked-content-credentials","cownclown/Image-and-3D-Model-Creator","carloscar/stable-diffusion-webui-controlnet-docker","Komorebizyd/DrawApp","prs-eth/rollingdepth","Truepic/ai-content-credentials","ennov8ion/comicbook-models","Nick088/stable-diffusion-arena","SUPERSHANKY/Finetuned_Diffusion_Max","nasttam/Image-and-3D-Model-Creator","IAmXenos21/stable-diffusion-webui-VORST2","svjack/AIIDiffusion","AlStable/AlPrompt","Fabrice-TIERCELIN/Text-to-Audio","Nymbo/Flood","EPFL-VILAB/ViPer","showlab/Show-o","yuan2023/Stable-Diffusion-ControlNet-WebUI","estusgroup/ai-qr-code-generator-beta-v2","decodemai/Stable-Diffusion-Ads","akhaliq/webui-orangemixs","gaspar-avit/Movie_Poster_Generator","xnetba/text2image","Make-A-Protagonist/Make-A-Protagonist-inference","yuan2023/stable-diffusion-webui-controlnet-docker","SVGRender/DiffSketcher","samthakur/stable-diffusion-2.1","taesiri/HuggingGPT-Lite","mindtube/Diffusion50XX","kamwoh/dreamcreature","sky24h/Stable-Makeup-unofficial","AI-ML-API-tutorials/ai-sticker-maker","rhfeiyang/Art-Free-Diffusion","YeOldHermit/StableDiffusion_AnythingV3_ModelCamenduru","Datasculptor/ImageGPT","Omnibus/Video-Diffusion-WebUI","Adam111/stable-diffusion-webui","vs4vijay/stable-diffusion","Yasu55/stable-diffusion-webui","ennov8ion/stablediffusion-models","sagarkarn/text2image","JoPmt/Multi-SD_Cntrl_Cny_Pse_Img2Img","JoPmt/Txt-to-video","imseldrith/Text-to-Image2","duchaba/sd_prompt_helper","Crossper6/stable-diffusion-webui","bobu5/SD-webui-controlnet-docker","JoPmt/Vid2Vid_Cntrl_Canny_Multi_SD","lianzhou/stable-diffusion-webui","Missinginaction/stablediffusionwithnofilter","achyuth1344/stable-diffusion-webui","ennov8ion/Scifi-Models","ennov8ion/semirealistic-models","ennov8ion/FantasyArt-Models","ennov8ion/dreamlike-models","meowingamogus69/stable-diffusion-webui-controlnet-docker","Ababababababbababa/SD-2.1-Img2Img","noes14155/img_All_models","Dagfinn1962/prodia2","tokeron/DiffusionLens","dilightnet/DiLightNet","Nymbo/Game-Creator","JoPmt/Img2Img_SD_Control_Canny_Pose_Multi","nihun/image-gen","AnimeStudio/anime-models","Emma02/LVM","pieeetre/stable-diffusion-webui","VincentZB/Stable-Diffusion-ControlNet-WebUI","luluneko1/stable-diffusion-webui","MathysL/AutoGPT4","voltcutter/stable-diffusion-webui","Duskfallcrew/newdreambooth-toclone","Dao3/Top-20-Models","Proveedy/dreambooth-trainingv15","MetaWabbit/Auto-GPT","Shriharsh/Text_To_Image","pikto/Diffuser"],"createdAt":"2022-12-06T17:24:51.000Z","usedStorage":57147037677}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '5204' Content-Type: - application/json; charset=utf-8 Date: - Wed, 22 Jan 2025 18:24:46 GMT ETag: - W/"1454-32uEi6TyuaGiZXENOHB0lYqMKCI" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 e47c282d2c53705a367f9e376a2eab28.cloudfront.net (CloudFront) X-Amz-Cf-Id: - _zL1bIJbONpKnGKV11WalfKg3AseDTIy4apS8njN3TF1VHNNm7qBxQ== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-679137ee-6c94201668eccebe49f4c219;b96ba40c-6425-4b42-acaa-0be1248fbde9 cross-origin-opener-policy: - same-origin status: code: 200 message: OK version: 1 huggingface_hub-0.31.1/tests/cassettes/TestSpaceAPIProduction.test_manage_secrets.yaml000066400000000000000000000503501500667546600312610ustar00rootroot00000000000000interactions: - request: body: '{"name": "tmp_test_space", "organization": "user", "private": true, "type": "space", "sdk": "gradio"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '104' Content-Type: - application/json X-Amzn-Trace-Id: - c9be3bc3-c902-4af2-98ce-49f29dd89efd method: POST uri: https://huggingface.co/api/repos/create response: body: string: '{"error":"You already created this space repo","url":"https://huggingface.co/spaces/user/tmp_test_space"}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '108' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:11 GMT ETag: - W/"6c-cLXSdVZqrgsnfZ2Xm7BqF4YqbPo" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - aH7hYCC_MTDl-sikHXAQZwwl8H7ZkAHuU1Raj65PT1M9ZqDdlUMHtg== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Error from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae56b-7044bad340f6f3b2760f4124;c9be3bc3-c902-4af2-98ce-49f29dd89efd cross-origin-opener-policy: - same-origin status: code: 409 message: Conflict - request: body: '{"files": [{"path": "app.py", "sample": "CmltcG9ydCBncmFkaW8gYXMgZ3IKCgpkZWYgZ3JlZXQobmFtZSk6CiAgICByZXR1cm4gIkhlbGxvICIgKyBuYW1lICsgIiEhIgoKaWZhY2UgPSBnci5JbnRlcmZhY2UoZm49Z3JlZXQsIGlucHV0cz0idGV4dCIsIG91dHB1dHM9InRleHQiKQppZmFjZS5sYXVuY2goKQo=", "size": 152}]}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '262' Content-Type: - application/json X-Amzn-Trace-Id: - 96319374-fbbe-42ed-b621-7552a91114f5 method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/preupload/main response: body: string: '{"files":[{"path":"app.py","uploadMode":"regular","shouldIgnore":false}],"commitOid":"8aa4cd78db279c7d673ade7a798dd4adff562cbb"}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '128' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:11 GMT ETag: - W/"80-9/S6gfzEBsCFgJIVX0Zf1/sEtmg" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - 4SXTzf9atTyWrsSHYbbYLKULBj-VoshO6iwrAen7eSOqozorgP99IQ== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae56b-49212cca2d70b5a51205c99a;96319374-fbbe-42ed-b621-7552a91114f5 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: null headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive X-Amzn-Trace-Id: - 86a2f7da-ff0f-41f2-9025-a87e405216b4 method: GET uri: https://huggingface.co/api/spaces/user/tmp_test_space/revision/main?expand=xetEnabled response: body: string: '{"_id":"67cae5358ccdbd832011c081","id":"user/tmp_test_space","xetEnabled":false}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '83' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:11 GMT ETag: - W/"53-a2Em1NBwfk1o0EIzZGOT899ueUk" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - hATAPZaLcQEsftKG9LsDN0rOcZxEGFk4ndyGXoxsVLxzSaOziI9HLA== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae56b-4ccabd6c4a7d63f0790a5d8e;86a2f7da-ff0f-41f2-9025-a87e405216b4 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"key": "header", "value": {"summary": "Upload app.py with huggingface_hub", "description": ""}} {"key": "file", "value": {"content": "CmltcG9ydCBncmFkaW8gYXMgZ3IKCgpkZWYgZ3JlZXQobmFtZSk6CiAgICByZXR1cm4gIkhlbGxvICIgKyBuYW1lICsgIiEhIgoKaWZhY2UgPSBnci5JbnRlcmZhY2UoZm49Z3JlZXQsIGlucHV0cz0idGV4dCIsIG91dHB1dHM9InRleHQiKQppZmFjZS5sYXVuY2goKQo=", "path": "app.py", "encoding": "base64"}} ' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '383' Content-Type: - application/x-ndjson X-Amzn-Trace-Id: - 03a62a5f-cd11-48d4-aad2-4b47cbecedec method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/commit/main response: body: string: '{"success":true,"commitOid":"855efb9561990d02233a49204227f93ba5cfaa81","commitUrl":"https://huggingface.co/spaces/user/tmp_test_space/commit/855efb9561990d02233a49204227f93ba5cfaa81","hookOutput":""}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '202' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:12 GMT ETag: - W/"ca-iQts/ibleiMnAtiU0ILntNCQMhM" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - Pe-v9zhEQlLPUtsJdofOm6Sx2Y4cHG64XD0txMcnixqoCvWQsYUDbw== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae56c-45d0c68012abaf6223afe331;03a62a5f-cd11-48d4-aad2-4b47cbecedec cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"key": "foo", "value": "123"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '30' Content-Type: - application/json X-Amzn-Trace-Id: - d44a034d-b2e2-43e7-8504-7c4a7b6a6f15 method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/secrets response: body: string: '{}' headers: Connection: - keep-alive Content-Length: - '2' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:12 GMT ETag: - W/"2-vyGp6PvFo4RvsFtPoIWeCReyIC8" Referrer-Policy: - strict-origin-when-cross-origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - -Z2SKEQXC4d3RJ9p-LV4slzus6Qd-MNY97DSl1EL2-6lsTYHQ3a8hQ== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae56c-77406d6558571d27093d9032;d44a034d-b2e2-43e7-8504-7c4a7b6a6f15 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"key": "token", "value": "hf_api_123456"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '42' Content-Type: - application/json X-Amzn-Trace-Id: - bf7a9f6b-a4a2-471a-9a0e-8c0d9f28e5f8 method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/secrets response: body: string: '{}' headers: Connection: - keep-alive Content-Length: - '2' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:13 GMT ETag: - W/"2-vyGp6PvFo4RvsFtPoIWeCReyIC8" Referrer-Policy: - strict-origin-when-cross-origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - PfnZ8uwcAcpYTlIsO5Wv6-EK3kZBPdPDjy8uf_RcDB0-Yyj5K6ic5w== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae56d-7423833c6ca909b27c519585;bf7a9f6b-a4a2-471a-9a0e-8c0d9f28e5f8 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"key": "gh_api_key", "value": "******"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '40' Content-Type: - application/json X-Amzn-Trace-Id: - 754a6946-e9c9-4f00-b5d8-00a0409265e9 method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/secrets response: body: string: '{}' headers: Connection: - keep-alive Content-Length: - '2' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:13 GMT ETag: - W/"2-vyGp6PvFo4RvsFtPoIWeCReyIC8" Referrer-Policy: - strict-origin-when-cross-origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - vZ_aoXKkuQAEkCLmr8wT9bEETyAwg7SyUcsugmFzfyjMi2p7fJm0Lw== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae56d-630a8f1c1fd4e3755c433323;754a6946-e9c9-4f00-b5d8-00a0409265e9 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"key": "bar", "value": "123", "description": "This is a secret"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '65' Content-Type: - application/json X-Amzn-Trace-Id: - 6dd3747b-f44c-4b95-a9d4-0c7c0ccb55a4 method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/secrets response: body: string: '{}' headers: Connection: - keep-alive Content-Length: - '2' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:13 GMT ETag: - W/"2-vyGp6PvFo4RvsFtPoIWeCReyIC8" Referrer-Policy: - strict-origin-when-cross-origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - ZJxk2O9KM1gKhOrg9CEKUTCBFyP-vqQBQGyIDz4FyzINYz50A8H8Zg== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae56d-0d991be52616ccd52f81462b;6dd3747b-f44c-4b95-a9d4-0c7c0ccb55a4 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"key": "foo", "value": "456"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '30' Content-Type: - application/json X-Amzn-Trace-Id: - 23a9e1fa-b4b7-402e-81ea-5ffee1e1fe40 method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/secrets response: body: string: '{}' headers: Connection: - keep-alive Content-Length: - '2' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:13 GMT ETag: - W/"2-vyGp6PvFo4RvsFtPoIWeCReyIC8" Referrer-Policy: - strict-origin-when-cross-origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - ff4tYN8pDZCRzOyHxiX4htBPouIMsQgaOI8Tu6CXb1ZCv66tOQp40Q== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae56d-203ef959504ba4523d8d2080;23a9e1fa-b4b7-402e-81ea-5ffee1e1fe40 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"key": "foo", "value": "789", "description": "This is a secret"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '65' Content-Type: - application/json X-Amzn-Trace-Id: - a658ade9-d7b5-4348-a55a-7cc63920ac12 method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/secrets response: body: string: '{}' headers: Connection: - keep-alive Content-Length: - '2' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:13 GMT ETag: - W/"2-vyGp6PvFo4RvsFtPoIWeCReyIC8" Referrer-Policy: - strict-origin-when-cross-origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - wvPMPDDGhrYi1XzKNk6W1yKMnywEDoO7GUutB158Hb6hoA9bDoE2EQ== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae56d-7f7ec498356cf25d7891fdb2;a658ade9-d7b5-4348-a55a-7cc63920ac12 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"key": "bar", "value": "456", "description": "This is another secret"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '71' Content-Type: - application/json X-Amzn-Trace-Id: - 280fc309-a810-485a-875e-ef801cef4a06 method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/secrets response: body: string: '{}' headers: Connection: - keep-alive Content-Length: - '2' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:14 GMT ETag: - W/"2-vyGp6PvFo4RvsFtPoIWeCReyIC8" Referrer-Policy: - strict-origin-when-cross-origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - MtyENUIdc-DDlS2r4LecZeN7xdY687cS99iDyk_AljAcMdDxfkvESQ== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae56e-6c1831633adba4df35d727bf;280fc309-a810-485a-875e-ef801cef4a06 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"key": "gh_api_key"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '21' Content-Type: - application/json X-Amzn-Trace-Id: - 6a5ccc7d-dad7-4260-9db5-432d85667736 method: DELETE uri: https://huggingface.co/api/spaces/user/tmp_test_space/secrets response: body: string: '{}' headers: Connection: - keep-alive Content-Length: - '2' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:14 GMT ETag: - W/"2-vyGp6PvFo4RvsFtPoIWeCReyIC8" Referrer-Policy: - strict-origin-when-cross-origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - FPgiwymoD902XfsaHSJ5PeBg0cTa0wzCclmMCk4W2oMIlowB5ATzxw== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae56e-2b3eb53a6d6ada9c42857e7d;6a5ccc7d-dad7-4260-9db5-432d85667736 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"key": "missing_key"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '22' Content-Type: - application/json X-Amzn-Trace-Id: - 0c4e1e88-7136-4776-b5e3-20553214d278 method: DELETE uri: https://huggingface.co/api/spaces/user/tmp_test_space/secrets response: body: string: '{}' headers: Connection: - keep-alive Content-Length: - '2' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:14 GMT ETag: - W/"2-vyGp6PvFo4RvsFtPoIWeCReyIC8" Referrer-Policy: - strict-origin-when-cross-origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - Dc-eyr_Pot418vnzFWwJBlOYn_cGX6tglUd7tpXgr4v7bLwkjR42_A== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae56e-519f0665135ac4485aae313e;0c4e1e88-7136-4776-b5e3-20553214d278 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"name": "tmp_test_space", "organization": "user", "type": "space"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '70' Content-Type: - application/json X-Amzn-Trace-Id: - e9b0afaa-3fbf-46f9-98e8-e4fe67829ca6 method: DELETE uri: https://huggingface.co/api/repos/delete response: body: string: OK headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '2' Content-Type: - text/plain; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:15 GMT ETag: - W/"2-nOO9QiTIwXgNtWtBJezz8kv3SLc" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - 1YKXYUszejurZEoZbaR8bzf4gBLOutzPs95GXjhlwfC0w_lZGNpnmA== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae56e-159b91f5249e2e672237e7a9;e9b0afaa-3fbf-46f9-98e8-e4fe67829ca6 cross-origin-opener-policy: - same-origin status: code: 200 message: OK version: 1 huggingface_hub-0.31.1/tests/cassettes/TestSpaceAPIProduction.test_manage_variables.yaml000066400000000000000000000613631500667546600315670ustar00rootroot00000000000000interactions: - request: body: '{"name": "tmp_test_space", "organization": "user", "private": true, "type": "space", "sdk": "gradio"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '104' Content-Type: - application/json X-Amzn-Trace-Id: - 47a4cc5f-9a1b-44b2-9cef-901111fd2498 method: POST uri: https://huggingface.co/api/repos/create response: body: string: '{"url":"https://huggingface.co/spaces/user/tmp_test_space","name":"user/tmp_test_space","id":"67cae56f9b4f4ee4713c83e7"}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '126' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:16 GMT ETag: - W/"7e-IIZfkz6oRU0wTdbGC2eUKrWYXDw" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 5cb605e8100138acccc04f094724133e.cloudfront.net (CloudFront) X-Amz-Cf-Id: - eqpp6gXTuHTKzoytrR-zThBLi_XqvA0R-mC6pfbbkhalO8mlbt6Low== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae56f-2255614a05efec98258484db;47a4cc5f-9a1b-44b2-9cef-901111fd2498 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"files": [{"path": "app.py", "sample": "CmltcG9ydCBncmFkaW8gYXMgZ3IKCgpkZWYgZ3JlZXQobmFtZSk6CiAgICByZXR1cm4gIkhlbGxvICIgKyBuYW1lICsgIiEhIgoKaWZhY2UgPSBnci5JbnRlcmZhY2UoZm49Z3JlZXQsIGlucHV0cz0idGV4dCIsIG91dHB1dHM9InRleHQiKQppZmFjZS5sYXVuY2goKQo=", "size": 152}]}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '262' Content-Type: - application/json X-Amzn-Trace-Id: - a674d7b9-a3ce-4402-a8f4-ae05e73a8578 method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/preupload/main response: body: string: '{"files":[{"path":"app.py","uploadMode":"regular","shouldIgnore":false}],"commitOid":"c28c3378a2354b5aa3b65c9652ed8ed8941ba3d5"}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '128' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:16 GMT ETag: - W/"80-kRxMLDjU5j+SSNqA7hsLAnQGOHQ" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 5cb605e8100138acccc04f094724133e.cloudfront.net (CloudFront) X-Amz-Cf-Id: - J3WLQ_RzQt5sPpotjNWKtNlmPzid0SzaB-uhWA8Wwjwx3WiAf0WU0w== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae570-790fb2dd21f9cad57389b7fd;a674d7b9-a3ce-4402-a8f4-ae05e73a8578 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: null headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive X-Amzn-Trace-Id: - 02a4c7dc-c8b8-4b16-8021-83edf7b5ecae method: GET uri: https://huggingface.co/api/spaces/user/tmp_test_space/revision/main?expand=xetEnabled response: body: string: '{"_id":"67cae56f9b4f4ee4713c83e7","id":"user/tmp_test_space","xetEnabled":false}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '83' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:17 GMT ETag: - W/"53-KcEoMEXQ1qRAHs8f1K9d+H6xV7w" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 5cb605e8100138acccc04f094724133e.cloudfront.net (CloudFront) X-Amz-Cf-Id: - mFd5RLMJhi2obWQ-QkhDBmc75IiBBox5ZQvIvc2T8Cecz-G3vJQynA== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae571-344d912007f9a51620a0fca7;02a4c7dc-c8b8-4b16-8021-83edf7b5ecae cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"key": "header", "value": {"summary": "Upload app.py with huggingface_hub", "description": ""}} {"key": "file", "value": {"content": "CmltcG9ydCBncmFkaW8gYXMgZ3IKCgpkZWYgZ3JlZXQobmFtZSk6CiAgICByZXR1cm4gIkhlbGxvICIgKyBuYW1lICsgIiEhIgoKaWZhY2UgPSBnci5JbnRlcmZhY2UoZm49Z3JlZXQsIGlucHV0cz0idGV4dCIsIG91dHB1dHM9InRleHQiKQppZmFjZS5sYXVuY2goKQo=", "path": "app.py", "encoding": "base64"}} ' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '383' Content-Type: - application/x-ndjson X-Amzn-Trace-Id: - b0f4a049-91f2-4a65-9eec-9f75b5a07872 method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/commit/main response: body: string: '{"success":true,"commitOid":"eff8e0eb38c673ede5f5d4c4d106255900fa39e2","commitUrl":"https://huggingface.co/spaces/user/tmp_test_space/commit/eff8e0eb38c673ede5f5d4c4d106255900fa39e2","hookOutput":""}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '202' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:18 GMT ETag: - W/"ca-iLOqNyjYlzhNCQkFbLydVgeRbf4" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 5cb605e8100138acccc04f094724133e.cloudfront.net (CloudFront) X-Amz-Cf-Id: - xuZonJ3zlvIeaBbYm2gx4mROKU5nvjpWgraQpLXjoAyhRA1XuRnzHw== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae571-573543ba4c3e13ec17141d90;b0f4a049-91f2-4a65-9eec-9f75b5a07872 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: null headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive X-Amzn-Trace-Id: - bf9d247f-1b2e-4454-acf7-aa972cd07b55 method: GET uri: https://huggingface.co/api/spaces/user/tmp_test_space/variables response: body: string: '{}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '2' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:18 GMT ETag: - W/"2-vyGp6PvFo4RvsFtPoIWeCReyIC8" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 5cb605e8100138acccc04f094724133e.cloudfront.net (CloudFront) X-Amz-Cf-Id: - a2QqDwiAwKKod4_2b-xu79UmL_Kc271A63ScnYMkR9s9yUhdPNikcg== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae572-6ae7936b53e3166a22040d10;bf9d247f-1b2e-4454-acf7-aa972cd07b55 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"key": "foo", "value": "123"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '30' Content-Type: - application/json X-Amzn-Trace-Id: - c08ae5a9-0683-4d1d-ac9c-8ce685520f3d method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/variables response: body: string: '{"foo":{"value":"123","updatedAt":"2025-03-07T12:24:18.657Z"}}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '62' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:18 GMT ETag: - W/"3e-q4uUCpjRuU299XB4L0PIX635r20" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 5cb605e8100138acccc04f094724133e.cloudfront.net (CloudFront) X-Amz-Cf-Id: - 8YXx2Dhq-vkxGSoArPsfuHdW5Fajil-0986X799w2pVg3ZtDfMTBkg== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae572-62408b0012d48183413b41e6;c08ae5a9-0683-4d1d-ac9c-8ce685520f3d cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"key": "MODEL_REPO_ID", "value": "user/repo"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '46' Content-Type: - application/json X-Amzn-Trace-Id: - 6727a16f-2d47-45f3-843b-e54989ad0580 method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/variables response: body: string: '{"foo":{"value":"123","updatedAt":"2025-03-07T12:24:18.657Z"},"MODEL_REPO_ID":{"value":"user/repo","updatedAt":"2025-03-07T12:24:19.140Z"}}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '139' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:19 GMT ETag: - W/"8b-0w1ouMDC0O3knwFRG8mKuf2z1gU" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 5cb605e8100138acccc04f094724133e.cloudfront.net (CloudFront) X-Amz-Cf-Id: - x9C51EPVQRKMu0X9269jE9kRV1TI3E1WxIAvgpzZX4UoGBIQnSmUjg== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae572-0e816e0d22d9de186c24f5f5;6727a16f-2d47-45f3-843b-e54989ad0580 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"key": "MODEL_PAPER", "value": "arXiv", "description": "found it there"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '73' Content-Type: - application/json X-Amzn-Trace-Id: - ac16a5ea-5d18-47b7-98c3-dc4129acfcf5 method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/variables response: body: string: '{"foo":{"value":"123","updatedAt":"2025-03-07T12:24:18.657Z"},"MODEL_REPO_ID":{"value":"user/repo","updatedAt":"2025-03-07T12:24:19.140Z"},"MODEL_PAPER":{"value":"arXiv","description":"found it there","updatedAt":"2025-03-07T12:24:19.337Z"}}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '241' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:19 GMT ETag: - W/"f1-bEfkBM1h+HtLdktXu58IvCW0uyE" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 5cb605e8100138acccc04f094724133e.cloudfront.net (CloudFront) X-Amz-Cf-Id: - 8bg9Y0pAFQBq9RtFl3WHSGkvOXjwi35BJRm5xLn6s4wzyhg0mhMnKg== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae573-2205bd7c46758c396c705b7a;ac16a5ea-5d18-47b7-98c3-dc4129acfcf5 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"key": "foo", "value": "456"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '30' Content-Type: - application/json X-Amzn-Trace-Id: - 699803d2-ae72-49ae-9712-e576db0751ea method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/variables response: body: string: '{"foo":{"value":"456","updatedAt":"2025-03-07T12:24:19.651Z"},"MODEL_REPO_ID":{"value":"user/repo","updatedAt":"2025-03-07T12:24:19.140Z"},"MODEL_PAPER":{"value":"arXiv","description":"found it there","updatedAt":"2025-03-07T12:24:19.337Z"}}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '241' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:19 GMT ETag: - W/"f1-PE5/Di6GQPKkXc5Dc7tQBbWdbSE" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 5cb605e8100138acccc04f094724133e.cloudfront.net (CloudFront) X-Amz-Cf-Id: - ajAUhuKkNRQmjXeSetA9dc0TwaMHDX7zIx4nkz8O80cuxYrvqtW5Bg== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae573-613368fc40b0a0e05c3a03ea;699803d2-ae72-49ae-9712-e576db0751ea cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"key": "foo", "value": "456", "description": "updated description"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '68' Content-Type: - application/json X-Amzn-Trace-Id: - a934728a-ddd4-4369-a0e4-6cdffa21758c method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/variables response: body: string: '{"foo":{"value":"456","description":"updated description","updatedAt":"2025-03-07T12:24:19.875Z"},"MODEL_REPO_ID":{"value":"user/repo","updatedAt":"2025-03-07T12:24:19.140Z"},"MODEL_PAPER":{"value":"arXiv","description":"found it there","updatedAt":"2025-03-07T12:24:19.337Z"}}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '277' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:19 GMT ETag: - W/"115-eatWlFhuJXDKdFCN6DDHH8ZrEa0" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 5cb605e8100138acccc04f094724133e.cloudfront.net (CloudFront) X-Amz-Cf-Id: - -eXNt-VkjkjAhxHxFRgmHOEgWq0AHNxJ5A7v3ejo1tj6PfHfGs6WRg== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae573-36676ba6315a70c6651b6ecb;a934728a-ddd4-4369-a0e4-6cdffa21758c cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"key": "gh_api_key"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '21' Content-Type: - application/json X-Amzn-Trace-Id: - 41cc0e9b-386c-4960-9b0c-fefffe97ea62 method: DELETE uri: https://huggingface.co/api/spaces/user/tmp_test_space/variables response: body: string: '{"foo":{"value":"456","description":"updated description","updatedAt":"2025-03-07T12:24:19.875Z"},"MODEL_REPO_ID":{"value":"user/repo","updatedAt":"2025-03-07T12:24:19.140Z"},"MODEL_PAPER":{"value":"arXiv","description":"found it there","updatedAt":"2025-03-07T12:24:19.337Z"}}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '277' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:20 GMT ETag: - W/"115-eatWlFhuJXDKdFCN6DDHH8ZrEa0" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 5cb605e8100138acccc04f094724133e.cloudfront.net (CloudFront) X-Amz-Cf-Id: - 3V7eqdB1eXIzDFvSAo4lEw42b3cS46NwFu3TFiTzGa17Pi3g7gbLtA== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae574-643f6a7e075723b3013b1364;41cc0e9b-386c-4960-9b0c-fefffe97ea62 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"key": "missing_key"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '22' Content-Type: - application/json X-Amzn-Trace-Id: - ec63b32a-6bcc-4a86-8a9c-6f7beda90418 method: DELETE uri: https://huggingface.co/api/spaces/user/tmp_test_space/variables response: body: string: '{"foo":{"value":"456","description":"updated description","updatedAt":"2025-03-07T12:24:19.875Z"},"MODEL_REPO_ID":{"value":"user/repo","updatedAt":"2025-03-07T12:24:19.140Z"},"MODEL_PAPER":{"value":"arXiv","description":"found it there","updatedAt":"2025-03-07T12:24:19.337Z"}}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '277' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:20 GMT ETag: - W/"115-eatWlFhuJXDKdFCN6DDHH8ZrEa0" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 5cb605e8100138acccc04f094724133e.cloudfront.net (CloudFront) X-Amz-Cf-Id: - nvIM3SyyDIDQyq6II7z0MYhkPao7GWEuvtbeJGArLb96ivflvXPHQw== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae574-3f86dee741149c407e6df30c;ec63b32a-6bcc-4a86-8a9c-6f7beda90418 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: null headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive X-Amzn-Trace-Id: - 1b8a4629-6629-48df-a512-dcff8d947cb4 method: GET uri: https://huggingface.co/api/spaces/user/tmp_test_space/variables response: body: string: '{"foo":{"value":"456","description":"updated description","updatedAt":"2025-03-07T12:24:19.875Z"},"MODEL_REPO_ID":{"value":"user/repo","updatedAt":"2025-03-07T12:24:19.140Z"},"MODEL_PAPER":{"value":"arXiv","description":"found it there","updatedAt":"2025-03-07T12:24:19.337Z"}}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '277' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:20 GMT ETag: - W/"115-eatWlFhuJXDKdFCN6DDHH8ZrEa0" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 5cb605e8100138acccc04f094724133e.cloudfront.net (CloudFront) X-Amz-Cf-Id: - gp59kY8IyTOe7uOCAPRyk7aqNXqXzjX7YF374voAc3cHzsm4jJzJRA== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae574-7b50faed54abeb620f129366;1b8a4629-6629-48df-a512-dcff8d947cb4 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"name": "tmp_test_space", "organization": "user", "type": "space"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '70' Content-Type: - application/json X-Amzn-Trace-Id: - 09422d1d-6beb-47cd-96e9-334e1d23e667 method: DELETE uri: https://huggingface.co/api/repos/delete response: body: string: OK headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '2' Content-Type: - text/plain; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:20 GMT ETag: - W/"2-nOO9QiTIwXgNtWtBJezz8kv3SLc" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 5cb605e8100138acccc04f094724133e.cloudfront.net (CloudFront) X-Amz-Cf-Id: - k3vFxzVrf5jBRs4IIh2ZTzOLk4fpLYBov63R9icnaBQaOtgQJuwwLw== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae574-21fe270c44ec0a8c51b83eb2;09422d1d-6beb-47cd-96e9-334e1d23e667 cross-origin-opener-policy: - same-origin status: code: 200 message: OK version: 1 huggingface_hub-0.31.1/tests/cassettes/TestSpaceAPIProduction.test_pause_and_restart_space.yaml000066400000000000000000000517571500667546600331730ustar00rootroot00000000000000interactions: - request: body: '{"name": "tmp_test_space", "organization": "user", "private": true, "type": "space", "sdk": "gradio"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '104' Content-Type: - application/json X-Amzn-Trace-Id: - a86bc51b-b47e-426f-a35f-240a57a69ba0 method: POST uri: https://huggingface.co/api/repos/create response: body: string: '{"url":"https://huggingface.co/spaces/user/tmp_test_space","name":"user/tmp_test_space","id":"67cae57522fde9195aaf95c5"}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '126' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:22 GMT ETag: - W/"7e-cANDOJXnR7JqsqX4FY/+yV6Z80g" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - ltslbWIDbeto-y9731xrW2NKtvntvbrk0DkxHcPJqcMdWG1iTY-v4w== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae575-057288fc692431ca313fa387;a86bc51b-b47e-426f-a35f-240a57a69ba0 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"files": [{"path": "app.py", "sample": "CmltcG9ydCBncmFkaW8gYXMgZ3IKCgpkZWYgZ3JlZXQobmFtZSk6CiAgICByZXR1cm4gIkhlbGxvICIgKyBuYW1lICsgIiEhIgoKaWZhY2UgPSBnci5JbnRlcmZhY2UoZm49Z3JlZXQsIGlucHV0cz0idGV4dCIsIG91dHB1dHM9InRleHQiKQppZmFjZS5sYXVuY2goKQo=", "size": 152}]}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '262' Content-Type: - application/json X-Amzn-Trace-Id: - 06cce17b-d0eb-4018-9104-e57f84a6b1c7 method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/preupload/main response: body: string: '{"files":[{"path":"app.py","uploadMode":"regular","shouldIgnore":false}],"commitOid":"d052d85daa4e615c3554cecaa7dbdc2e4d179929"}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '128' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:22 GMT ETag: - W/"80-o37TMCJion9tXEpXCGSfmyWO+H4" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - 6QbhLOXmcZfCV6aRrJ1wdoLR0nSreMTPdHza4E50w-VY0L08kovFKA== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae576-298749c90215b9e265f7a281;06cce17b-d0eb-4018-9104-e57f84a6b1c7 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: null headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive X-Amzn-Trace-Id: - 59b8bb46-367a-4e3d-b796-4511a701a469 method: GET uri: https://huggingface.co/api/spaces/user/tmp_test_space/revision/main?expand=xetEnabled response: body: string: '{"_id":"67cae57522fde9195aaf95c5","id":"user/tmp_test_space","xetEnabled":false}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '83' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:22 GMT ETag: - W/"53-FIrrE87MlsfOWdXe5c0aZBcYc5M" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - YZbiyM_Rm3gBuDSprem8uIdDugxCCJzil9NXuGXdSmjYcJzDjSbEpA== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae576-58adf5e47599e8a908829681;59b8bb46-367a-4e3d-b796-4511a701a469 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"key": "header", "value": {"summary": "Upload app.py with huggingface_hub", "description": ""}} {"key": "file", "value": {"content": "CmltcG9ydCBncmFkaW8gYXMgZ3IKCgpkZWYgZ3JlZXQobmFtZSk6CiAgICByZXR1cm4gIkhlbGxvICIgKyBuYW1lICsgIiEhIgoKaWZhY2UgPSBnci5JbnRlcmZhY2UoZm49Z3JlZXQsIGlucHV0cz0idGV4dCIsIG91dHB1dHM9InRleHQiKQppZmFjZS5sYXVuY2goKQo=", "path": "app.py", "encoding": "base64"}} ' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '383' Content-Type: - application/x-ndjson X-Amzn-Trace-Id: - ad37bd5d-59ad-4aba-bb5b-ee2a5c004855 method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/commit/main response: body: string: '{"success":true,"commitOid":"c52ce5e45984b9378ef6247fe7a45ed72b096b24","commitUrl":"https://huggingface.co/spaces/user/tmp_test_space/commit/c52ce5e45984b9378ef6247fe7a45ed72b096b24","hookOutput":""}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '202' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:23 GMT ETag: - W/"ca-wapu7c2+EVB7QHYtNqK8Q/f985Y" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - dFoaB0Y-sS8CnESrd49aWe7jj8pVe0NLvV4M07gn21isBrsUjwrwag== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae576-1430fb4e73a8e5c63ded648d;ad37bd5d-59ad-4aba-bb5b-ee2a5c004855 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"files": [{"path": "app.py", "sample": "", "size": 0}]}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '56' Content-Type: - application/json X-Amzn-Trace-Id: - 48734f70-485c-4e46-a6ba-7c5dddac5545 method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/preupload/main response: body: string: '{"files":[{"path":"app.py","uploadMode":"regular","shouldIgnore":false,"oid":"1a42034850e2f5213e1a89a887747b667dc9d125"}],"commitOid":"c52ce5e45984b9378ef6247fe7a45ed72b096b24"}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '177' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:23 GMT ETag: - W/"b1-IARwrXroBnMWqawtpuKn+D8vBgU" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - Zgq9WfyFTJthPNuYLd2ok-oGQ1Laviez3l6g-hKX_D1MGMdoeZbr6A== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae577-599e5f4d1d0a6c486864d32e;48734f70-485c-4e46-a6ba-7c5dddac5545 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: null headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive X-Amzn-Trace-Id: - c53b6463-dce9-45db-a4f1-347c5f447ae1 method: GET uri: https://huggingface.co/api/spaces/user/tmp_test_space/revision/main?expand=xetEnabled response: body: string: '{"_id":"67cae57522fde9195aaf95c5","id":"user/tmp_test_space","xetEnabled":false}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '83' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:23 GMT ETag: - W/"53-FIrrE87MlsfOWdXe5c0aZBcYc5M" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - uILqgA1fn8pYLBT269WjhgA-frxwjbbJ6AE5Baad1jLu-uP7ye_PEg== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae577-350598925a39e6d302576c0c;c53b6463-dce9-45db-a4f1-347c5f447ae1 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"key": "header", "value": {"summary": "Upload app.py with huggingface_hub", "description": ""}} {"key": "file", "value": {"content": "", "path": "app.py", "encoding": "base64"}} ' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '179' Content-Type: - application/x-ndjson X-Amzn-Trace-Id: - 39fc2f2d-ae6a-41c4-8fd3-1ad093f7fb6d method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/commit/main response: body: string: '{"success":true,"commitOid":"d37235f0eaf17cd7e86cc944e5487a3570056b04","commitUrl":"https://huggingface.co/spaces/user/tmp_test_space/commit/d37235f0eaf17cd7e86cc944e5487a3570056b04","hookOutput":""}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '202' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:24 GMT ETag: - W/"ca-KkEZAwv8/Q7fV9Kl05Z7vhmmTgU" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - qmQPq_yQ9Gzs-in0oFdpoByZVn5vMfmVzupHM_5weN2b2TBqi9pNjw== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae577-61fe6831284a00c21e5f093f;39fc2f2d-ae6a-41c4-8fd3-1ad093f7fb6d cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: null headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive X-Amzn-Trace-Id: - 1b845141-f4a8-40de-acbe-7c9c08993cc8 method: GET uri: https://huggingface.co/api/spaces/user/tmp_test_space/runtime response: body: string: '{"stage":"BUILDING","hardware":{"current":null,"requested":"cpu-basic"},"storage":null,"gcTimeout":172800,"replicas":{"requested":1},"devMode":false,"domains":[{"domain":"user-tmp-test-space.hf.space","stage":"READY"}]}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '222' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:25 GMT ETag: - W/"de-+RBXGi1JeTfRdHwxH3iLUmZUjr0" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - mxNUus7YVqDX2BgB59vLKLzFV-uZiaPCUEHB6OGVERj0JeczSbf9DQ== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae578-23edd72a19764b8a21eccc05;1b845141-f4a8-40de-acbe-7c9c08993cc8 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: null headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '0' X-Amzn-Trace-Id: - c89896dc-5bdb-4968-abcf-6b6dd93fd1c6 method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/pause response: body: string: '{"stage":"PAUSED","hardware":{"current":null,"requested":"cpu-basic"},"storage":null,"gcTimeout":172800,"replicas":{"requested":1},"devMode":false,"domains":[{"domain":"user-tmp-test-space.hf.space","stage":"READY"}]}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '220' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:25 GMT ETag: - W/"dc-0cOOI26eshHyYgeIJy7pudLBiCs" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - Vaf1rVJi7OQ-3L2J6Ue_ufnOpqa73Qcif2uQzfjYFAdQJrj3OG2XJw== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae579-3cfda27777a71d5b604ff856;c89896dc-5bdb-4968-abcf-6b6dd93fd1c6 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: null headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '0' X-Amzn-Trace-Id: - 2bb06c3c-52f8-41ed-8e00-6c13f2d0c22e method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/restart response: body: string: '{"stage":"BUILDING","hardware":{"current":null,"requested":"cpu-basic"},"storage":null,"gcTimeout":172800,"replicas":{"requested":1},"devMode":false,"domains":[{"domain":"user-tmp-test-space.hf.space","stage":"READY"}]}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '222' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:25 GMT ETag: - W/"de-+RBXGi1JeTfRdHwxH3iLUmZUjr0" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - 40udxw9EAJeiZntMDziLw_FMqrjTQXhtcz5ufqgj45BAOkub3-DTMA== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae579-40acb3ca5d7781964f1a2181;2bb06c3c-52f8-41ed-8e00-6c13f2d0c22e cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: null headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive X-Amzn-Trace-Id: - 5589740f-cde4-4bc2-a85c-ca33b9a5fd2d method: GET uri: https://huggingface.co/api/spaces/user/tmp_test_space/runtime response: body: string: '{"stage":"BUILDING","hardware":{"current":null,"requested":"cpu-basic"},"storage":null,"gcTimeout":172800,"replicas":{"requested":1},"devMode":false,"domains":[{"domain":"user-tmp-test-space.hf.space","stage":"READY"}]}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '222' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:26 GMT ETag: - W/"de-+RBXGi1JeTfRdHwxH3iLUmZUjr0" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - 8fxvdFL7hd6ZCU3oMeW-vxlAC9e3Z3g85NHrfcFlnhb3GLVTR7SMsA== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae57a-254a6c2d154805431ec92977;5589740f-cde4-4bc2-a85c-ca33b9a5fd2d cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"name": "tmp_test_space", "organization": "user", "type": "space"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '70' Content-Type: - application/json X-Amzn-Trace-Id: - d783a055-0dc6-458a-a4d7-73a3d551c741 method: DELETE uri: https://huggingface.co/api/repos/delete response: body: string: OK headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '2' Content-Type: - text/plain; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:26 GMT ETag: - W/"2-nOO9QiTIwXgNtWtBJezz8kv3SLc" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 aad5d23429e63574c684a22d6a0313f0.cloudfront.net (CloudFront) X-Amz-Cf-Id: - mMN-FE-9HxpXFjuPoT6HMJxJPbpHPwYpCpGkpfy_8fs4oUpbTB3n4g== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae57a-2ef2666e084d6e0a2a395796;d783a055-0dc6-458a-a4d7-73a3d551c741 cross-origin-opener-policy: - same-origin status: code: 200 message: OK version: 1 huggingface_hub-0.31.1/tests/cassettes/TestSpaceAPIProduction.test_space_runtime.yaml000066400000000000000000000252141500667546600311400ustar00rootroot00000000000000interactions: - request: body: '{"name": "tmp_test_space", "organization": "user", "private": true, "type": "space", "sdk": "gradio"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '104' Content-Type: - application/json X-Amzn-Trace-Id: - dce0b1eb-2033-4828-bde7-0c350982d090 method: POST uri: https://huggingface.co/api/repos/create response: body: string: '{"url":"https://huggingface.co/spaces/user/tmp_test_space","name":"user/tmp_test_space","id":"67cae57bc3f95553ede73be7"}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '126' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:27 GMT ETag: - W/"7e-gtR2sRmrCIGWw3Q/NRHwpzFc1y0" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 e47c282d2c53705a367f9e376a2eab28.cloudfront.net (CloudFront) X-Amz-Cf-Id: - Y90ILOyKIRNKrjO7hdzXuPn0hIe2EW5yuEeoG46wLCHRXC7TkKv-wQ== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae57a-4b4e5c5b432e5b3c0e788d01;dce0b1eb-2033-4828-bde7-0c350982d090 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"files": [{"path": "app.py", "sample": "CmltcG9ydCBncmFkaW8gYXMgZ3IKCgpkZWYgZ3JlZXQobmFtZSk6CiAgICByZXR1cm4gIkhlbGxvICIgKyBuYW1lICsgIiEhIgoKaWZhY2UgPSBnci5JbnRlcmZhY2UoZm49Z3JlZXQsIGlucHV0cz0idGV4dCIsIG91dHB1dHM9InRleHQiKQppZmFjZS5sYXVuY2goKQo=", "size": 152}]}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '262' Content-Type: - application/json X-Amzn-Trace-Id: - 9885712b-69e9-48d8-8010-2eab63ee50b8 method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/preupload/main response: body: string: '{"files":[{"path":"app.py","uploadMode":"regular","shouldIgnore":false}],"commitOid":"d02801a18c74e2875b33127a993615a36eb7448b"}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '128' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:28 GMT ETag: - W/"80-wh4U6aPRbJd/LbtLDKJsYOQ33R8" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 e47c282d2c53705a367f9e376a2eab28.cloudfront.net (CloudFront) X-Amz-Cf-Id: - 58BAAB6MNc-tQjWz9QNWKUfQEaj8ENSBz7RJyVHldTFtLMsvpimi-g== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae57c-013d82071f6e1128724ce17a;9885712b-69e9-48d8-8010-2eab63ee50b8 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: null headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive X-Amzn-Trace-Id: - 616ebcf7-e610-496c-b4f0-40a438b22494 method: GET uri: https://huggingface.co/api/spaces/user/tmp_test_space/revision/main?expand=xetEnabled response: body: string: '{"_id":"67cae57bc3f95553ede73be7","id":"user/tmp_test_space","xetEnabled":false}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '83' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:28 GMT ETag: - W/"53-r0lG4ObhJFKVnhyKtQ899fj4eDw" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 e47c282d2c53705a367f9e376a2eab28.cloudfront.net (CloudFront) X-Amz-Cf-Id: - 2P2xhj59xnkLbs4kavlFidr_5x8EVhXMu_ieGtpyK2XZPSARzPHN4w== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae57c-230f25335970111108e17684;616ebcf7-e610-496c-b4f0-40a438b22494 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"key": "header", "value": {"summary": "Upload app.py with huggingface_hub", "description": ""}} {"key": "file", "value": {"content": "CmltcG9ydCBncmFkaW8gYXMgZ3IKCgpkZWYgZ3JlZXQobmFtZSk6CiAgICByZXR1cm4gIkhlbGxvICIgKyBuYW1lICsgIiEhIgoKaWZhY2UgPSBnci5JbnRlcmZhY2UoZm49Z3JlZXQsIGlucHV0cz0idGV4dCIsIG91dHB1dHM9InRleHQiKQppZmFjZS5sYXVuY2goKQo=", "path": "app.py", "encoding": "base64"}} ' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '383' Content-Type: - application/x-ndjson X-Amzn-Trace-Id: - e6008894-c8c9-46f8-b312-fe122fb2a865 method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/commit/main response: body: string: '{"success":true,"commitOid":"403adf5c8b48aecc6b4e5d3bfc402587153a707a","commitUrl":"https://huggingface.co/spaces/user/tmp_test_space/commit/403adf5c8b48aecc6b4e5d3bfc402587153a707a","hookOutput":""}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '202' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:28 GMT ETag: - W/"ca-CjZ13AP//EebkDr6mTGQvoAfv9k" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 e47c282d2c53705a367f9e376a2eab28.cloudfront.net (CloudFront) X-Amz-Cf-Id: - qYf8mYLedX8_1Ad_I0WOrQKks4fy_QE06sR0XLUJyrqshZJS-Y25WQ== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae57c-6da554ae1ba9088c56d60ab1;e6008894-c8c9-46f8-b312-fe122fb2a865 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: null headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive X-Amzn-Trace-Id: - 3fca2004-7797-4321-9981-8d2b491cf804 method: GET uri: https://huggingface.co/api/spaces/user/tmp_test_space/runtime response: body: string: '{"stage":"BUILDING","hardware":{"current":null,"requested":"cpu-basic"},"storage":null,"gcTimeout":172800,"replicas":{"requested":1},"devMode":false,"domains":[{"domain":"user-tmp-test-space.hf.space","stage":"READY"}]}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '222' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:28 GMT ETag: - W/"de-+RBXGi1JeTfRdHwxH3iLUmZUjr0" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 e47c282d2c53705a367f9e376a2eab28.cloudfront.net (CloudFront) X-Amz-Cf-Id: - Bv6K4IAykvk23sjvTLP_Hw7THKNjVElwAGrLR1rO6wd-PvBAFPHAQw== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae57c-57dd5e08323750a80cc4d371;3fca2004-7797-4321-9981-8d2b491cf804 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"name": "tmp_test_space", "organization": "user", "type": "space"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '70' Content-Type: - application/json X-Amzn-Trace-Id: - 6c21467f-21f0-4540-901b-8ccffdc38aba method: DELETE uri: https://huggingface.co/api/repos/delete response: body: string: OK headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '2' Content-Type: - text/plain; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:29 GMT ETag: - W/"2-nOO9QiTIwXgNtWtBJezz8kv3SLc" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 e47c282d2c53705a367f9e376a2eab28.cloudfront.net (CloudFront) X-Amz-Cf-Id: - a7cMrRuhNNV4QjDn2npcfEr4tCp9tO698-5gAGH44hYpOnfo0pleNA== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae57d-6203d6c110ca8209526862d6;6c21467f-21f0-4540-901b-8ccffdc38aba cross-origin-opener-policy: - same-origin status: code: 200 message: OK version: 1 huggingface_hub-0.31.1/tests/cassettes/TestSpaceAPIProduction.test_static_space_runtime.yaml000066400000000000000000000250471500667546600325130ustar00rootroot00000000000000interactions: - request: body: '{"name": "tmp_test_space", "organization": "user", "private": true, "type": "space", "sdk": "gradio"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '104' Content-Type: - application/json X-Amzn-Trace-Id: - 4549789f-7f26-4688-9868-9f1bc01c974d method: POST uri: https://huggingface.co/api/repos/create response: body: string: '{"url":"https://huggingface.co/spaces/user/tmp_test_space","name":"user/tmp_test_space","id":"67cae57dfb88e1c3d1fec3f5"}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '126' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:30 GMT ETag: - W/"7e-Za5bJCkGRDGiUaqxqd5bgQlWb0s" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 0a58752d78fb248f2488304f0f93599a.cloudfront.net (CloudFront) X-Amz-Cf-Id: - 4n7fqC7W63JdXvQadMSji8v52mrgQrUJsxb-xGjFDs1z8p32gkbADg== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae57d-32f50faf78820a041d5de551;4549789f-7f26-4688-9868-9f1bc01c974d cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"files": [{"path": "app.py", "sample": "CmltcG9ydCBncmFkaW8gYXMgZ3IKCgpkZWYgZ3JlZXQobmFtZSk6CiAgICByZXR1cm4gIkhlbGxvICIgKyBuYW1lICsgIiEhIgoKaWZhY2UgPSBnci5JbnRlcmZhY2UoZm49Z3JlZXQsIGlucHV0cz0idGV4dCIsIG91dHB1dHM9InRleHQiKQppZmFjZS5sYXVuY2goKQo=", "size": 152}]}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '262' Content-Type: - application/json X-Amzn-Trace-Id: - 77fe438e-3344-4de8-9a38-40309aef3c92 method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/preupload/main response: body: string: '{"files":[{"path":"app.py","uploadMode":"regular","shouldIgnore":false}],"commitOid":"47ad63ebd52a0886c638a6445aa05d6fe5cb0836"}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '128' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:30 GMT ETag: - W/"80-s4MwkC8gi+pv72Az3k1g5l7bhpQ" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 0a58752d78fb248f2488304f0f93599a.cloudfront.net (CloudFront) X-Amz-Cf-Id: - imIuYRyXozoYv6n3ga8exJUEUUPskH4CXHda1C8pG_oa9Vgfi90pzQ== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae57e-732a1ad93e26bfa548f5013e;77fe438e-3344-4de8-9a38-40309aef3c92 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: null headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive X-Amzn-Trace-Id: - 798c63f1-e894-4be3-844b-cdd4f782f19e method: GET uri: https://huggingface.co/api/spaces/user/tmp_test_space/revision/main?expand=xetEnabled response: body: string: '{"_id":"67cae57dfb88e1c3d1fec3f5","id":"user/tmp_test_space","xetEnabled":false}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '83' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:30 GMT ETag: - W/"53-mxGtYvhlZ5YLSyPSosOnT6ZX0iI" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 0a58752d78fb248f2488304f0f93599a.cloudfront.net (CloudFront) X-Amz-Cf-Id: - sGYTzm-Xp-C72Q4k6RuLfZGzOR0cyAWxAhi1P9DzSbgFSh4zWGZXYg== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae57e-2772093b1fff549c1db58234;798c63f1-e894-4be3-844b-cdd4f782f19e cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"key": "header", "value": {"summary": "Upload app.py with huggingface_hub", "description": ""}} {"key": "file", "value": {"content": "CmltcG9ydCBncmFkaW8gYXMgZ3IKCgpkZWYgZ3JlZXQobmFtZSk6CiAgICByZXR1cm4gIkhlbGxvICIgKyBuYW1lICsgIiEhIgoKaWZhY2UgPSBnci5JbnRlcmZhY2UoZm49Z3JlZXQsIGlucHV0cz0idGV4dCIsIG91dHB1dHM9InRleHQiKQppZmFjZS5sYXVuY2goKQo=", "path": "app.py", "encoding": "base64"}} ' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '383' Content-Type: - application/x-ndjson X-Amzn-Trace-Id: - ff93d09b-63d6-4527-90a8-fb753acf0862 method: POST uri: https://huggingface.co/api/spaces/user/tmp_test_space/commit/main response: body: string: '{"success":true,"commitOid":"18404cdff7a1dbb7890b52aa94d6f6db6e7b580a","commitUrl":"https://huggingface.co/spaces/user/tmp_test_space/commit/18404cdff7a1dbb7890b52aa94d6f6db6e7b580a","hookOutput":""}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '202' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:31 GMT ETag: - W/"ca-oCTIzMc2Cy1da0xNjSczfJ7lxFo" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 0a58752d78fb248f2488304f0f93599a.cloudfront.net (CloudFront) X-Amz-Cf-Id: - HCFbwV029QrXZm6V5Nyfhd-CsxzWGal5HMlsHG0sEQEeeGZRY2cufQ== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae57f-3d688bc439de5a007a6b0774;ff93d09b-63d6-4527-90a8-fb753acf0862 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: null headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive X-Amzn-Trace-Id: - 667eaa7d-1339-46c0-82ce-c0bd167b6fe1 method: GET uri: https://huggingface.co/api/spaces/victor/static-space/runtime response: body: string: '{"stage":"RUNNING","hardware":{"current":null,"requested":null},"storage":null,"replicas":{"requested":1,"current":1}}' headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '118' Content-Type: - application/json; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:31 GMT ETag: - W/"76-P0FPsIJ0y4/N4wjDHf89CLxBVFg" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 0a58752d78fb248f2488304f0f93599a.cloudfront.net (CloudFront) X-Amz-Cf-Id: - FX2xaJMG8GeIQbEev3sNjve3z__q3GoBVSs_-WM2rxqQAoda0aSDww== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae57f-381dd26e43c8857376b052e8;667eaa7d-1339-46c0-82ce-c0bd167b6fe1 cross-origin-opener-policy: - same-origin status: code: 200 message: OK - request: body: '{"name": "tmp_test_space", "organization": "user", "type": "space"}' headers: Accept: - '*/*' Accept-Encoding: - gzip, deflate, br Connection: - keep-alive Content-Length: - '70' Content-Type: - application/json X-Amzn-Trace-Id: - 19e2c3c5-6fad-4f35-8791-f2d124301c44 method: DELETE uri: https://huggingface.co/api/repos/delete response: body: string: OK headers: Access-Control-Allow-Origin: - https://huggingface.co Access-Control-Expose-Headers: - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash Connection: - keep-alive Content-Length: - '2' Content-Type: - text/plain; charset=utf-8 Date: - Fri, 07 Mar 2025 12:24:32 GMT ETag: - W/"2-nOO9QiTIwXgNtWtBJezz8kv3SLc" Referrer-Policy: - strict-origin-when-cross-origin Vary: - Origin Via: - 1.1 0a58752d78fb248f2488304f0f93599a.cloudfront.net (CloudFront) X-Amz-Cf-Id: - KAeQud5wozopUAy_MGd2FCVGhJvdnBLTNTSKri5T-uOFkKzLIVHc6g== X-Amz-Cf-Pop: - CDG52-P4 X-Cache: - Miss from cloudfront X-Powered-By: - huggingface-moon X-Request-Id: - Root=1-67cae57f-305cba7f65a0591268a9898e;19e2c3c5-6fad-4f35-8791-f2d124301c44 cross-origin-opener-policy: - same-origin status: code: 200 message: OK version: 1 huggingface_hub-0.31.1/tests/conftest.py000066400000000000000000000061261500667546600202660ustar00rootroot00000000000000import os import shutil from typing import Generator import pytest from _pytest.fixtures import SubRequest import huggingface_hub from huggingface_hub import constants from huggingface_hub.utils import SoftTemporaryDirectory, logging from .testing_utils import set_write_permission_and_retry @pytest.fixture(autouse=True, scope="function") def patch_constants(mocker): with SoftTemporaryDirectory() as cache_dir: mocker.patch.object(constants, "HF_HOME", cache_dir) mocker.patch.object(constants, "HF_HUB_CACHE", os.path.join(cache_dir, "hub")) mocker.patch.object(constants, "HF_XET_CACHE", os.path.join(cache_dir, "xet")) mocker.patch.object(constants, "HUGGINGFACE_HUB_CACHE", os.path.join(cache_dir, "hub")) mocker.patch.object(constants, "HF_ASSETS_CACHE", os.path.join(cache_dir, "assets")) mocker.patch.object(constants, "HF_TOKEN_PATH", os.path.join(cache_dir, "token")) mocker.patch.object(constants, "HF_STORED_TOKENS_PATH", os.path.join(cache_dir, "stored_tokens")) yield logger = logging.get_logger(__name__) @pytest.fixture def fx_cache_dir(request: SubRequest) -> Generator[None, None, None]: """Add a `cache_dir` attribute pointing to a temporary directory in tests. Example: ```py @pytest.mark.usefixtures("fx_cache_dir") class TestWithCache(unittest.TestCase): cache_dir: Path def test_cache_dir(self) -> None: self.assertTrue(self.cache_dir.is_dir()) ``` """ with SoftTemporaryDirectory() as cache_dir: request.cls.cache_dir = cache_dir yield # TemporaryDirectory is not super robust on Windows when a git repository is # cloned in it. See https://www.scivision.dev/python-tempfile-permission-error-windows/. shutil.rmtree(cache_dir, onerror=set_write_permission_and_retry) @pytest.fixture(autouse=True) def disable_symlinks_on_windows_ci(monkeypatch: pytest.MonkeyPatch) -> None: class FakeSymlinkDict(dict): def __contains__(self, __o: object) -> bool: return True # consider any `cache_dir` to be already checked def __getitem__(self, __key: str) -> bool: return False # symlinks are never supported if os.name == "nt" and os.environ.get("DISABLE_SYMLINKS_IN_WINDOWS_TESTS"): monkeypatch.setattr( huggingface_hub.file_download, "_are_symlinks_supported_in_dir", FakeSymlinkDict(), ) @pytest.fixture(autouse=True) def disable_experimental_warnings(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(huggingface_hub.constants, "HF_HUB_DISABLE_EXPERIMENTAL_WARNING", True) @pytest.fixture(scope="module") def vcr_config(): return { "filter_headers": ["authorization", "user-agent", "cookie"], "ignore_localhost": True, "path_transformer": lambda path: path + ".yaml", } @pytest.fixture(autouse=True) def clear_lru_cache(): from huggingface_hub.inference._providers.hf_inference import _check_supported_task _check_supported_task.cache_clear() yield _check_supported_task.cache_clear() huggingface_hub-0.31.1/tests/fixtures/000077500000000000000000000000001500667546600177335ustar00rootroot00000000000000huggingface_hub-0.31.1/tests/fixtures/cards/000077500000000000000000000000001500667546600210275ustar00rootroot00000000000000huggingface_hub-0.31.1/tests/fixtures/cards/sample_datasetcard_simple.md000066400000000000000000000006161500667546600265450ustar00rootroot00000000000000--- language: - en license: - bsd-3-clause annotations_creators: - crowdsourced - expert-generated language_creators: - found multilinguality: - monolingual size_categories: - n<1K task_categories: - image-segmentation task_ids: - semantic-segmentation pretty_name: Sample Segmentation --- # Dataset Card for Sample Segmentation This is a sample dataset card for a semantic segmentation dataset. huggingface_hub-0.31.1/tests/fixtures/cards/sample_datasetcard_template.md000066400000000000000000000001311500667546600270570ustar00rootroot00000000000000--- {card_data} --- # {{ pretty_name | default("Dataset Name", true)}} {{ some_data }} huggingface_hub-0.31.1/tests/fixtures/cards/sample_invalid_card_data.md000066400000000000000000000002261500667546600263220ustar00rootroot00000000000000--- [] --- # invalid-card-data This card should fail when trying to load it in because the card data between the `---` is a list instead of a dict. huggingface_hub-0.31.1/tests/fixtures/cards/sample_invalid_model_index.md000066400000000000000000000007001500667546600267040ustar00rootroot00000000000000--- language: en license: mit library_name: timm tags: - pytorch - image-classification datasets: - beans metrics: - acc model-index: - name: my-cool-model results: - task: type: image-classification metrics: - type: acc value: 0.9 --- # Invalid Model Index In this example, the model index does not define a dataset field. In this case, we'll still initialize CardData, but will leave model-index/eval_results out of it. huggingface_hub-0.31.1/tests/fixtures/cards/sample_no_metadata.md000066400000000000000000000002241500667546600251640ustar00rootroot00000000000000# MyCoolModel In this example, we don't have any metadata at the top of the file. In cases like these, `CardData` should be instantiated as empty. huggingface_hub-0.31.1/tests/fixtures/cards/sample_simple.md000066400000000000000000000003521500667546600242030ustar00rootroot00000000000000--- language: - en license: mit library_name: pytorch-lightning tags: - pytorch - image-classification datasets: - beans metrics: - acc --- # my-cool-model ## Model description You can embed local or remote images using `![](...)` huggingface_hub-0.31.1/tests/fixtures/cards/sample_simple_model_index.md000066400000000000000000000013341500667546600265530ustar00rootroot00000000000000--- language: en license: mit library_name: timm tags: - pytorch - image-classification datasets: - beans metrics: - accuracy model-index: - name: my-cool-model results: - task: type: image-classification dataset: type: beans name: Beans metrics: - type: accuracy value: 0.9 - task: type: image-classification dataset: type: beans name: Beans config: default split: test revision: 5503434ddd753f426f4b38109466949a1217c2bb args: date: 20220120 metrics: - type: f1 value: 0.66 --- # my-cool-model ## Model description This is a test model card with multiple evaluations across different (dataset, metric) configurations. huggingface_hub-0.31.1/tests/fixtures/cards/sample_template.md000066400000000000000000000001311500667546600245200ustar00rootroot00000000000000--- {{card_data}} --- # {{ model_name | default("MyModelName", true)}} {{ some_data }} huggingface_hub-0.31.1/tests/fixtures/cards/sample_windows_line_breaks.md000066400000000000000000000004251500667546600267430ustar00rootroot00000000000000--- license: mit language: eo thumbnail: https://huggingface.co/blog/assets/01_how-to-train/EsperBERTo-thumbnail-v2.png widget: - text: "Jen la komenco de bela ." - text: "Uno du " - text: "Jen finiĝas bela ." --- # Hello old Windows line breaks huggingface_hub-0.31.1/tests/fixtures/empty.txt000066400000000000000000000000001500667546600216200ustar00rootroot00000000000000huggingface_hub-0.31.1/tests/test_auth.py000066400000000000000000000115441500667546600204410ustar00rootroot00000000000000import os import tempfile from unittest.mock import patch import pytest from huggingface_hub import constants from huggingface_hub._login import _login, _set_active_token, auth_switch, logout from huggingface_hub.utils._auth import _get_token_by_name, _get_token_from_file, _save_token, get_stored_tokens from .testing_constants import ENDPOINT_STAGING, OTHER_TOKEN, TOKEN @pytest.fixture(autouse=True) def use_tmp_file_paths(): """ Fixture to temporarily override HF_TOKEN_PATH, HF_STORED_TOKENS_PATH, and ENDPOINT. This fixture patches the constants in the huggingface_hub module to use the specified paths and the staging endpoint. It also ensures that the files are deleted after all tests in the module are completed. """ with tempfile.TemporaryDirectory() as tmp_hf_home: hf_token_path = os.path.join(tmp_hf_home, "token") hf_stored_tokens_path = os.path.join(tmp_hf_home, "stored_tokens") with patch.multiple( constants, HF_TOKEN_PATH=hf_token_path, HF_STORED_TOKENS_PATH=hf_stored_tokens_path, ENDPOINT=ENDPOINT_STAGING, ): yield class TestGetTokenByName: def test_get_existing_token(self): _save_token(TOKEN, "test_token") token = _get_token_by_name("test_token") assert token == TOKEN def test_get_non_existent_token(self): assert _get_token_by_name("non_existent") is None class TestSaveToken: def test_save_new_token(self): _save_token(TOKEN, "new_token") stored_tokens = get_stored_tokens() assert "new_token" in stored_tokens assert stored_tokens["new_token"] == TOKEN def test_overwrite_existing_token(self): _save_token(TOKEN, "test_token") _save_token("new_token", "test_token") assert _get_token_by_name("test_token") == "new_token" class TestSetActiveToken: def test_set_active_token_success(self): _save_token(TOKEN, "test_token") _set_active_token("test_token", add_to_git_credential=False) assert _get_token_from_file() == TOKEN def test_set_active_token_non_existent(self): non_existent_token = "non_existent" with pytest.raises(ValueError, match="Token non_existent not found in .*"): _set_active_token(non_existent_token, add_to_git_credential=False) class TestLogin: @patch( "huggingface_hub.hf_api.whoami", return_value={ "auth": { "accessToken": { "displayName": "test_token", "role": "write", "createdAt": "2024-01-01T00:00:00.000Z", } } }, ) def test_login_success(self, mock_whoami): _login(TOKEN, add_to_git_credential=False) assert _get_token_by_name("test_token") == TOKEN assert _get_token_from_file() == TOKEN class TestLogout: def test_logout_deletes_files(self): _save_token(TOKEN, "test_token") _set_active_token("test_token", add_to_git_credential=False) assert os.path.exists(constants.HF_TOKEN_PATH) assert os.path.exists(constants.HF_STORED_TOKENS_PATH) logout() # Check that both files are deleted assert not os.path.exists(constants.HF_TOKEN_PATH) assert not os.path.exists(constants.HF_STORED_TOKENS_PATH) def test_logout_specific_token(self): # Create two tokens _save_token(TOKEN, "token_1") _save_token(OTHER_TOKEN, "token_2") logout("token_1") # Check that HF_STORED_TOKENS_PATH still exists assert os.path.exists(constants.HF_STORED_TOKENS_PATH) # Check that token_1 is removed stored_tokens = get_stored_tokens() assert "token_1" not in stored_tokens assert "token_2" in stored_tokens def test_logout_active_token(self): _save_token(TOKEN, "active_token") _set_active_token("active_token", add_to_git_credential=False) logout("active_token") # Check that both files are deleted assert not os.path.exists(constants.HF_TOKEN_PATH) stored_tokens = get_stored_tokens() assert "active_token" not in stored_tokens class TestAuthSwitch: def test_auth_switch_existing_token(self): # Add two access tokens _save_token(TOKEN, "test_token_1") _save_token(OTHER_TOKEN, "test_token_2") # Set `test_token_1` as the active token _set_active_token("test_token_1", add_to_git_credential=False) # Switch to `test_token_2` auth_switch("test_token_2", add_to_git_credential=False) assert _get_token_from_file() == OTHER_TOKEN def test_auth_switch_nonexisting_token(self): with patch("huggingface_hub.utils._auth._get_token_by_name", return_value=None): with pytest.raises(ValueError): auth_switch("nonexistent_token") huggingface_hub-0.31.1/tests/test_auth_cli.py000066400000000000000000000117671500667546600212770ustar00rootroot00000000000000import logging import os import tempfile from unittest.mock import patch import pytest from pytest import CaptureFixture, LogCaptureFixture from huggingface_hub import constants from huggingface_hub.commands.user import AuthListCommand, AuthSwitchCommand, LoginCommand, LogoutCommand from .testing_constants import ENDPOINT_STAGING from .testing_utils import assert_in_logs # fixtures & constants MOCK_TOKEN = "hf_1234" @pytest.fixture(autouse=True) def use_tmp_file_paths(): """ Fixture to temporarily override HF_TOKEN_PATH, HF_STORED_TOKENS_PATH, and ENDPOINT. """ with tempfile.TemporaryDirectory() as tmp_hf_home: hf_token_path = os.path.join(tmp_hf_home, "token") hf_stored_tokens_path = os.path.join(tmp_hf_home, "stored_tokens") with patch.multiple( constants, HF_TOKEN_PATH=hf_token_path, HF_STORED_TOKENS_PATH=hf_stored_tokens_path, ENDPOINT=ENDPOINT_STAGING, ): yield @pytest.fixture def mock_whoami_api_call(): MOCK_WHOAMI_RESPONSE = { "auth": { "accessToken": { "displayName": "test_token", "role": "write", "createdAt": "2024-01-01T00:00:00.000Z", } } } with patch("huggingface_hub.hf_api.whoami", return_value=MOCK_WHOAMI_RESPONSE): yield @pytest.fixture def mock_stored_tokens(): """Mock stored tokens.""" stored_tokens = { "token1": "hf_1234", "token2": "hf_5678", "active_token": "hf_9012", } with patch("huggingface_hub._login.get_stored_tokens", return_value=stored_tokens): with patch("huggingface_hub.utils._auth.get_stored_tokens", return_value=stored_tokens): yield stored_tokens def test_login_command_basic(mock_whoami_api_call, caplog: LogCaptureFixture): """Test basic login command execution.""" caplog.set_level(logging.INFO) args = type("Args", (), {"token": MOCK_TOKEN, "add_to_git_credential": False})() cmd = LoginCommand(args) cmd.run() assert_in_logs(caplog, "Login successful") assert_in_logs(caplog, "Token is valid") assert_in_logs(caplog, "The current active token is: `test_token`") def test_login_command_with_git(mock_whoami_api_call, caplog: LogCaptureFixture): """Test login command with git credential option.""" caplog.set_level(logging.INFO) args = type("Args", (), {"token": MOCK_TOKEN, "add_to_git_credential": True})() cmd = LoginCommand(args) with patch("huggingface_hub._login._is_git_credential_helper_configured", return_value=True): with patch("huggingface_hub.utils.set_git_credential"): cmd.run() assert_in_logs(caplog, "Login successful") assert_in_logs(caplog, "Your token has been saved in your configured git credential helpers") def test_logout_specific_token(mock_stored_tokens, caplog: LogCaptureFixture): """Test logout command for a specific token.""" caplog.set_level(logging.INFO) args = type("Args", (), {"token_name": "token1"})() cmd = LogoutCommand(args) cmd.run() assert_in_logs(caplog, "Successfully logged out from access token: token1") def test_logout_active_token(mock_stored_tokens, caplog: LogCaptureFixture): """Test logout command for active token.""" caplog.set_level(logging.INFO) with patch("huggingface_hub._login._get_token_from_file", return_value="hf_9012"): args = type("Args", (), {"token_name": "active_token"})() cmd = LogoutCommand(args) cmd.run() assert_in_logs(caplog, "Successfully logged out from access token: active_token") assert_in_logs(caplog, "Active token 'active_token' has been deleted") def test_logout_all_tokens(mock_stored_tokens, caplog: LogCaptureFixture): """Test logout command for all tokens.""" caplog.set_level(logging.INFO) args = type("Args", (), {"token_name": None})() cmd = LogoutCommand(args) cmd.run() assert_in_logs(caplog, "Successfully logged out from all access tokens") def test_switch_token(mock_stored_tokens, caplog: LogCaptureFixture): """Test switching between tokens.""" caplog.set_level(logging.INFO) args = type("Args", (), {"token_name": "token1", "add_to_git_credential": False})() cmd = AuthSwitchCommand(args) cmd.run() assert_in_logs(caplog, "The current active token is: token1") def test_switch_nonexistent_token(mock_stored_tokens): """Test switching to a non-existent token.""" args = type("Args", (), {"token_name": "nonexistent", "add_to_git_credential": False})() cmd = AuthSwitchCommand(args) with pytest.raises(ValueError, match="Access token nonexistent not found"): cmd.run() def test_list_tokens(mock_stored_tokens, capsys: CaptureFixture): """Test listing tokens command.""" args = type("Args", (), {})() cmd = AuthListCommand(args) cmd.run() captured = capsys.readouterr() assert "token1" in captured.out assert "hf_****1234" in captured.out assert "token2" in captured.out huggingface_hub-0.31.1/tests/test_cache_layout.py000066400000000000000000000345051500667546600221420ustar00rootroot00000000000000import os import time import unittest from io import BytesIO from huggingface_hub import HfApi, hf_hub_download, snapshot_download from huggingface_hub.errors import EntryNotFoundError from huggingface_hub.utils import SoftTemporaryDirectory, logging from .testing_constants import ENDPOINT_STAGING, TOKEN from .testing_utils import ( repo_name, with_production_testing, xfail_on_windows, ) logger = logging.get_logger(__name__) MODEL_IDENTIFIER = "hf-internal-testing/hfh-cache-layout" def get_file_contents(path): with open(path) as f: content = f.read() return content @with_production_testing class CacheFileLayoutHfHubDownload(unittest.TestCase): @xfail_on_windows(reason="Symlinks are deactivated in Windows tests.") def test_file_downloaded_in_cache(self): for revision, expected_reference in ( (None, "main"), ("file-2", "file-2"), ): with self.subTest(revision), SoftTemporaryDirectory() as cache: hf_hub_download( MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache, revision=revision, ) expected_directory_name = f"models--{MODEL_IDENTIFIER.replace('/', '--')}" expected_path = os.path.join(cache, expected_directory_name) refs = os.listdir(os.path.join(expected_path, "refs")) snapshots = os.listdir(os.path.join(expected_path, "snapshots")) # Only reference should be the expected one. self.assertListEqual(refs, [expected_reference]) with open(os.path.join(expected_path, "refs", expected_reference)) as f: snapshot_name = f.readline().strip() # The `main` reference should point to the only snapshot we have downloaded self.assertListEqual(snapshots, [snapshot_name]) snapshot_path = os.path.join(expected_path, "snapshots", snapshot_name) snapshot_content = os.listdir(snapshot_path) # Only a single file in the snapshot self.assertEqual(len(snapshot_content), 1) snapshot_content_path = os.path.join(snapshot_path, snapshot_content[0]) # The snapshot content should link to a blob self.assertTrue(os.path.islink(snapshot_content_path)) resolved_blob_relative = os.readlink(snapshot_content_path) resolved_blob_absolute = os.path.normpath(os.path.join(snapshot_path, resolved_blob_relative)) with open(resolved_blob_absolute) as f: blob_contents = f.readline().strip() # The contents of the file should be 'File 0'. self.assertEqual(blob_contents, "File 0") def test_no_exist_file_is_cached(self): revisions = [None, "file-2"] expected_references = ["main", "file-2"] for revision, expected_reference in zip(revisions, expected_references): with self.subTest(revision), SoftTemporaryDirectory() as cache: filename = "this_does_not_exist.txt" with self.assertRaises(EntryNotFoundError): # The file does not exist, so we get an exception. hf_hub_download(MODEL_IDENTIFIER, filename, cache_dir=cache, revision=revision) expected_directory_name = f"models--{MODEL_IDENTIFIER.replace('/', '--')}" expected_path = os.path.join(cache, expected_directory_name) refs = os.listdir(os.path.join(expected_path, "refs")) no_exist_snapshots = os.listdir(os.path.join(expected_path, ".no_exist")) # Only reference should be `main`. self.assertListEqual(refs, [expected_reference]) with open(os.path.join(expected_path, "refs", expected_reference)) as f: snapshot_name = f.readline().strip() # The `main` reference should point to the only snapshot we have downloaded self.assertListEqual(no_exist_snapshots, [snapshot_name]) no_exist_path = os.path.join(expected_path, ".no_exist", snapshot_name) no_exist_content = os.listdir(no_exist_path) # Only a single file in the no_exist snapshot self.assertEqual(len(no_exist_content), 1) # The no_exist content should be our file self.assertEqual(no_exist_content[0], filename) with open(os.path.join(no_exist_path, filename)) as f: content = f.read().strip() # The contents of the file should be empty. self.assertEqual(content, "") def test_file_download_happens_once(self): # Tests that a file is only downloaded once if it's not updated. with SoftTemporaryDirectory() as cache: path = hf_hub_download(MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache) creation_time_0 = os.path.getmtime(path) time.sleep(2) path = hf_hub_download(MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache) creation_time_1 = os.path.getmtime(path) self.assertEqual(creation_time_0, creation_time_1) @xfail_on_windows(reason="Symlinks are deactivated in Windows tests.") def test_file_download_happens_once_intra_revision(self): # Tests that a file is only downloaded once if it's not updated, even across different revisions. with SoftTemporaryDirectory() as cache: path = hf_hub_download(MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache) creation_time_0 = os.path.getmtime(path) time.sleep(2) path = hf_hub_download(MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache, revision="file-2") creation_time_1 = os.path.getmtime(path) self.assertEqual(creation_time_0, creation_time_1) @xfail_on_windows(reason="Symlinks are deactivated in Windows tests.") def test_multiple_refs_for_same_file(self): with SoftTemporaryDirectory() as cache: hf_hub_download(MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache) hf_hub_download(MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache, revision="file-2") expected_directory_name = f"models--{MODEL_IDENTIFIER.replace('/', '--')}" expected_path = os.path.join(cache, expected_directory_name) refs = os.listdir(os.path.join(expected_path, "refs")) refs.sort() snapshots = os.listdir(os.path.join(expected_path, "snapshots")) snapshots.sort() # Directory should contain two revisions self.assertListEqual(refs, ["file-2", "main"]) refs_contents = [get_file_contents(os.path.join(expected_path, "refs", f)) for f in refs] refs_contents.sort() # snapshots directory should contain two snapshots self.assertListEqual(refs_contents, snapshots) snapshot_links = [ os.readlink(os.path.join(expected_path, "snapshots", filename, "file_0.txt")) for filename in snapshots ] # All snapshot links should point to the same file. self.assertEqual(*snapshot_links) @with_production_testing class CacheFileLayoutSnapshotDownload(unittest.TestCase): @xfail_on_windows(reason="Symlinks are deactivated in Windows tests.") def test_file_downloaded_in_cache(self): with SoftTemporaryDirectory() as cache: snapshot_download(MODEL_IDENTIFIER, cache_dir=cache) expected_directory_name = f"models--{MODEL_IDENTIFIER.replace('/', '--')}" expected_path = os.path.join(cache, expected_directory_name) refs = os.listdir(os.path.join(expected_path, "refs")) snapshots = os.listdir(os.path.join(expected_path, "snapshots")) snapshots.sort() # Directory should contain two revisions self.assertListEqual(refs, ["main"]) ref_content = get_file_contents(os.path.join(expected_path, "refs", refs[0])) # snapshots directory should contain two snapshots self.assertListEqual([ref_content], snapshots) snapshot_path = os.path.join(expected_path, "snapshots", snapshots[0]) files_in_snapshot = os.listdir(snapshot_path) snapshot_links = [os.readlink(os.path.join(snapshot_path, filename)) for filename in files_in_snapshot] resolved_snapshot_links = [os.path.normpath(os.path.join(snapshot_path, link)) for link in snapshot_links] self.assertTrue(all([os.path.isfile(link) for link in resolved_snapshot_links])) @xfail_on_windows(reason="Symlinks are deactivated in Windows tests.") def test_file_downloaded_in_cache_several_revisions(self): with SoftTemporaryDirectory() as cache: snapshot_download(MODEL_IDENTIFIER, cache_dir=cache, revision="file-3") snapshot_download(MODEL_IDENTIFIER, cache_dir=cache, revision="file-2") expected_directory_name = f"models--{MODEL_IDENTIFIER.replace('/', '--')}" expected_path = os.path.join(cache, expected_directory_name) refs = os.listdir(os.path.join(expected_path, "refs")) refs.sort() snapshots = os.listdir(os.path.join(expected_path, "snapshots")) snapshots.sort() # Directory should contain two revisions self.assertListEqual(refs, ["file-2", "file-3"]) refs_content = [get_file_contents(os.path.join(expected_path, "refs", ref)) for ref in refs] refs_content.sort() # snapshots directory should contain two snapshots self.assertListEqual(refs_content, snapshots) snapshots_paths = [os.path.join(expected_path, "snapshots", s) for s in snapshots] files_in_snapshots = {s: os.listdir(s) for s in snapshots_paths} links_in_snapshots = { k: [os.readlink(os.path.join(k, _v)) for _v in v] for k, v in files_in_snapshots.items() } resolved_snapshots_links = { k: [os.path.normpath(os.path.join(k, link)) for link in v] for k, v in links_in_snapshots.items() } all_links = [b for a in resolved_snapshots_links.values() for b in a] all_unique_links = set(all_links) # [ 100] . # ├── [ 140] blobs # │ ├── [ 7] 4475433e279a71203927cbe80125208a3b5db560 # │ ├── [ 7] 50fcd26d6ce3000f9d5f12904e80eccdc5685dd1 # │ ├── [ 7] 80146afc836c60e70ba67933fec439ab05b478f6 # │ ├── [ 7] 8cf9e18f080becb674b31c21642538269fe886a4 # │ └── [1.1K] ac481c8eb05e4d2496fbe076a38a7b4835dd733d # ├── [ 80] refs # │ ├── [ 40] file-2 # │ └── [ 40] file-3 # └── [ 80] snapshots # ├── [ 120] 5e23cb3ae7f904919a442e1b27dcddae6c6bc292 # │ ├── [ 52] file_0.txt -> ../../blobs/80146afc836c60e70ba67933fec439ab05b478f6 # │ ├── [ 52] file_1.txt -> ../../blobs/50fcd26d6ce3000f9d5f12904e80eccdc5685dd1 # │ ├── [ 52] file_2.txt -> ../../blobs/4475433e279a71203927cbe80125208a3b5db560 # │ └── [ 52] .gitattributes -> ../../blobs/ac481c8eb05e4d2496fbe076a38a7b4835dd733d # └── [ 120] 78aa2ebdb60bba086496a8792ba506e58e587b4c # ├── [ 52] file_0.txt -> ../../blobs/80146afc836c60e70ba67933fec439ab05b478f6 # ├── [ 52] file_1.txt -> ../../blobs/50fcd26d6ce3000f9d5f12904e80eccdc5685dd1 # ├── [ 52] file_3.txt -> ../../blobs/8cf9e18f080becb674b31c21642538269fe886a4 # └── [ 52] .gitattributes -> ../../blobs/ac481c8eb05e4d2496fbe076a38a7b4835dd733d # Across the two revisions, there should be 8 total links self.assertEqual(len(all_links), 8) # Across the two revisions, there should only be 5 unique files. self.assertEqual(len(all_unique_links), 5) class ReferenceUpdates(unittest.TestCase): _api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) def test_update_reference(self): repo_id = self._api.create_repo(repo_name(), exist_ok=True).repo_id try: self._api.upload_file(path_or_fileobj=BytesIO(b"Some string"), path_in_repo="file.txt", repo_id=repo_id) with SoftTemporaryDirectory() as cache: hf_hub_download(repo_id, "file.txt", cache_dir=cache) expected_directory_name = f"models--{repo_id.replace('/', '--')}" expected_path = os.path.join(cache, expected_directory_name) refs = os.listdir(os.path.join(expected_path, "refs")) # Directory should contain two revisions self.assertListEqual(refs, ["main"]) initial_ref_content = get_file_contents(os.path.join(expected_path, "refs", refs[0])) # Upload a new file on the same branch self._api.upload_file( path_or_fileobj=BytesIO(b"Some new string"), path_in_repo="file.txt", repo_id=repo_id, ) hf_hub_download(repo_id, "file.txt", cache_dir=cache) final_ref_content = get_file_contents(os.path.join(expected_path, "refs", refs[0])) # The `main` reference should point to two different, but existing snapshots which contain # a 'file.txt' self.assertNotEqual(initial_ref_content, final_ref_content) self.assertTrue(os.path.isdir(os.path.join(expected_path, "snapshots", initial_ref_content))) self.assertTrue( os.path.isfile(os.path.join(expected_path, "snapshots", initial_ref_content, "file.txt")) ) self.assertTrue(os.path.isdir(os.path.join(expected_path, "snapshots", final_ref_content))) self.assertTrue( os.path.isfile(os.path.join(expected_path, "snapshots", final_ref_content, "file.txt")) ) except Exception: raise finally: self._api.delete_repo(repo_id) huggingface_hub-0.31.1/tests/test_cache_no_symlinks.py000066400000000000000000000172061500667546600231710ustar00rootroot00000000000000import unittest import warnings from pathlib import Path from unittest.mock import Mock, patch import pytest from huggingface_hub import hf_hub_download, scan_cache_dir from huggingface_hub.constants import CONFIG_NAME, HF_HUB_CACHE from huggingface_hub.file_download import are_symlinks_supported from .testing_utils import DUMMY_MODEL_ID, with_production_testing @with_production_testing @pytest.mark.usefixtures("fx_cache_dir") class TestCacheLayoutIfSymlinksNotSupported(unittest.TestCase): cache_dir: Path @patch( "huggingface_hub.file_download._are_symlinks_supported_in_dir", {HF_HUB_CACHE: True}, ) def test_are_symlinks_supported_default(self) -> None: self.assertTrue(are_symlinks_supported()) @patch("huggingface_hub.file_download.os.symlink") @patch("huggingface_hub.file_download._are_symlinks_supported_in_dir", {}) def test_are_symlinks_supported_windows_specific_dir(self, mock_symlink: Mock) -> None: mock_symlink.side_effect = [OSError(), None] # First dir not supported then yes this_dir = Path(__file__).parent # First time in `this_dir`: warning is raised with self.assertWarns(UserWarning): self.assertFalse(are_symlinks_supported(this_dir)) with warnings.catch_warnings(): # Assert no warnings raised # Taken from https://stackoverflow.com/a/45671804 warnings.simplefilter("error") # Second time in `this_dir` but with absolute path: value is still cached self.assertFalse(are_symlinks_supported(this_dir.absolute())) # Try with another directory: symlinks are supported, no warnings self.assertTrue(are_symlinks_supported()) # True @patch("huggingface_hub.file_download.are_symlinks_supported") def test_download_no_symlink_new_file(self, mock_are_symlinks_supported: Mock) -> None: mock_are_symlinks_supported.return_value = False filepath = Path( hf_hub_download( DUMMY_MODEL_ID, filename=CONFIG_NAME, cache_dir=self.cache_dir, local_files_only=False, ) ) # Not a symlink ! self.assertFalse(filepath.is_symlink()) self.assertTrue(filepath.is_file()) # Blobs directory is empty self.assertEqual(len(list((Path(filepath).parents[2] / "blobs").glob("*"))), 0) @patch("huggingface_hub.file_download.are_symlinks_supported") def test_download_no_symlink_existing_file(self, mock_are_symlinks_supported: Mock) -> None: mock_are_symlinks_supported.return_value = True filepath = Path( hf_hub_download( DUMMY_MODEL_ID, filename=CONFIG_NAME, cache_dir=self.cache_dir, local_files_only=False, ) ) self.assertTrue(filepath.is_symlink()) blob_path = filepath.resolve() self.assertTrue(blob_path.is_file()) # Delete file in snapshot filepath.unlink() # Re-download but symlinks are not supported anymore (example: not an admin) mock_are_symlinks_supported.return_value = False new_filepath = Path( hf_hub_download( DUMMY_MODEL_ID, filename=CONFIG_NAME, cache_dir=self.cache_dir, local_files_only=False, ) ) # File exist but is not a symlink self.assertFalse(new_filepath.is_symlink()) self.assertTrue(new_filepath.is_file()) # Blob file still exists as well (has not been deleted) # => duplicate file on disk self.assertTrue(blob_path.is_file()) @patch("huggingface_hub.file_download.are_symlinks_supported") def test_scan_and_delete_cache_no_symlinks(self, mock_are_symlinks_supported: Mock) -> None: """Test scan_cache_dir works as well when cache-system doesn't use symlinks.""" OLDER_REVISION = "44c70f043cfe8162efc274ff531575e224a0e6f0" # Symlinks not supported mock_are_symlinks_supported.return_value = False # Download config.json from main hf_hub_download( DUMMY_MODEL_ID, filename=CONFIG_NAME, cache_dir=self.cache_dir, ) # Download README.md from main hf_hub_download( DUMMY_MODEL_ID, filename="README.md", cache_dir=self.cache_dir, ) # Download config.json from older revision hf_hub_download( DUMMY_MODEL_ID, filename=CONFIG_NAME, cache_dir=self.cache_dir, revision=OLDER_REVISION, ) # Now symlinks work: user has rerun the script as admin mock_are_symlinks_supported.return_value = True # Download merges.txt from older revision with symlinks hf_hub_download( DUMMY_MODEL_ID, filename="merges.txt", cache_dir=self.cache_dir, revision=OLDER_REVISION, ) # Scan cache directory report = scan_cache_dir(self.cache_dir) # 1 repo found, no warnings self.assertEqual(len(report.repos), 1) self.assertEqual(len(report.warnings), 0) repo = list(report.repos)[0] # 2 revisions found self.assertEqual(len(repo.revisions), 2) self.assertEqual(repo.nb_files, 4) self.assertEqual(len(repo.refs), 1) # only `main` main_revision = repo.refs["main"] main_ref = main_revision.commit_hash older_revision = [rev for rev in repo.revisions if rev is not main_revision][0] # 2 files in `main` revisions, both are not symlinks self.assertEqual(main_revision.nb_files, 2) for file in main_revision.files: # No symlinks means the files are in the snapshot dir itself self.assertTrue(main_revision.snapshot_path in file.blob_path.parents) # 2 files in older revision, only 1 as symlink for file in older_revision.files: if file.file_name == CONFIG_NAME: # In snapshot dir self.assertTrue(older_revision.snapshot_path in file.blob_path.parents) else: # In blob dir self.assertFalse(older_revision.snapshot_path in file.blob_path.parents) self.assertTrue("blobs" in str(file.blob_path)) # Since files are not shared (README.md is duplicated in cache), the total size # of the repo is the sum of each revision size. If symlinks were used, the total # size of the repo would be lower. self.assertEqual(repo.size_on_disk, main_revision.size_on_disk + older_revision.size_on_disk) # Test delete repo strategy strategy_delete_repo = report.delete_revisions(main_ref, OLDER_REVISION) self.assertEqual(strategy_delete_repo.expected_freed_size, repo.size_on_disk) self.assertEqual(len(strategy_delete_repo.blobs), 0) self.assertEqual(len(strategy_delete_repo.snapshots), 0) self.assertEqual(len(strategy_delete_repo.refs), 0) self.assertEqual(len(strategy_delete_repo.repos), 1) # Test delete older revision strategy strategy_delete_revision = report.delete_revisions(OLDER_REVISION) self.assertEqual( strategy_delete_revision.blobs, {file.blob_path for file in older_revision.files}, ) self.assertEqual(strategy_delete_revision.snapshots, {older_revision.snapshot_path}) self.assertEqual(len(strategy_delete_revision.refs), 0) self.assertEqual(len(strategy_delete_revision.repos), 0) strategy_delete_revision.execute() # Execute without error huggingface_hub-0.31.1/tests/test_cli.py000066400000000000000000001017511500667546600202470ustar00rootroot00000000000000import os import unittest import warnings from argparse import ArgumentParser, Namespace from contextlib import contextmanager from pathlib import Path from typing import Generator from unittest.mock import Mock, patch from huggingface_hub.commands.delete_cache import DeleteCacheCommand from huggingface_hub.commands.download import DownloadCommand from huggingface_hub.commands.repo_files import DeleteFilesSubCommand, RepoFilesCommand from huggingface_hub.commands.scan_cache import ScanCacheCommand from huggingface_hub.commands.tag import TagCommands from huggingface_hub.commands.upload import UploadCommand from huggingface_hub.errors import RevisionNotFoundError from huggingface_hub.utils import SoftTemporaryDirectory, capture_output from .testing_utils import DUMMY_MODEL_ID class TestCacheCommand(unittest.TestCase): def setUp(self) -> None: """ Set up scan-cache/delete-cache commands as in `src/huggingface_hub/commands/huggingface_cli.py`. """ self.parser = ArgumentParser("huggingface-cli", usage="huggingface-cli []") commands_parser = self.parser.add_subparsers() ScanCacheCommand.register_subcommand(commands_parser) DeleteCacheCommand.register_subcommand(commands_parser) def test_scan_cache_basic(self) -> None: """Test `huggingface-cli scan-cache`.""" args = self.parser.parse_args(["scan-cache"]) self.assertEqual(args.dir, None) self.assertEqual(args.verbose, 0) self.assertEqual(args.func, ScanCacheCommand) def test_scan_cache_verbose(self) -> None: """Test `huggingface-cli scan-cache -v`.""" args = self.parser.parse_args(["scan-cache", "-v"]) self.assertEqual(args.dir, None) self.assertEqual(args.verbose, 1) self.assertEqual(args.func, ScanCacheCommand) def test_scan_cache_with_dir(self) -> None: """Test `huggingface-cli scan-cache --dir something`.""" args = self.parser.parse_args(["scan-cache", "--dir", "something"]) self.assertEqual(args.dir, "something") self.assertEqual(args.verbose, 0) self.assertEqual(args.func, ScanCacheCommand) def test_scan_cache_ultra_verbose(self) -> None: """Test `huggingface-cli scan-cache -vvv`.""" args = self.parser.parse_args(["scan-cache", "-vvv"]) self.assertEqual(args.dir, None) self.assertEqual(args.verbose, 3) self.assertEqual(args.func, ScanCacheCommand) def test_delete_cache_with_dir(self) -> None: """Test `huggingface-cli delete-cache --dir something`.""" args = self.parser.parse_args(["delete-cache", "--dir", "something"]) self.assertEqual(args.dir, "something") self.assertEqual(args.func, DeleteCacheCommand) class TestUploadCommand(unittest.TestCase): def setUp(self) -> None: """ Set up CLI as in `src/huggingface_hub/commands/huggingface_cli.py`. """ self.parser = ArgumentParser("huggingface-cli", usage="huggingface-cli []") commands_parser = self.parser.add_subparsers() UploadCommand.register_subcommand(commands_parser) def test_upload_basic(self) -> None: """Test `huggingface-cli upload my-folder to dummy-repo`.""" cmd = UploadCommand(self.parser.parse_args(["upload", DUMMY_MODEL_ID, "my-folder"])) self.assertEqual(cmd.repo_id, DUMMY_MODEL_ID) self.assertEqual(cmd.local_path, "my-folder") self.assertEqual(cmd.path_in_repo, ".") # implicit self.assertEqual(cmd.repo_type, "model") self.assertEqual(cmd.revision, None) self.assertEqual(cmd.include, None) self.assertEqual(cmd.exclude, None) self.assertEqual(cmd.delete, None) self.assertEqual(cmd.commit_message, None) self.assertEqual(cmd.commit_description, None) self.assertEqual(cmd.create_pr, False) self.assertEqual(cmd.every, None) self.assertEqual(cmd.api.token, None) self.assertEqual(cmd.quiet, False) def test_upload_with_wildcard(self) -> None: """Test uploading files using wildcard patterns.""" with tmp_current_directory() as cache_dir: # Create test files (Path(cache_dir) / "model1.safetensors").touch() (Path(cache_dir) / "model2.safetensors").touch() (Path(cache_dir) / "model.bin").touch() (Path(cache_dir) / "config.json").touch() # Test basic wildcard pattern cmd = UploadCommand(self.parser.parse_args(["upload", DUMMY_MODEL_ID, "*.safetensors"])) self.assertEqual(cmd.local_path, ".") self.assertEqual(cmd.include, "*.safetensors") self.assertEqual(cmd.path_in_repo, ".") self.assertEqual(cmd.repo_id, DUMMY_MODEL_ID) # Test wildcard pattern with specific directory subdir = Path(cache_dir) / "subdir" subdir.mkdir() (subdir / "special.safetensors").touch() cmd = UploadCommand(self.parser.parse_args(["upload", DUMMY_MODEL_ID, "subdir/*.safetensors"])) self.assertEqual(cmd.local_path, ".") self.assertEqual(cmd.include, "subdir/*.safetensors") self.assertEqual(cmd.path_in_repo, ".") # Test error when using wildcard with --include with self.assertRaises(ValueError): UploadCommand( self.parser.parse_args(["upload", DUMMY_MODEL_ID, "*.safetensors", "--include", "*.json"]) ) # Test error when using wildcard with explicit path_in_repo with self.assertRaises(ValueError): UploadCommand(self.parser.parse_args(["upload", DUMMY_MODEL_ID, "*.safetensors", "models/"])) def test_upload_with_all_options(self) -> None: """Test `huggingface-cli upload my-file to dummy-repo with all options selected`.""" cmd = UploadCommand( self.parser.parse_args( [ "upload", DUMMY_MODEL_ID, "my-file", "data/", "--repo-type", "dataset", "--revision", "v1.0.0", "--include", "*.json", "*.yaml", "--exclude", "*.log", "*.txt", "--delete", "*.config", "*.secret", "--commit-message", "My commit message", "--commit-description", "My commit description", "--create-pr", "--every", "5", "--token", "my-token", "--quiet", ] ) ) self.assertEqual(cmd.repo_id, DUMMY_MODEL_ID) self.assertEqual(cmd.local_path, "my-file") self.assertEqual(cmd.path_in_repo, "data/") self.assertEqual(cmd.repo_type, "dataset") self.assertEqual(cmd.revision, "v1.0.0") self.assertEqual(cmd.include, ["*.json", "*.yaml"]) self.assertEqual(cmd.exclude, ["*.log", "*.txt"]) self.assertEqual(cmd.delete, ["*.config", "*.secret"]) self.assertEqual(cmd.commit_message, "My commit message") self.assertEqual(cmd.commit_description, "My commit description") self.assertEqual(cmd.create_pr, True) self.assertEqual(cmd.every, 5) self.assertEqual(cmd.api.token, "my-token") self.assertEqual(cmd.quiet, True) def test_upload_implicit_local_path_when_folder_exists(self) -> None: with tmp_current_directory() as cache_dir: folder_path = Path(cache_dir) / "my-cool-model" folder_path.mkdir() cmd = UploadCommand(self.parser.parse_args(["upload", "my-cool-model"])) # A folder with the same name as the repo exists => upload it at the root of the repo self.assertEqual(cmd.local_path, "my-cool-model") self.assertEqual(cmd.path_in_repo, ".") def test_upload_implicit_local_path_when_file_exists(self) -> None: with tmp_current_directory() as cache_dir: folder_path = Path(cache_dir) / "my-cool-model" folder_path.touch() cmd = UploadCommand(self.parser.parse_args(["upload", "my-cool-model"])) # A file with the same name as the repo exists => upload it at the root of the repo self.assertEqual(cmd.local_path, "my-cool-model") self.assertEqual(cmd.path_in_repo, "my-cool-model") def test_upload_implicit_local_path_when_org_repo(self) -> None: with tmp_current_directory() as cache_dir: folder_path = Path(cache_dir) / "my-cool-model" folder_path.mkdir() cmd = UploadCommand(self.parser.parse_args(["upload", "my-cool-org/my-cool-model"])) # A folder with the same name as the repo exists => upload it at the root of the repo self.assertEqual(cmd.local_path, "my-cool-model") self.assertEqual(cmd.path_in_repo, ".") def test_upload_implicit_local_path_otherwise(self) -> None: # No folder or file has the same name as the repo => raise exception with self.assertRaises(ValueError): with tmp_current_directory(): UploadCommand(self.parser.parse_args(["upload", "my-cool-model"])) def test_upload_explicit_local_path_to_folder_implicit_path_in_repo(self) -> None: with tmp_current_directory() as cache_dir: folder_path = Path(cache_dir) / "path" / "to" / "folder" folder_path.mkdir(parents=True, exist_ok=True) cmd = UploadCommand(self.parser.parse_args(["upload", "my-repo", "./path/to/folder"])) self.assertEqual(cmd.local_path, "./path/to/folder") self.assertEqual(cmd.path_in_repo, ".") # Always upload the folder at the root of the repo def test_upload_explicit_local_path_to_file_implicit_path_in_repo(self) -> None: with tmp_current_directory() as cache_dir: file_path = Path(cache_dir) / "path" / "to" / "file.txt" file_path.parent.mkdir(parents=True, exist_ok=True) file_path.touch() cmd = UploadCommand(self.parser.parse_args(["upload", "my-repo", "./path/to/file.txt"])) self.assertEqual(cmd.local_path, "./path/to/file.txt") self.assertEqual(cmd.path_in_repo, "file.txt") # If a file, upload it at the root of the repo and keep name def test_upload_explicit_paths(self) -> None: cmd = UploadCommand(self.parser.parse_args(["upload", "my-repo", "./path/to/folder", "data/"])) self.assertEqual(cmd.local_path, "./path/to/folder") self.assertEqual(cmd.path_in_repo, "data/") def test_every_must_be_positive(self) -> None: with self.assertRaises(ValueError): UploadCommand(self.parser.parse_args(["upload", DUMMY_MODEL_ID, ".", "--every", "0"])) with self.assertRaises(ValueError): UploadCommand(self.parser.parse_args(["upload", DUMMY_MODEL_ID, ".", "--every", "-10"])) def test_every_as_int(self) -> None: cmd = UploadCommand(self.parser.parse_args(["upload", DUMMY_MODEL_ID, ".", "--every", "10"])) self.assertEqual(cmd.every, 10) def test_every_as_float(self) -> None: cmd = UploadCommand(self.parser.parse_args(["upload", DUMMY_MODEL_ID, ".", "--every", "0.5"])) self.assertEqual(cmd.every, 0.5) @patch("huggingface_hub.commands.upload.HfApi.repo_info") @patch("huggingface_hub.commands.upload.HfApi.upload_folder") @patch("huggingface_hub.commands.upload.HfApi.create_repo") def test_upload_folder_mock(self, create_mock: Mock, upload_mock: Mock, repo_info_mock: Mock) -> None: with SoftTemporaryDirectory() as cache_dir: cache_path = cache_dir.absolute().as_posix() cmd = UploadCommand( self.parser.parse_args( ["upload", "my-model", cache_path, ".", "--private", "--include", "*.json", "--delete", "*.json"] ) ) cmd.run() create_mock.assert_called_once_with( repo_id="my-model", repo_type="model", exist_ok=True, private=True, space_sdk=None ) upload_mock.assert_called_once_with( folder_path=cache_path, path_in_repo=".", repo_id=create_mock.return_value.repo_id, repo_type="model", revision=None, commit_message=None, commit_description=None, create_pr=False, allow_patterns=["*.json"], ignore_patterns=None, delete_patterns=["*.json"], ) @patch("huggingface_hub.commands.upload.HfApi.repo_info") @patch("huggingface_hub.commands.upload.HfApi.upload_file") @patch("huggingface_hub.commands.upload.HfApi.create_repo") def test_upload_file_mock(self, create_mock: Mock, upload_mock: Mock, repo_info_mock: Mock) -> None: with SoftTemporaryDirectory() as cache_dir: file_path = Path(cache_dir) / "file.txt" file_path.write_text("content") cmd = UploadCommand( self.parser.parse_args( ["upload", "my-dataset", str(file_path), "logs/file.txt", "--repo-type", "dataset", "--create-pr"] ) ) cmd.run() create_mock.assert_called_once_with( repo_id="my-dataset", repo_type="dataset", exist_ok=True, private=False, space_sdk=None ) upload_mock.assert_called_once_with( path_or_fileobj=str(file_path), path_in_repo="logs/file.txt", repo_id=create_mock.return_value.repo_id, repo_type="dataset", revision=None, commit_message=None, commit_description=None, create_pr=True, ) @patch("huggingface_hub.commands.upload.HfApi.repo_info") @patch("huggingface_hub.commands.upload.HfApi.upload_file") @patch("huggingface_hub.commands.upload.HfApi.create_repo") def test_upload_file_no_revision_mock(self, create_mock: Mock, upload_mock: Mock, repo_info_mock: Mock) -> None: with SoftTemporaryDirectory() as cache_dir: file_path = Path(cache_dir) / "file.txt" file_path.write_text("content") cmd = UploadCommand(self.parser.parse_args(["upload", "my-model", str(file_path), "logs/file.txt"])) cmd.run() # Revision not specified => no need to check repo_info_mock.assert_not_called() @patch("huggingface_hub.commands.upload.HfApi.create_branch") @patch("huggingface_hub.commands.upload.HfApi.repo_info") @patch("huggingface_hub.commands.upload.HfApi.upload_file") @patch("huggingface_hub.commands.upload.HfApi.create_repo") def test_upload_file_with_revision_mock( self, create_mock: Mock, upload_mock: Mock, repo_info_mock: Mock, create_branch_mock: Mock ) -> None: repo_info_mock.side_effect = RevisionNotFoundError("revision not found") with SoftTemporaryDirectory() as cache_dir: file_path = Path(cache_dir) / "file.txt" file_path.write_text("content") cmd = UploadCommand( self.parser.parse_args( ["upload", "my-model", str(file_path), "logs/file.txt", "--revision", "my-branch"] ) ) cmd.run() # Revision specified => check that it exists repo_info_mock.assert_called_once_with( repo_id=create_mock.return_value.repo_id, repo_type="model", revision="my-branch" ) # Revision does not exist => create it create_branch_mock.assert_called_once_with( repo_id=create_mock.return_value.repo_id, repo_type="model", branch="my-branch", exist_ok=True ) @patch("huggingface_hub.commands.upload.HfApi.repo_info") @patch("huggingface_hub.commands.upload.HfApi.upload_file") @patch("huggingface_hub.commands.upload.HfApi.create_repo") def test_upload_file_revision_and_create_pr_mock( self, create_mock: Mock, upload_mock: Mock, repo_info_mock: Mock ) -> None: with SoftTemporaryDirectory() as cache_dir: file_path = Path(cache_dir) / "file.txt" file_path.write_text("content") cmd = UploadCommand( self.parser.parse_args( ["upload", "my-model", str(file_path), "logs/file.txt", "--revision", "my-branch", "--create-pr"] ) ) cmd.run() # Revision specified but --create-pr => no need to check repo_info_mock.assert_not_called() @patch("huggingface_hub.commands.upload.HfApi.create_repo") def test_upload_missing_path(self, create_mock: Mock) -> None: cmd = UploadCommand(self.parser.parse_args(["upload", "my-model", "/path/to/missing_file", "logs/file.txt"])) with self.assertRaises(FileNotFoundError): cmd.run() # File/folder does not exist locally # Repo creation happens before the check create_mock.assert_not_called() class TestDownloadCommand(unittest.TestCase): def setUp(self) -> None: """ Set up CLI as in `src/huggingface_hub/commands/huggingface_cli.py`. """ self.parser = ArgumentParser("huggingface-cli", usage="huggingface-cli []") commands_parser = self.parser.add_subparsers() DownloadCommand.register_subcommand(commands_parser) def test_download_basic(self) -> None: """Test `huggingface-cli download dummy-repo`.""" args = self.parser.parse_args(["download", DUMMY_MODEL_ID]) self.assertEqual(args.repo_id, DUMMY_MODEL_ID) self.assertEqual(len(args.filenames), 0) self.assertEqual(args.repo_type, "model") self.assertIsNone(args.revision) self.assertIsNone(args.include) self.assertIsNone(args.exclude) self.assertIsNone(args.cache_dir) self.assertIsNone(args.local_dir) self.assertFalse(args.force_download) self.assertFalse(args.resume_download) self.assertIsNone(args.token) self.assertFalse(args.quiet) self.assertEqual(args.func, DownloadCommand) def test_download_with_all_options(self) -> None: """Test `huggingface-cli download dummy-repo` with all options selected.""" args = self.parser.parse_args( [ "download", DUMMY_MODEL_ID, "--repo-type", "dataset", "--revision", "v1.0.0", "--include", "*.json", "*.yaml", "--exclude", "*.log", "*.txt", "--force-download", "--cache-dir", "/tmp", "--resume-download", "--token", "my-token", "--quiet", "--local-dir", ".", "--max-workers", "4", ] ) self.assertEqual(args.repo_id, DUMMY_MODEL_ID) self.assertEqual(args.repo_type, "dataset") self.assertEqual(args.revision, "v1.0.0") self.assertEqual(args.include, ["*.json", "*.yaml"]) self.assertEqual(args.exclude, ["*.log", "*.txt"]) self.assertTrue(args.force_download) self.assertEqual(args.cache_dir, "/tmp") self.assertEqual(args.local_dir, ".") self.assertTrue(args.resume_download) self.assertEqual(args.token, "my-token") self.assertTrue(args.quiet) self.assertEqual(args.max_workers, 4) self.assertEqual(args.func, DownloadCommand) @patch("huggingface_hub.commands.download.hf_hub_download") def test_download_file_from_revision(self, mock: Mock) -> None: args = Namespace( token="hf_****", repo_id="author/dataset", filenames=["README.md"], repo_type="dataset", revision="refs/pr/1", include=None, exclude=None, force_download=False, resume_download=False, cache_dir=None, local_dir=".", local_dir_use_symlinks=None, quiet=False, max_workers=8, ) # Output path is printed to terminal once run is completed with capture_output() as output: DownloadCommand(args).run() self.assertRegex(output.getvalue(), r"") mock.assert_called_once_with( repo_id="author/dataset", repo_type="dataset", revision="refs/pr/1", filename="README.md", cache_dir=None, resume_download=None, force_download=False, token="hf_****", local_dir=".", library_name="huggingface-cli", ) @patch("huggingface_hub.commands.download.snapshot_download") def test_download_multiple_files(self, mock: Mock) -> None: args = Namespace( token="hf_****", repo_id="author/model", filenames=["README.md", "config.json"], repo_type="model", revision=None, include=None, exclude=None, force_download=True, resume_download=True, cache_dir=None, local_dir="/path/to/dir", local_dir_use_symlinks=None, quiet=False, max_workers=8, ) DownloadCommand(args).run() # Use `snapshot_download` to ensure all files comes from same revision mock.assert_called_once_with( repo_id="author/model", repo_type="model", revision=None, allow_patterns=["README.md", "config.json"], ignore_patterns=None, resume_download=True, force_download=True, cache_dir=None, token="hf_****", local_dir="/path/to/dir", library_name="huggingface-cli", max_workers=8, ) @patch("huggingface_hub.commands.download.snapshot_download") def test_download_with_patterns(self, mock: Mock) -> None: args = Namespace( token=None, repo_id="author/model", filenames=[], repo_type="model", revision=None, include=["*.json"], exclude=["data/*"], force_download=True, resume_download=True, cache_dir=None, quiet=False, local_dir=None, local_dir_use_symlinks=None, max_workers=8, ) DownloadCommand(args).run() # Use `snapshot_download` to ensure all files comes from same revision mock.assert_called_once_with( repo_id="author/model", repo_type="model", revision=None, allow_patterns=["*.json"], ignore_patterns=["data/*"], resume_download=True, force_download=True, cache_dir=None, local_dir=None, token=None, library_name="huggingface-cli", max_workers=8, ) @patch("huggingface_hub.commands.download.snapshot_download") def test_download_with_ignored_patterns(self, mock: Mock) -> None: args = Namespace( token=None, repo_id="author/model", filenames=["README.md", "config.json"], repo_type="model", revision=None, include=["*.json"], exclude=["data/*"], force_download=True, resume_download=True, cache_dir=None, quiet=False, local_dir=None, local_dir_use_symlinks=None, max_workers=8, ) with self.assertWarns(UserWarning): # warns that patterns are ignored DownloadCommand(args).run() mock.assert_called_once_with( repo_id="author/model", repo_type="model", revision=None, allow_patterns=["README.md", "config.json"], # `filenames` has priority over the patterns ignore_patterns=None, # cleaned up resume_download=True, force_download=True, cache_dir=None, token=None, local_dir=None, library_name="huggingface-cli", max_workers=8, ) # Same but quiet (no warnings) args.quiet = True with warnings.catch_warnings(): # Taken from https://docs.pytest.org/en/latest/how-to/capture-warnings.html#additional-use-cases-of-warnings-in-tests warnings.simplefilter("error") DownloadCommand(args).run() class TestTagCommands(unittest.TestCase): def setUp(self) -> None: """ Set up CLI as in `src/huggingface_hub/commands/huggingface_cli.py`. """ self.parser = ArgumentParser("huggingface-cli", usage="huggingface-cli []") commands_parser = self.parser.add_subparsers() TagCommands.register_subcommand(commands_parser) def test_tag_create_basic(self) -> None: args = self.parser.parse_args(["tag", DUMMY_MODEL_ID, "1.0", "-m", "My tag message"]) self.assertEqual(args.repo_id, DUMMY_MODEL_ID) self.assertEqual(args.tag, "1.0") self.assertIsNotNone(args.message) self.assertIsNone(args.revision) self.assertIsNone(args.token) self.assertEqual(args.repo_type, "model") self.assertFalse(args.yes) def test_tag_create_with_all_options(self) -> None: args = self.parser.parse_args( [ "tag", DUMMY_MODEL_ID, "1.0", "--message", "My tag message", "--revision", "v1.0.0", "--token", "my-token", "--repo-type", "dataset", "--yes", ] ) self.assertEqual(args.repo_id, DUMMY_MODEL_ID) self.assertEqual(args.tag, "1.0") self.assertEqual(args.message, "My tag message") self.assertEqual(args.revision, "v1.0.0") self.assertEqual(args.token, "my-token") self.assertEqual(args.repo_type, "dataset") self.assertTrue(args.yes) def test_tag_list_basic(self) -> None: args = self.parser.parse_args(["tag", "--list", DUMMY_MODEL_ID]) self.assertEqual(args.repo_id, DUMMY_MODEL_ID) self.assertIsNone(args.token) self.assertEqual(args.repo_type, "model") def test_tag_delete_basic(self) -> None: args = self.parser.parse_args(["tag", "--delete", DUMMY_MODEL_ID, "1.0"]) self.assertEqual(args.repo_id, DUMMY_MODEL_ID) self.assertEqual(args.tag, "1.0") self.assertIsNone(args.token) self.assertEqual(args.repo_type, "model") self.assertFalse(args.yes) @contextmanager def tmp_current_directory() -> Generator[str, None, None]: """Change current directory to a tmp dir and revert back when exiting.""" with SoftTemporaryDirectory() as tmp_dir: cwd = os.getcwd() os.chdir(tmp_dir) try: yield tmp_dir except: raise finally: os.chdir(cwd) class TestRepoFilesCommand(unittest.TestCase): def setUp(self) -> None: """ Set up CLI as in `src/huggingface_hub/commands/huggingface_cli.py`. """ self.parser = ArgumentParser("huggingface-cli", usage="huggingface-cli []") commands_parser = self.parser.add_subparsers() RepoFilesCommand.register_subcommand(commands_parser) @patch("huggingface_hub.commands.repo_files.HfApi.delete_files") def test_delete(self, delete_files_mock: Mock) -> None: fixtures = [ { "input_args": [ "repo-files", DUMMY_MODEL_ID, "delete", "*", ], "delete_files_args": { "delete_patterns": [ "*", ], "repo_id": DUMMY_MODEL_ID, "repo_type": "model", "revision": None, "commit_message": None, "commit_description": None, "create_pr": False, }, }, { "input_args": [ "repo-files", DUMMY_MODEL_ID, "delete", "file.txt", ], "delete_files_args": { "delete_patterns": [ "file.txt", ], "repo_id": DUMMY_MODEL_ID, "repo_type": "model", "revision": None, "commit_message": None, "commit_description": None, "create_pr": False, }, }, { "input_args": [ "repo-files", DUMMY_MODEL_ID, "delete", "folder/", ], "delete_files_args": { "delete_patterns": [ "folder/", ], "repo_id": DUMMY_MODEL_ID, "repo_type": "model", "revision": None, "commit_message": None, "commit_description": None, "create_pr": False, }, }, { "input_args": [ "repo-files", DUMMY_MODEL_ID, "delete", "file1.txt", "folder/", "file2.txt", ], "delete_files_args": { "delete_patterns": [ "file1.txt", "folder/", "file2.txt", ], "repo_id": DUMMY_MODEL_ID, "repo_type": "model", "revision": None, "commit_message": None, "commit_description": None, "create_pr": False, }, }, { "input_args": [ "repo-files", DUMMY_MODEL_ID, "delete", "file.txt *", "*.json", "folder/*.parquet", ], "delete_files_args": { "delete_patterns": [ "file.txt *", "*.json", "folder/*.parquet", ], "repo_id": DUMMY_MODEL_ID, "repo_type": "model", "revision": None, "commit_message": None, "commit_description": None, "create_pr": False, }, }, { "input_args": [ "repo-files", DUMMY_MODEL_ID, "delete", "file.txt *", "--revision", "test_revision", "--repo-type", "dataset", "--commit-message", "My commit message", "--commit-description", "My commit description", "--create-pr", ], "delete_files_args": { "delete_patterns": [ "file.txt *", ], "repo_id": DUMMY_MODEL_ID, "repo_type": "dataset", "revision": "test_revision", "commit_message": "My commit message", "commit_description": "My commit description", "create_pr": True, }, }, ] for expected in fixtures: # subTest is similar to pytest.mark.parametrize, but using the unittest # framework with self.subTest(expected): delete_files_args = expected["delete_files_args"] cmd = DeleteFilesSubCommand(self.parser.parse_args(expected["input_args"])) cmd.run() if delete_files_args is None: assert delete_files_mock.call_count == 0 else: assert delete_files_mock.call_count == 1 # Inspect the captured calls _, kwargs = delete_files_mock.call_args_list[0] assert kwargs == delete_files_args delete_files_mock.reset_mock() huggingface_hub-0.31.1/tests/test_command_delete_cache.py000066400000000000000000000515461500667546600235710ustar00rootroot00000000000000import os import unittest from pathlib import Path from tempfile import mkstemp from unittest.mock import Mock, patch from InquirerPy.base.control import Choice from InquirerPy.separator import Separator from huggingface_hub.commands.delete_cache import ( _CANCEL_DELETION_STR, DeleteCacheCommand, _ask_for_confirmation_no_tui, _get_expectations_str, _get_tui_choices_from_scan, _manual_review_no_tui, _read_manual_review_tmp_file, ) from huggingface_hub.utils import SoftTemporaryDirectory, capture_output from .testing_utils import handle_injection class TestDeleteCacheHelpers(unittest.TestCase): def test_get_tui_choices_from_scan_empty(self) -> None: choices = _get_tui_choices_from_scan(repos={}, preselected=[], sort_by=None) self.assertEqual(len(choices), 1) self.assertIsInstance(choices[0], Choice) self.assertEqual(choices[0].value, _CANCEL_DELETION_STR) self.assertTrue(len(choices[0].name) != 0) # Something displayed to the user self.assertFalse(choices[0].enabled) def test_get_tui_choices_from_scan_with_preselection(self) -> None: choices = _get_tui_choices_from_scan( repos=_get_cache_mock().repos, preselected=[ "dataset_revision_hash_id", # dataset_1 is preselected "a_revision_id_that_does_not_exist", # unknown but will not complain "older_hash_id", # only the oldest revision from model_2 ], sort_by=None, # Don't sort to maintain original order ) self.assertEqual(len(choices), 8) # Item to cancel everything self.assertIsInstance(choices[0], Choice) self.assertEqual(choices[0].value, _CANCEL_DELETION_STR) self.assertTrue(len(choices[0].name) != 0) self.assertFalse(choices[0].enabled) # Dataset repo separator self.assertIsInstance(choices[1], Separator) self.assertEqual(choices[1]._line, "\nDataset dummy_dataset (8M, used 2 weeks ago)") # Only revision of `dummy_dataset` self.assertIsInstance(choices[2], Choice) self.assertEqual(choices[2].value, "dataset_revision_hash_id") self.assertEqual( choices[2].name, # truncated hash id + detached + last modified "dataset_: (detached) # modified 1 day ago", ) self.assertTrue(choices[2].enabled) # preselected # Model `dummy_model` separator self.assertIsInstance(choices[3], Separator) self.assertEqual(choices[3]._line, "\nModel dummy_model (1.4K, used 2 years ago)") # Recent revision of `dummy_model` (appears first due to sorting by last_modified) self.assertIsInstance(choices[4], Choice) self.assertEqual(choices[4].value, "recent_hash_id") self.assertEqual(choices[4].name, "recent_h: main # modified 2 years ago") self.assertFalse(choices[4].enabled) # Oldest revision of `dummy_model` self.assertIsInstance(choices[5], Choice) self.assertEqual(choices[5].value, "older_hash_id") self.assertEqual(choices[5].name, "older_ha: (detached) # modified 3 years ago") self.assertTrue(choices[5].enabled) # preselected # Model `gpt2` separator self.assertIsInstance(choices[6], Separator) self.assertEqual(choices[6]._line, "\nModel gpt2 (3.6G, used 2 hours ago)") # Only revision of `gpt2` self.assertIsInstance(choices[7], Choice) self.assertEqual(choices[7].value, "abcdef123456789") self.assertEqual(choices[7].name, "abcdef12: main, refs/pr/1 # modified 2 years ago") self.assertFalse(choices[7].enabled) def test_get_tui_choices_from_scan_with_sort_size(self) -> None: """Test sorting by size.""" choices = _get_tui_choices_from_scan(repos=_get_cache_mock().repos, preselected=[], sort_by="size") # Verify repo order: gpt2 (3.6G) -> dummy_dataset (8M) -> dummy_model (1.4K) self.assertIsInstance(choices[1], Separator) self.assertIn("gpt2", choices[1]._line) self.assertIsInstance(choices[3], Separator) self.assertIn("dummy_dataset", choices[3]._line) self.assertIsInstance(choices[5], Separator) self.assertIn("dummy_model", choices[5]._line) def test_get_expectations_str_on_no_deletion_item(self) -> None: """Test `_get_instructions` when `_CANCEL_DELETION_STR` is passed.""" self.assertEqual( _get_expectations_str( hf_cache_info=Mock(), selected_hashes=["hash_1", _CANCEL_DELETION_STR, "hash_2"], ), "Nothing will be deleted.", ) def test_get_expectations_str_with_selection(self) -> None: """Test `_get_instructions` with 2 revisions selected.""" strategy_mock = Mock() strategy_mock.expected_freed_size_str = "5.1M" cache_mock = Mock() cache_mock.delete_revisions.return_value = strategy_mock self.assertEqual( _get_expectations_str( hf_cache_info=cache_mock, selected_hashes=["hash_1", "hash_2"], ), "2 revisions selected counting for 5.1M.", ) cache_mock.delete_revisions.assert_called_once_with("hash_1", "hash_2") def test_read_manual_review_tmp_file(self) -> None: """Test `_read_manual_review_tmp_file`.""" with SoftTemporaryDirectory() as tmp_dir: tmp_path = Path(tmp_dir) / "file.txt" with tmp_path.open("w") as f: f.writelines( [ "# something commented out\n", "###\n", "\n\n\n\n", # some empty lines " # Something commented out after spaces\n", "a_revision_hash\n", "a_revision_hash_with_a_comment # 2 years ago\n", " a_revision_hash_after_spaces\n", " a_revision_hash_with_a_comment_after_spaces # 2years ago\n", " # hash_commented_out # 2 years ago\n", "a_revision_hash\n", # Duplicate "", # empty line ] ) # Only non-commented lines are returned # Order is kept and lines are not de-duplicated self.assertListEqual( _read_manual_review_tmp_file(tmp_path), [ "a_revision_hash", "a_revision_hash_with_a_comment", "a_revision_hash_after_spaces", "a_revision_hash_with_a_comment_after_spaces", "a_revision_hash", ], ) @patch("huggingface_hub.commands.delete_cache.input") @patch("huggingface_hub.commands.delete_cache.mkstemp") def test_manual_review_no_tui(self, mock_mkstemp: Mock, mock_input: Mock) -> None: # Mock file creation so that we know the file location in test fd, tmp_path = mkstemp() mock_mkstemp.return_value = fd, tmp_path # Mock cache cache_mock = _get_cache_mock() # Mock input from user def _input_answers(): self.assertTrue(os.path.isfile(tmp_path)) # not deleted yet with open(tmp_path) as f: content = f.read() self.assertTrue(content.startswith("# INSTRUCTIONS")) # older_hash_id is not commented self.assertIn("\n older_hash_id # Refs: (detached)", content) # same for abcdef123456789 self.assertIn("\n abcdef123456789 # Refs: main, refs/pr/1", content) # dataset revision is not preselected self.assertIn("# dataset_revision_hash_id", content) # same for recent_hash_id self.assertIn("# recent_hash_id", content) # Select dataset revision content = content.replace("# dataset_revision_hash_id", "dataset_revision_hash_id") # Deselect abcdef123456789 content = content.replace("abcdef123456789", "# abcdef123456789") with open(tmp_path, "w") as f: f.write(content) yield "no" # User edited the file and want to see the strategy diff yield "y" # User confirms mock_input.side_effect = _input_answers() # Run manual review with capture_output() as output: selected_hashes = _manual_review_no_tui( hf_cache_info=cache_mock, preselected=["abcdef123456789", "older_hash_id"], sort_by=None ) # Tmp file has been created but is now deleted mock_mkstemp.assert_called_once_with(suffix=".txt") self.assertFalse(os.path.isfile(tmp_path)) # now deleted # User changed the selection self.assertListEqual(selected_hashes, ["dataset_revision_hash_id", "older_hash_id"]) # Check printed instructions printed = output.getvalue() self.assertTrue(printed.startswith("TUI is disabled. In order to")) # ... self.assertIn(tmp_path, printed) # Check input called twice self.assertEqual(mock_input.call_count, 2) @patch("huggingface_hub.commands.delete_cache.input") def test_ask_for_confirmation_no_tui(self, mock_input: Mock) -> None: """Test `_ask_for_confirmation_no_tui`.""" # Answer yes mock_input.side_effect = ("y",) value = _ask_for_confirmation_no_tui("custom message 1", default=True) mock_input.assert_called_with("custom message 1 (Y/n) ") self.assertTrue(value) # Answer no mock_input.side_effect = ("NO",) value = _ask_for_confirmation_no_tui("custom message 2", default=True) mock_input.assert_called_with("custom message 2 (Y/n) ") self.assertFalse(value) # Answer invalid, then default mock_input.side_effect = ("foo", "") with capture_output() as output: value = _ask_for_confirmation_no_tui("custom message 3", default=False) mock_input.assert_called_with("custom message 3 (y/N) ") self.assertFalse(value) self.assertEqual( output.getvalue(), "Invalid input. Must be one of ('y', 'yes', '1', 'n', 'no', '0', '')\n", ) def test_get_tui_choices_from_scan_with_different_sorts(self) -> None: """Test different sorting modes.""" cache_mock = _get_cache_mock() # Test size sorting (largest first) - order: gpt2 (3.6G) -> dummy_dataset (8M) -> dummy_model (1.4K) size_choices = _get_tui_choices_from_scan(cache_mock.repos, [], sort_by="size") # Separators at positions 1, 3, 5 self.assertIsInstance(size_choices[1], Separator) self.assertIn("gpt2", size_choices[1]._line) self.assertIsInstance(size_choices[3], Separator) self.assertIn("dummy_dataset", size_choices[3]._line) self.assertIsInstance(size_choices[5], Separator) self.assertIn("dummy_model", size_choices[5]._line) # Test alphabetical sorting - order: dummy_dataset -> dummy_model -> gpt2 alpha_choices = _get_tui_choices_from_scan(cache_mock.repos, [], sort_by="alphabetical") # Separators at positions 1, 3, 6 (dummy_model has 2 revisions) self.assertIsInstance(alpha_choices[1], Separator) self.assertIn("dummy_dataset", alpha_choices[1]._line) self.assertIsInstance(alpha_choices[3], Separator) self.assertIn("dummy_model", alpha_choices[3]._line) self.assertIsInstance(alpha_choices[6], Separator) self.assertIn("gpt2", alpha_choices[6]._line) # Test lastUpdated sorting - order: dummy_dataset (1 day) -> gpt2 (2 years) -> dummy_model (3 years) updated_choices = _get_tui_choices_from_scan(cache_mock.repos, [], sort_by="lastUpdated") # Separators at positions 1, 3, 5 self.assertIsInstance(updated_choices[1], Separator) self.assertIn("dummy_dataset", updated_choices[1]._line) self.assertIsInstance(updated_choices[3], Separator) self.assertIn("gpt2", updated_choices[3]._line) self.assertIsInstance(updated_choices[5], Separator) self.assertIn("dummy_model", updated_choices[5]._line) # Test lastUsed sorting - order: gpt2 (2h) -> dummy_dataset (2w) -> dummy_model (2y) used_choices = _get_tui_choices_from_scan(cache_mock.repos, [], sort_by="lastUsed") # Separators at positions 1, 3, 5 self.assertIsInstance(used_choices[1], Separator) self.assertIn("gpt2", used_choices[1]._line) self.assertIsInstance(used_choices[3], Separator) self.assertIn("dummy_dataset", used_choices[3]._line) self.assertIsInstance(used_choices[5], Separator) self.assertIn("dummy_model", used_choices[5]._line) @patch("huggingface_hub.commands.delete_cache._ask_for_confirmation_no_tui") @patch("huggingface_hub.commands.delete_cache._get_expectations_str") @patch("huggingface_hub.commands.delete_cache.inquirer.confirm") @patch("huggingface_hub.commands.delete_cache._manual_review_tui") @patch("huggingface_hub.commands.delete_cache._manual_review_no_tui") @patch("huggingface_hub.commands.delete_cache.scan_cache_dir") @handle_injection class TestMockedDeleteCacheCommand(unittest.TestCase): """Test case with a patched `DeleteCacheCommand` to test `.run()` without testing the manual review. """ args: Mock command: DeleteCacheCommand def setUp(self) -> None: self.args = Mock() self.args.sort = None self.command = DeleteCacheCommand(self.args) def test_run_and_delete_with_tui( self, mock_scan_cache_dir: Mock, mock__manual_review_tui: Mock, mock__get_expectations_str: Mock, mock_confirm: Mock, ) -> None: """Test command run with a mocked manual review step.""" # Mock return values mock__manual_review_tui.return_value = ["hash_1", "hash_2"] mock__get_expectations_str.return_value = "Will delete A and B." mock_confirm.return_value.execute.return_value = True mock_scan_cache_dir.return_value = _get_cache_mock() # Run self.command.disable_tui = False with capture_output() as output: self.command.run() # Step 1: scan mock_scan_cache_dir.assert_called_once_with(self.args.dir) cache_mock = mock_scan_cache_dir.return_value # Step 2: manual review mock__manual_review_tui.assert_called_once_with(cache_mock, preselected=[], sort_by=None) # Step 3: ask confirmation mock__get_expectations_str.assert_called_once_with(cache_mock, ["hash_1", "hash_2"]) mock_confirm.assert_called_once_with("Will delete A and B. Confirm deletion ?", default=True) mock_confirm().execute.assert_called_once_with() # Step 4: delete cache_mock.delete_revisions.assert_called_once_with("hash_1", "hash_2") strategy_mock = cache_mock.delete_revisions.return_value strategy_mock.execute.assert_called_once_with() # Check output self.assertEqual( output.getvalue(), "Start deletion.\nDone. Deleted 0 repo(s) and 0 revision(s) for a total of 5.1M.\n", ) def test_run_nothing_selected_with_tui(self, mock__manual_review_tui: Mock) -> None: """Test command run but nothing is selected in manual review.""" # Mock return value mock__manual_review_tui.return_value = [] # Run self.command.disable_tui = False with capture_output() as output: self.command.run() # Check output self.assertEqual(output.getvalue(), "Deletion is cancelled. Do nothing.\n") def test_run_stuff_selected_but_cancel_item_as_well_with_tui(self, mock__manual_review_tui: Mock) -> None: """Test command run when some are selected but "cancel item" as well.""" # Mock return value mock__manual_review_tui.return_value = [ "hash_1", "hash_2", _CANCEL_DELETION_STR, ] # Run self.command.disable_tui = False with capture_output() as output: self.command.run() # Check output self.assertEqual(output.getvalue(), "Deletion is cancelled. Do nothing.\n") def test_run_and_delete_no_tui( self, mock_scan_cache_dir: Mock, mock__manual_review_no_tui: Mock, mock__get_expectations_str: Mock, mock__ask_for_confirmation_no_tui: Mock, ) -> None: """Test command run with a mocked manual review step.""" # Mock return values mock__manual_review_no_tui.return_value = ["hash_1", "hash_2"] mock__get_expectations_str.return_value = "Will delete A and B." mock__ask_for_confirmation_no_tui.return_value.return_value = True mock_scan_cache_dir.return_value = _get_cache_mock() # Run self.command.disable_tui = True with capture_output() as output: self.command.run() # Step 1: scan mock_scan_cache_dir.assert_called_once_with(self.args.dir) cache_mock = mock_scan_cache_dir.return_value # Step 2: manual review mock__manual_review_no_tui.assert_called_once_with(cache_mock, preselected=[], sort_by=None) # Step 3: ask confirmation mock__get_expectations_str.assert_called_once_with(cache_mock, ["hash_1", "hash_2"]) mock__ask_for_confirmation_no_tui.assert_called_once_with("Will delete A and B. Confirm deletion ?") # Step 4: delete cache_mock.delete_revisions.assert_called_once_with("hash_1", "hash_2") strategy_mock = cache_mock.delete_revisions.return_value strategy_mock.execute.assert_called_once_with() # Check output self.assertEqual( output.getvalue(), "Start deletion.\nDone. Deleted 0 repo(s) and 0 revision(s) for a total of 5.1M.\n", ) def test_run_with_sorting(self): """Test command run with sorting enabled.""" self.args.sort = "size" self.command = DeleteCacheCommand(self.args) mock_scan_cache_dir = Mock() mock_scan_cache_dir.return_value = _get_cache_mock() with patch("huggingface_hub.commands.delete_cache.scan_cache_dir", mock_scan_cache_dir): with patch("huggingface_hub.commands.delete_cache._manual_review_tui") as mock_review: self.command.disable_tui = False self.command.run() mock_review.assert_called_once_with(mock_scan_cache_dir.return_value, preselected=[], sort_by="size") def _get_cache_mock() -> Mock: # First model with 1 revision model_1 = Mock() model_1.repo_type = "model" model_1.repo_id = "gpt2" model_1.size_on_disk_str = "3.6G" model_1.last_accessed = 1660000000 model_1.last_accessed_str = "2 hours ago" model_1.size_on_disk = 3.6 * 1024**3 # 3.6 GiB model_1_revision_1 = Mock() model_1_revision_1.commit_hash = "abcdef123456789" model_1_revision_1.refs = {"main", "refs/pr/1"} model_1_revision_1.last_modified = 123456789000 # 2 years ago model_1_revision_1.last_modified_str = "2 years ago" model_1.revisions = {model_1_revision_1} # Second model with 2 revisions model_2 = Mock() model_2.repo_type = "model" model_2.repo_id = "dummy_model" model_2.size_on_disk_str = "1.4K" model_2.last_accessed = 1550000000 model_2.last_accessed_str = "2 years ago" model_2.size_on_disk = 1.4 * 1024 # 1.4K model_2_revision_1 = Mock() model_2_revision_1.commit_hash = "recent_hash_id" model_2_revision_1.refs = {"main"} model_2_revision_1.last_modified = 123456789 # 2 years ago model_2_revision_1.last_modified_str = "2 years ago" model_2_revision_2 = Mock() model_2_revision_2.commit_hash = "older_hash_id" model_2_revision_2.refs = {} model_2_revision_2.last_modified = 12345678000 # 3 years ago model_2_revision_2.last_modified_str = "3 years ago" model_2.revisions = {model_2_revision_1, model_2_revision_2} # And a dataset with 1 revision dataset_1 = Mock() dataset_1.repo_type = "dataset" dataset_1.repo_id = "dummy_dataset" dataset_1.size_on_disk_str = "8M" dataset_1.last_accessed = 1659000000 dataset_1.last_accessed_str = "2 weeks ago" dataset_1.size_on_disk = 8 * 1024**2 # 8 MiB dataset_1_revision_1 = Mock() dataset_1_revision_1.commit_hash = "dataset_revision_hash_id" dataset_1_revision_1.refs = {} dataset_1_revision_1.last_modified = 1234567890000 # 1 day ago (newest) dataset_1_revision_1.last_modified_str = "1 day ago" dataset_1.revisions = {dataset_1_revision_1} # Fake cache strategy_mock = Mock() strategy_mock.repos = [] strategy_mock.snapshots = [] strategy_mock.expected_freed_size_str = "5.1M" cache_mock = Mock() cache_mock.repos = {model_1, model_2, dataset_1} cache_mock.delete_revisions.return_value = strategy_mock return cache_mock huggingface_hub-0.31.1/tests/test_commit_api.py000066400000000000000000000141351500667546600216200ustar00rootroot00000000000000import unittest from huggingface_hub._commit_api import ( CommitOperationAdd, CommitOperationDelete, _warn_on_overwriting_operations, ) class TestCommitOperationDelete(unittest.TestCase): def test_implicit_file(self): self.assertFalse(CommitOperationDelete(path_in_repo="path/to/file").is_folder) self.assertFalse(CommitOperationDelete(path_in_repo="path/to/file.md").is_folder) def test_implicit_folder(self): self.assertTrue(CommitOperationDelete(path_in_repo="path/to/folder/").is_folder) self.assertTrue(CommitOperationDelete(path_in_repo="path/to/folder.md/").is_folder) def test_explicit_file(self): # Weird case: if user explicitly set as file (`is_folder`=False) but path has a # trailing "/" => user input has priority self.assertFalse(CommitOperationDelete(path_in_repo="path/to/folder/", is_folder=False).is_folder) self.assertFalse(CommitOperationDelete(path_in_repo="path/to/folder.md/", is_folder=False).is_folder) def test_explicit_folder(self): # No need for the trailing "/" is `is_folder` explicitly passed self.assertTrue(CommitOperationDelete(path_in_repo="path/to/folder", is_folder=True).is_folder) self.assertTrue(CommitOperationDelete(path_in_repo="path/to/folder.md", is_folder=True).is_folder) def test_is_folder_wrong_value(self): with self.assertRaises(ValueError): CommitOperationDelete(path_in_repo="path/to/folder", is_folder="any value") class TestCommitOperationPathInRepo(unittest.TestCase): valid_values = { # key is input, value is expected validated output "file.txt": "file.txt", ".file.txt": ".file.txt", "/file.txt": "file.txt", "./file.txt": "file.txt", } invalid_values = [".", "..", "../file.txt"] def test_path_in_repo_valid(self) -> None: for input, expected in self.valid_values.items(): with self.subTest(f"Testing with valid input: '{input}'"): self.assertEqual(CommitOperationAdd(path_in_repo=input, path_or_fileobj=b"").path_in_repo, expected) self.assertEqual(CommitOperationDelete(path_in_repo=input).path_in_repo, expected) def test_path_in_repo_invalid(self) -> None: for input in self.invalid_values: with self.subTest(f"Testing with invalid input: '{input}'"): with self.assertRaises(ValueError): CommitOperationAdd(path_in_repo=input, path_or_fileobj=b"") with self.assertRaises(ValueError): CommitOperationDelete(path_in_repo=input) class TestCommitOperationForbiddenPathInRepo(unittest.TestCase): """Commit operations must throw an error on files in the .git/ or .cache/huggingface/ folders. Server would error anyway so it's best to prevent early. """ INVALID_PATHS_IN_REPO = { ".git", ".git/path/to/file", "./.git/path/to/file", "subfolder/path/.git/to/file", "./subfolder/path/.git/to/file", ".cache/huggingface", "./.cache/huggingface/path/to/file", "./subfolder/path/.cache/huggingface/to/file", } VALID_PATHS_IN_REPO = { ".gitignore", "path/to/.gitignore", "path/to/something.git", "path/to/something.git/more", "path/to/something.huggingface/more", "huggingface", ".huggingface", "./.huggingface/path/to/file", "./subfolder/path/huggingface/to/file", "./subfolder/path/.huggingface/to/file", } def test_cannot_update_file_in_git_folder(self): for path in self.INVALID_PATHS_IN_REPO: with self.subTest(msg=f"Add: '{path}'"): with self.assertRaises(ValueError): CommitOperationAdd(path_in_repo=path, path_or_fileobj=b"content") with self.subTest(msg=f"Delete: '{path}'"): with self.assertRaises(ValueError): CommitOperationDelete(path_in_repo=path) def test_valid_path_in_repo_containing_git(self): for path in self.VALID_PATHS_IN_REPO: with self.subTest(msg=f"Add: '{path}'"): CommitOperationAdd(path_in_repo=path, path_or_fileobj=b"content") with self.subTest(msg=f"Delete: '{path}'"): CommitOperationDelete(path_in_repo=path) class TestWarnOnOverwritingOperations(unittest.TestCase): add_file_ab = CommitOperationAdd(path_in_repo="a/b.txt", path_or_fileobj=b"data") add_file_abc = CommitOperationAdd(path_in_repo="a/b/c.md", path_or_fileobj=b"data") add_file_abd = CommitOperationAdd(path_in_repo="a/b/d.md", path_or_fileobj=b"data") update_file_abc = CommitOperationAdd(path_in_repo="a/b/c.md", path_or_fileobj=b"updated data") delete_file_abc = CommitOperationDelete(path_in_repo="a/b/c.md") delete_folder_a = CommitOperationDelete(path_in_repo="a/") delete_folder_e = CommitOperationDelete(path_in_repo="e/") def test_no_overwrite(self) -> None: _warn_on_overwriting_operations( [ self.add_file_ab, self.add_file_abc, self.add_file_abd, self.delete_folder_e, ] ) def test_add_then_update_file(self) -> None: with self.assertWarns(UserWarning): _warn_on_overwriting_operations([self.add_file_abc, self.update_file_abc]) def test_add_then_delete_file(self) -> None: with self.assertWarns(UserWarning): _warn_on_overwriting_operations([self.add_file_abc, self.delete_file_abc]) def test_add_then_delete_folder(self) -> None: with self.assertWarns(UserWarning): _warn_on_overwriting_operations([self.add_file_abc, self.delete_folder_a]) with self.assertWarns(UserWarning): _warn_on_overwriting_operations([self.add_file_ab, self.delete_folder_a]) def test_delete_file_then_add(self) -> None: _warn_on_overwriting_operations([self.delete_file_abc, self.add_file_abc]) def test_delete_folder_then_add(self) -> None: _warn_on_overwriting_operations([self.delete_folder_a, self.add_file_ab, self.add_file_abc]) huggingface_hub-0.31.1/tests/test_commit_scheduler.py000066400000000000000000000244511500667546600230270ustar00rootroot00000000000000import time import unittest from io import SEEK_END from pathlib import Path from unittest.mock import MagicMock, patch import pytest from huggingface_hub import CommitOperationAdd, HfApi, hf_hub_download from huggingface_hub._commit_scheduler import CommitScheduler, PartialFileIO from .testing_constants import ENDPOINT_STAGING, TOKEN from .testing_utils import repo_name @pytest.mark.usefixtures("fx_cache_dir") class TestCommitScheduler(unittest.TestCase): cache_dir: Path def setUp(self) -> None: self.api = HfApi(token=TOKEN, endpoint=ENDPOINT_STAGING) self.repo_name = repo_name() def tearDown(self) -> None: try: # try stopping scheduler (if exists) self.scheduler.stop() except AttributeError: pass try: # try delete temporary repo self.api.delete_repo(self.repo_name) except Exception: pass @patch("huggingface_hub._commit_scheduler.CommitScheduler.push_to_hub") def test_mocked_push_to_hub(self, push_to_hub_mock: MagicMock) -> None: self.scheduler = CommitScheduler( folder_path=self.cache_dir, repo_id=self.repo_name, every=1 / 60 / 10, # every 0.1s hf_api=self.api, ) time.sleep(0.5) # Triggered at least twice (at 0.0s and then 0.1s, 0.2s,...) self.assertGreater(len(push_to_hub_mock.call_args_list), 2) # Can get the last upload result self.assertEqual(self.scheduler.last_future.result(), push_to_hub_mock.return_value) def test_invalid_folder_path_is_a_file(self) -> None: """Test cannot scheduler upload of a single file.""" file_path = self.cache_dir / "file.txt" file_path.write_text("something") with self.assertRaises(ValueError): CommitScheduler(folder_path=file_path, repo_id=self.repo_name, hf_api=self.api) def test_missing_folder_is_created(self) -> None: folder_path = self.cache_dir / "folder" / "subfolder" self.scheduler = CommitScheduler(folder_path=folder_path, repo_id=self.repo_name, hf_api=self.api) self.assertTrue(folder_path.is_dir()) def test_sync_local_folder(self) -> None: """Test sync local folder to remote repo.""" watched_folder = self.cache_dir / "watched_folder" hub_cache = self.cache_dir / "hub" # to download hub files file_path = watched_folder / "file.txt" lfs_path = watched_folder / "lfs.bin" self.scheduler = CommitScheduler( folder_path=watched_folder, repo_id=self.repo_name, every=1 / 60, # every 1s hf_api=self.api, ) # 1 push to hub triggered (empty commit not pushed) time.sleep(0.5) # write content to files with file_path.open("a") as f: f.write("first line\n") with lfs_path.open("a") as f: f.write("binary content") # 2 push to hub triggered (1 commit + 1 ignored) time.sleep(2) self.scheduler.last_future.result() # new content in file with file_path.open("a") as f: f.write("second line\n") # 1 push to hub triggered (1 commit) time.sleep(1) self.scheduler.last_future.result() with lfs_path.open("a") as f: f.write(" updated") # 5 push to hub triggered (1 commit) time.sleep(5) # wait for every threads/uploads to complete self.scheduler.stop() self.scheduler.last_future.result() # 4 commits expected (initial commit + 3 push to hub) repo_id = self.scheduler.repo_id commits = self.api.list_repo_commits(repo_id) self.assertEqual(len(commits), 4) push_1 = commits[2].commit_id # sorted by last first push_2 = commits[1].commit_id push_3 = commits[0].commit_id def _download(filename: str, revision: str) -> Path: return Path(hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=hub_cache, revision=revision)) # Check file.txt consistency file_push1 = _download(filename="file.txt", revision=push_1) file_push2 = _download(filename="file.txt", revision=push_2) file_push3 = _download(filename="file.txt", revision=push_3) self.assertEqual(file_push1.read_text(), "first line\n") self.assertEqual(file_push2.read_text(), "first line\nsecond line\n") self.assertEqual(file_push3.read_text(), "first line\nsecond line\n") # Check lfs.bin consistency lfs_push1 = _download(filename="lfs.bin", revision=push_1) lfs_push2 = _download(filename="lfs.bin", revision=push_2) lfs_push3 = _download(filename="lfs.bin", revision=push_3) self.assertEqual(lfs_push1.read_text(), "binary content") self.assertEqual(lfs_push2.read_text(), "binary content") self.assertEqual(lfs_push3.read_text(), "binary content updated") def test_sync_and_squash_history(self) -> None: """Test squash history when pushing to the Hub.""" watched_folder = self.cache_dir / "watched_folder" watched_folder.mkdir(exist_ok=True, parents=True) file_path = watched_folder / "file.txt" with file_path.open("a") as f: f.write("first line\n") self.scheduler = CommitScheduler( folder_path=watched_folder, repo_id=self.repo_name, every=1 / 60 / 10, # every 0.1s hf_api=self.api, squash_history=True, ) # At least 1 push to hub triggered time.sleep(0.5) self.scheduler.stop() self.scheduler.last_future.result() # Branch history has been squashed commits = self.api.list_repo_commits(repo_id=self.scheduler.repo_id) self.assertEqual(len(commits), 1) self.assertEqual(commits[0].title, "Super-squash branch 'main' using huggingface_hub") def test_context_manager(self) -> None: watched_folder = self.cache_dir / "watched_folder" watched_folder.mkdir(exist_ok=True, parents=True) file_path = watched_folder / "file.txt" with CommitScheduler( folder_path=watched_folder, repo_id=self.repo_name, every=5, # every 5min hf_api=self.api, ) as scheduler: with file_path.open("w") as f: f.write("first line\n") assert "file.txt" in self.api.list_repo_files(scheduler.repo_id) assert scheduler._CommitScheduler__stopped # means the scheduler has been stopped when exiting the context @pytest.mark.usefixtures("fx_cache_dir") class TestPartialFileIO(unittest.TestCase): """Test PartialFileIO object.""" cache_dir: Path def setUp(self) -> None: """Set up a test file.""" self.file_path = self.cache_dir / "file.txt" self.file_path.write_text("123456789") # file size: 9 bytes def test_read_partial_file_twice(self) -> None: file = PartialFileIO(self.file_path, size_limit=5) self.assertEqual(file.read(), b"12345") self.assertEqual(file.read(), b"") # End of file def test_read_partial_file_by_chunks(self) -> None: file = PartialFileIO(self.file_path, size_limit=5) self.assertEqual(file.read(2), b"12") self.assertEqual(file.read(2), b"34") self.assertEqual(file.read(2), b"5") self.assertEqual(file.read(2), b"") def test_read_partial_file_too_much(self) -> None: file = PartialFileIO(self.file_path, size_limit=5) self.assertEqual(file.read(20), b"12345") def test_partial_file_len(self) -> None: """Useful for `requests` internally.""" file = PartialFileIO(self.file_path, size_limit=5) self.assertEqual(len(file), 5) file = PartialFileIO(self.file_path, size_limit=50) self.assertEqual(len(file), 9) def test_partial_file_seek_and_tell(self) -> None: file = PartialFileIO(self.file_path, size_limit=5) self.assertEqual(file.tell(), 0) file.read(2) self.assertEqual(file.tell(), 2) file.seek(0) self.assertEqual(file.tell(), 0) file.seek(2) self.assertEqual(file.tell(), 2) file.seek(50) self.assertEqual(file.tell(), 5) file.seek(-3, SEEK_END) self.assertEqual(file.tell(), 2) # 5-3 def test_methods_not_implemented(self) -> None: """Test `PartialFileIO` only implements a subset of the `io` interface. This is on-purpose to avoid misuse.""" file = PartialFileIO(self.file_path, size_limit=5) with self.assertRaises(NotImplementedError): file.readline() with self.assertRaises(NotImplementedError): file.write(b"123") def test_append_to_file_then_read(self) -> None: file = PartialFileIO(self.file_path, size_limit=9) with self.file_path.open("ab") as f: f.write(b"abcdef") # Output is truncated even if new content appended to the wrapped file self.assertEqual(file.read(), b"123456789") def test_high_size_limit(self) -> None: file = PartialFileIO(self.file_path, size_limit=20) with self.file_path.open("ab") as f: f.write(b"abcdef") # File size limit is truncated to the actual file size at instance creation (not on the fly) self.assertEqual(len(file), 9) self.assertEqual(file._size_limit, 9) def test_with_commit_operation_add(self) -> None: # Truncated file op_truncated = CommitOperationAdd( path_or_fileobj=PartialFileIO(self.file_path, size_limit=5), path_in_repo="file.txt" ) self.assertEqual(op_truncated.upload_info.size, 5) self.assertEqual(op_truncated.upload_info.sample, b"12345") with op_truncated.as_file() as f: self.assertEqual(f.read(), b"12345") # Full file op_full = CommitOperationAdd( path_or_fileobj=PartialFileIO(self.file_path, size_limit=9), path_in_repo="file.txt" ) self.assertEqual(op_full.upload_info.size, 9) self.assertEqual(op_full.upload_info.sample, b"123456789") with op_full.as_file() as f: self.assertEqual(f.read(), b"123456789") # Truncated file has a different hash than the full file self.assertNotEqual(op_truncated.upload_info.sha256, op_full.upload_info.sha256) huggingface_hub-0.31.1/tests/test_dduf.py000066400000000000000000000244421500667546600204230ustar00rootroot00000000000000import json import zipfile from pathlib import Path from typing import Iterable, Tuple, Union import pytest from pytest_mock import MockerFixture from huggingface_hub.errors import DDUFCorruptedFileError, DDUFExportError, DDUFInvalidEntryNameError from huggingface_hub.serialization._dduf import ( DDUFEntry, _load_content, _validate_dduf_entry_name, _validate_dduf_structure, export_entries_as_dduf, export_folder_as_dduf, read_dduf_file, ) class TestDDUFEntry: @pytest.fixture def dummy_entry(self, tmp_path: Path) -> DDUFEntry: dummy_dduf = tmp_path / "dummy_dduf.dduf" dummy_dduf.write_bytes(b"somethingCONTENTsomething") return DDUFEntry(filename="dummy.json", length=7, offset=9, dduf_path=dummy_dduf) def test_dataclass(self, dummy_entry: DDUFEntry): assert dummy_entry.filename == "dummy.json" assert dummy_entry.length == 7 assert dummy_entry.offset == 9 assert str(dummy_entry.dduf_path).endswith("dummy_dduf.dduf") def test_read_text(self, dummy_entry: DDUFEntry): assert dummy_entry.read_text() == "CONTENT" def test_as_mmap(self, dummy_entry: DDUFEntry): with dummy_entry.as_mmap() as mmap: assert mmap == b"CONTENT" class TestUtils: @pytest.mark.parametrize("filename", ["dummy.txt", "dummy.json", "dummy.safetensors"]) def test_entry_name_valid_extension(self, filename: str): assert _validate_dduf_entry_name(filename) == filename @pytest.mark.parametrize("filename", ["dummy", "dummy.bin", "dummy.dduf", "dummy.gguf"]) def test_entry_name_invalid_extension(self, filename: str): with pytest.raises(DDUFInvalidEntryNameError): _validate_dduf_entry_name(filename) @pytest.mark.parametrize("filename", ["encoder\\dummy.json", "C:\\dummy.json"]) def test_entry_name_no_windows_path(self, filename: str): with pytest.raises(DDUFInvalidEntryNameError): _validate_dduf_entry_name(filename) def test_entry_name_stripped( self, ): assert _validate_dduf_entry_name("/dummy.json") == "dummy.json" def test_entry_name_no_nested_directory(self): _validate_dduf_entry_name("bar/dummy.json") # 1 level is ok with pytest.raises(DDUFInvalidEntryNameError): _validate_dduf_entry_name("foo/bar/dummy.json") # not more def test_load_content(self, tmp_path: Path): content = b"hello world" path = tmp_path / "hello.txt" path.write_bytes(content) assert _load_content(content) == content # from bytes assert _load_content(path) == content # from Path assert _load_content(str(path)) == content # from str def test_validate_dduf_structure_valid(self): _validate_dduf_structure( { # model_index.json content "_some_key": "some_value", "encoder": { "config.json": {}, "model.safetensors": {}, }, }, { # entries in DDUF archive "model_index.json", "something.txt", "encoder/config.json", "encoder/model.safetensors", }, ) def test_validate_dduf_structure_not_a_dict(self): with pytest.raises(DDUFCorruptedFileError, match="Must be a dictionary."): _validate_dduf_structure(["not a dict"], {}) # content from 'model_index.json' def test_validate_dduf_structure_missing_folder(self): with pytest.raises(DDUFCorruptedFileError, match="Missing required entry 'encoder' in 'model_index.json'."): _validate_dduf_structure({}, {"encoder/config.json", "encoder/model.safetensors"}) def test_validate_dduf_structure_missing_config_file(self): with pytest.raises(DDUFCorruptedFileError, match="Missing required file in folder 'encoder'."): _validate_dduf_structure( {"encoder": {}}, { "encoder/not_a_config.json", # expecting a config.json / tokenizer_config.json / preprocessor_config.json / scheduler_config.json "encoder/model.safetensors", }, ) class TestExportFolder: @pytest.fixture def dummy_folder(self, tmp_path: Path): folder_path = tmp_path / "dummy_folder" folder_path.mkdir() encoder_path = folder_path / "encoder" encoder_path.mkdir() subdir_path = encoder_path / "subdir" subdir_path.mkdir() (folder_path / "config.json").touch() (folder_path / "model.safetensors").touch() (folder_path / "model.bin").touch() # won't be included (encoder_path / "config.json").touch() (encoder_path / "model.safetensors").touch() (encoder_path / "model.bin").touch() # won't be included (subdir_path / "config.json").touch() # won't be included return folder_path def test_export_folder(self, dummy_folder: Path, mocker: MockerFixture): mock = mocker.patch("huggingface_hub.serialization._dduf.export_entries_as_dduf") export_folder_as_dduf("dummy.dduf", dummy_folder) mock.assert_called_once() args = mock.call_args_list[0].args assert args[0] == "dummy.dduf" assert sorted(list(args[1])) == [ # args[1] is a generator of tuples (path_in_archive, path_on_disk) ("config.json", dummy_folder / "config.json"), ("encoder/config.json", dummy_folder / "encoder/config.json"), ("encoder/model.safetensors", dummy_folder / "encoder/model.safetensors"), ("model.safetensors", dummy_folder / "model.safetensors"), ] class TestExportEntries: @pytest.fixture def dummy_entries(self, tmp_path: Path) -> Iterable[Tuple[str, Union[str, Path, bytes]]]: (tmp_path / "model_index.json").write_text(json.dumps({"foo": "bar"})) (tmp_path / "doesnt_have_to_be_same_name.safetensors").write_bytes(b"this is safetensors content") return [ ("model_index.json", str(tmp_path / "model_index.json")), # string path ("model.safetensors", tmp_path / "doesnt_have_to_be_same_name.safetensors"), # pathlib path ("hello.txt", b"hello world"), # raw bytes ] def test_export_entries( self, tmp_path: Path, dummy_entries: Iterable[Tuple[str, Union[str, Path, bytes]]], mocker: MockerFixture ): mock = mocker.patch("huggingface_hub.serialization._dduf._validate_dduf_structure") export_entries_as_dduf(tmp_path / "dummy.dduf", dummy_entries) mock.assert_called_once_with({"foo": "bar"}, {"model_index.json", "model.safetensors", "hello.txt"}) with zipfile.ZipFile(tmp_path / "dummy.dduf", "r") as archive: assert archive.compression == zipfile.ZIP_STORED # uncompressed! assert archive.namelist() == ["model_index.json", "model.safetensors", "hello.txt"] assert archive.read("model_index.json") == b'{"foo": "bar"}' assert archive.read("model.safetensors") == b"this is safetensors content" assert archive.read("hello.txt") == b"hello world" def test_export_entries_invalid_name(self, tmp_path: Path): with pytest.raises(DDUFExportError, match="Invalid entry name") as e: export_entries_as_dduf(tmp_path / "dummy.dduf", [("config", "model_index.json")]) assert isinstance(e.value.__cause__, DDUFInvalidEntryNameError) def test_export_entries_no_duplicate(self, tmp_path: Path): with pytest.raises(DDUFExportError, match="Can't add duplicate entry"): export_entries_as_dduf( tmp_path / "dummy.dduf", [ ("model_index.json", b'{"key": "content1"}'), ("model_index.json", b'{"key": "content2"}'), ], ) def test_export_entries_model_index_required(self, tmp_path: Path): with pytest.raises(DDUFExportError, match="Missing required 'model_index.json' entry"): export_entries_as_dduf(tmp_path / "dummy.dduf", [("model.safetensors", b"content")]) class TestReadDDUFFile: @pytest.fixture def dummy_dduf_file(self, tmp_path: Path) -> Path: with zipfile.ZipFile(tmp_path / "dummy.dduf", "w") as archive: archive.writestr("model_index.json", b'{"foo": "bar"}') archive.writestr("model.safetensors", b"this is safetensors content") archive.writestr("hello.txt", b"hello world") return tmp_path / "dummy.dduf" def test_read_dduf_file(self, dummy_dduf_file: Path, mocker: MockerFixture): mock = mocker.patch("huggingface_hub.serialization._dduf._validate_dduf_structure") entries = read_dduf_file(dummy_dduf_file) assert len(entries) == 3 index_entry = entries["model_index.json"] model_entry = entries["model.safetensors"] hello_entry = entries["hello.txt"] mock.assert_called_once_with({"foo": "bar"}, {"model_index.json", "model.safetensors", "hello.txt"}) assert index_entry.filename == "model_index.json" assert index_entry.dduf_path == dummy_dduf_file assert index_entry.read_text() == '{"foo": "bar"}' with dummy_dduf_file.open("rb") as f: f.seek(index_entry.offset) assert f.read(index_entry.length) == b'{"foo": "bar"}' assert model_entry.filename == "model.safetensors" assert model_entry.dduf_path == dummy_dduf_file assert model_entry.read_text() == "this is safetensors content" with dummy_dduf_file.open("rb") as f: f.seek(model_entry.offset) assert f.read(model_entry.length) == b"this is safetensors content" assert hello_entry.filename == "hello.txt" assert hello_entry.dduf_path == dummy_dduf_file assert hello_entry.read_text() == "hello world" with dummy_dduf_file.open("rb") as f: f.seek(hello_entry.offset) assert f.read(hello_entry.length) == b"hello world" def test_model_index_required(self, tmp_path: Path): with zipfile.ZipFile(tmp_path / "dummy.dduf", "w") as archive: archive.writestr("model.safetensors", b"this is safetensors content") with pytest.raises(DDUFCorruptedFileError, match="Missing required 'model_index.json' entry"): read_dduf_file(tmp_path / "dummy.dduf") huggingface_hub-0.31.1/tests/test_endpoint_helpers.py000066400000000000000000000011361500667546600230360ustar00rootroot00000000000000# Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. huggingface_hub-0.31.1/tests/test_fastai_integration.py000066400000000000000000000055301500667546600233500ustar00rootroot00000000000000import os from unittest import TestCase, skip from huggingface_hub import HfApi from huggingface_hub.fastai_utils import ( _save_pretrained_fastai, from_pretrained_fastai, push_to_hub_fastai, ) from huggingface_hub.utils import ( SoftTemporaryDirectory, is_fastai_available, is_fastcore_available, is_torch_available, ) from .testing_constants import ENDPOINT_STAGING, TOKEN, USER from .testing_utils import repo_name WORKING_REPO_SUBDIR = f"fixtures/working_repo_{__name__.split('.')[-1]}" WORKING_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), WORKING_REPO_SUBDIR) if is_fastai_available(): from fastai.data.block import DataBlock from fastai.test_utils import synth_learner if is_torch_available(): import torch def require_fastai_fastcore(test_case): """ Decorator marking a test that requires fastai and fastcore. These tests are skipped when fastai and fastcore are not installed. """ if not is_fastai_available(): return skip("Test requires fastai")(test_case) elif not is_fastcore_available(): return skip("Test requires fastcore")(test_case) else: return test_case def fake_dataloaders(a=2, b=3, bs=16, n=10): def get_data(n): x = torch.randn(bs * n, 1) return torch.cat((x, a * x + b + 0.1 * torch.randn(bs * n, 1)), 1) ds = get_data(n) dblock = DataBlock() return dblock.dataloaders(ds) if is_fastai_available(): dummy_model = synth_learner(data=fake_dataloaders()) dummy_config = dict(test="test_0") else: dummy_model = None dummy_config = None @require_fastai_fastcore class TestFastaiUtils(TestCase): def test_save_pretrained_without_config(self): with SoftTemporaryDirectory() as tmpdir: _save_pretrained_fastai(dummy_model, tmpdir) files = os.listdir(tmpdir) self.assertTrue("model.pkl" in files) self.assertTrue("pyproject.toml" in files) self.assertTrue("README.md" in files) self.assertEqual(len(files), 3) def test_save_pretrained_with_config(self): with SoftTemporaryDirectory() as tmpdir: _save_pretrained_fastai(dummy_model, tmpdir, config=dummy_config) files = os.listdir(tmpdir) self.assertTrue("config.json" in files) self.assertEqual(len(files), 4) def test_push_to_hub_and_from_pretrained_fastai(self): api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) repo_id = f"{USER}/{repo_name()}" push_to_hub_fastai(learner=dummy_model, repo_id=repo_id, token=TOKEN, config=dummy_config) model_info = api.model_info(repo_id) assert model_info.id == repo_id loaded_model = from_pretrained_fastai(repo_id) assert dummy_model.show_training_loop() == loaded_model.show_training_loop() api.delete_repo(repo_id=repo_id) huggingface_hub-0.31.1/tests/test_file_download.py000066400000000000000000001506561500667546600223160ustar00rootroot00000000000000# Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import io import os import shutil import stat import unittest import warnings from contextlib import contextmanager from pathlib import Path from typing import Iterable, List from unittest.mock import Mock, patch import pytest import requests from requests import Response import huggingface_hub.file_download from huggingface_hub import HfApi, RepoUrl, constants from huggingface_hub._local_folder import write_download_metadata from huggingface_hub.errors import EntryNotFoundError, GatedRepoError, LocalEntryNotFoundError from huggingface_hub.file_download import ( _CACHED_NO_EXIST, HfFileMetadata, _check_disk_space, _create_symlink, _get_pointer_path, _normalize_etag, _request_wrapper, get_hf_file_metadata, hf_hub_download, hf_hub_url, http_get, try_to_load_from_cache, ) from huggingface_hub.utils import SoftTemporaryDirectory, get_session, hf_raise_for_status, is_hf_transfer_available from .testing_constants import ENDPOINT_STAGING, OTHER_TOKEN, TOKEN from .testing_utils import ( DUMMY_EXTRA_LARGE_FILE_MODEL_ID, DUMMY_EXTRA_LARGE_FILE_NAME, DUMMY_MODEL_ID, DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT, DUMMY_RENAMED_NEW_MODEL_ID, DUMMY_RENAMED_OLD_MODEL_ID, DUMMY_TINY_FILE_NAME, SAMPLE_DATASET_IDENTIFIER, repo_name, use_tmp_repo, with_production_testing, xfail_on_windows, ) REVISION_ID_DEFAULT = "main" # Default branch name DATASET_ID = SAMPLE_DATASET_IDENTIFIER # An actual dataset hosted on huggingface.co DATASET_REVISION_ID_ONE_SPECIFIC_COMMIT = "e25d55a1c4933f987c46cc75d8ffadd67f257c61" # One particular commit for DATASET_ID DATASET_SAMPLE_PY_FILE = "custom_squad.py" class TestDiskUsageWarning(unittest.TestCase): @classmethod def setUpClass(cls): # Test with 100MB expected file size cls.expected_size = 100 * 1024 * 1024 @patch("huggingface_hub.file_download.shutil.disk_usage") def test_disk_usage_warning(self, disk_usage_mock: Mock) -> None: # Test with only 1MB free disk space / not enough disk space, with UserWarning expected disk_usage_mock.return_value.free = 1024 * 1024 with warnings.catch_warnings(record=True) as w: # Cause all warnings to always be triggered. warnings.simplefilter("always") _check_disk_space(expected_size=self.expected_size, target_dir=disk_usage_mock) assert len(w) == 1 assert issubclass(w[-1].category, UserWarning) # Test with 200MB free disk space / enough disk space, with no warning expected disk_usage_mock.return_value.free = 200 * 1024 * 1024 with warnings.catch_warnings(record=True) as w: # Cause all warnings to always be triggered. warnings.simplefilter("always") _check_disk_space(expected_size=self.expected_size, target_dir=disk_usage_mock) assert len(w) == 0 def test_disk_usage_warning_with_non_existent_path(self) -> None: # Test for not existent (absolute) path with warnings.catch_warnings(record=True) as w: # Cause all warnings to always be triggered. warnings.simplefilter("always") _check_disk_space(expected_size=self.expected_size, target_dir="path/to/not_existent_path") assert len(w) == 0 # Test for not existent (relative) path with warnings.catch_warnings(record=True) as w: # Cause all warnings to always be triggered. warnings.simplefilter("always") _check_disk_space(expected_size=self.expected_size, target_dir="/path/to/not_existent_path") assert len(w) == 0 class StagingDownloadTests(unittest.TestCase): _api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) @use_tmp_repo() def test_download_from_a_gated_repo_with_hf_hub_download(self, repo_url: RepoUrl) -> None: """Checks `hf_hub_download` outputs error on gated repo. Regression test for #1121. https://github.com/huggingface/huggingface_hub/pull/1121 Cannot test on staging as dynamically setting a gated repo doesn't work there. """ # Set repo as gated response = get_session().put( f"{self._api.endpoint}/api/models/{repo_url.repo_id}/settings", json={"gated": "auto"}, headers=self._api._build_hf_headers(), ) hf_raise_for_status(response) # Cannot download file as repo is gated with SoftTemporaryDirectory() as tmpdir: with self.assertRaisesRegex( GatedRepoError, "Access to model .* is restricted and you are not in the authorized list" ): hf_hub_download( repo_id=repo_url.repo_id, filename=".gitattributes", token=OTHER_TOKEN, cache_dir=tmpdir ) @use_tmp_repo() def test_download_regular_file_from_private_renamed_repo(self, repo_url: RepoUrl) -> None: """Regression test for #1999. See https://github.com/huggingface/huggingface_hub/pull/1999. """ repo_id_before = repo_url.repo_id repo_id_after = repo_url.repo_id + "_renamed" # Make private + rename + upload regular file self._api.update_repo_settings(repo_id_before, private=True) self._api.upload_file(repo_id=repo_id_before, path_in_repo="file.txt", path_or_fileobj=b"content") self._api.move_repo(repo_id_before, repo_id_after) # Download from private renamed repo path = self._api.hf_hub_download(repo_id_before, filename="file.txt") with open(path) as f: self.assertEqual(f.read(), "content") # Move back (so that auto-cleanup works) self._api.move_repo(repo_id_after, repo_id_before) @with_production_testing class CachedDownloadTests(unittest.TestCase): def test_file_not_found_locally_and_network_disabled(self): # Valid file but missing locally and network is disabled. with SoftTemporaryDirectory() as tmpdir: # Download a first time to get the refs ok filepath = hf_hub_download( DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, cache_dir=tmpdir, local_files_only=False, ) # Remove local file os.remove(filepath) # Get without network must fail with pytest.raises(LocalEntryNotFoundError): hf_hub_download( DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, cache_dir=tmpdir, local_files_only=True, ) def test_private_repo_and_file_cached_locally(self): api = HfApi(endpoint=ENDPOINT_STAGING) repo_id = api.create_repo(repo_id=repo_name(), private=True, token=TOKEN).repo_id api.upload_file(path_or_fileobj=b"content", path_in_repo="config.json", repo_id=repo_id, token=TOKEN) with SoftTemporaryDirectory() as tmpdir: # Download a first time with token => file is cached filepath_1 = api.hf_hub_download(repo_id, filename="config.json", cache_dir=tmpdir, token=TOKEN) # Download without token => return cached file filepath_2 = api.hf_hub_download(repo_id, filename="config.json", cache_dir=tmpdir, token=False) assert filepath_1 == filepath_2 def test_file_cached_and_read_only_access(self): """Should works if file is already cached and user has read-only permission. Regression test for https://github.com/huggingface/huggingface_hub/issues/1216. """ # Valid file but missing locally and network is disabled. with SoftTemporaryDirectory() as tmpdir: # Download a first time to get the refs ok hf_hub_download(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, cache_dir=tmpdir) # Set read-only permission recursively _recursive_chmod(tmpdir, 0o555) # Get without write-access must succeed hf_hub_download(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, cache_dir=tmpdir) # Set permission back for cleanup _recursive_chmod(tmpdir, 0o777) @xfail_on_windows(reason="umask is UNIX-specific") def test_hf_hub_download_custom_cache_permission(self): """Checks `hf_hub_download` respect the cache dir permission. Regression test for #1141 #1215. https://github.com/huggingface/huggingface_hub/issues/1141 https://github.com/huggingface/huggingface_hub/issues/1215 """ with SoftTemporaryDirectory() as tmpdir: # Equivalent to umask u=rwx,g=r,o= previous_umask = os.umask(0o037) try: filepath = hf_hub_download(DUMMY_RENAMED_OLD_MODEL_ID, "config.json", cache_dir=tmpdir) # Permissions are honored (640: u=rw,g=r,o=) self.assertEqual(stat.S_IMODE(os.stat(filepath).st_mode), 0o640) finally: os.umask(previous_umask) def test_download_from_a_renamed_repo_with_hf_hub_download(self): """Checks `hf_hub_download` works also on a renamed repo. Regression test for #981. https://github.com/huggingface/huggingface_hub/issues/981 """ with SoftTemporaryDirectory() as tmpdir: filepath = hf_hub_download(DUMMY_RENAMED_OLD_MODEL_ID, "config.json", cache_dir=tmpdir) self.assertTrue(os.path.exists(filepath)) def test_hf_hub_download_with_empty_subfolder(self): """ Check subfolder arg is processed correctly when empty string is passed to `hf_hub_download`. See https://github.com/huggingface/huggingface_hub/issues/1016. """ filepath = Path( hf_hub_download( DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, subfolder="", # Subfolder should be processed as `None` ) ) # Check file exists and is not in a subfolder in cache # e.g: "(...)/snapshots//config.json" self.assertTrue(filepath.is_file()) self.assertEqual(filepath.name, constants.CONFIG_NAME) self.assertEqual(Path(filepath).parent.parent.name, "snapshots") def test_hf_hub_download_offline_no_refs(self): """Regression test for #1305. If "refs/" dir did not exists on "local_files_only" (or connection broken), a non-explicit `FileNotFoundError` was raised (for the "/refs/revision" file) instead of the documented `LocalEntryNotFoundError` (for the actual searched file). See https://github.com/huggingface/huggingface_hub/issues/1305. """ with SoftTemporaryDirectory() as cache_dir: with self.assertRaises(LocalEntryNotFoundError): hf_hub_download( DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, local_files_only=True, cache_dir=cache_dir, ) def test_hf_hub_download_with_user_agent(self): """ Check that user agent is correctly sent to the HEAD call when downloading a file. Regression test for #1854. See https://github.com/huggingface/huggingface_hub/pull/1854. """ def _check_user_agent(headers: dict): assert "user-agent" in headers assert "test/1.0.0" in headers["user-agent"] assert "foo/bar" in headers["user-agent"] with SoftTemporaryDirectory() as cache_dir: with patch("huggingface_hub.file_download._request_wrapper", wraps=_request_wrapper) as mock_request: # First download hf_hub_download( DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, cache_dir=cache_dir, library_name="test", library_version="1.0.0", user_agent="foo/bar", ) calls = mock_request.call_args_list assert len(calls) == 3 # HEAD, HEAD, GET for call in calls: _check_user_agent(call.kwargs["headers"]) with patch("huggingface_hub.file_download._request_wrapper", wraps=_request_wrapper) as mock_request: # Second download: no GET call hf_hub_download( DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, cache_dir=cache_dir, library_name="test", library_version="1.0.0", user_agent="foo/bar", ) calls = mock_request.call_args_list assert len(calls) == 2 # HEAD, HEAD for call in calls: _check_user_agent(call.kwargs["headers"]) def test_hf_hub_url_with_empty_subfolder(self): """ Check subfolder arg is processed correctly when empty string is passed to `hf_hub_url`. See https://github.com/huggingface/huggingface_hub/issues/1016. """ url = hf_hub_url( DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, subfolder="", # Subfolder should be processed as `None` ) self.assertTrue( url.endswith( # "./resolve/main/config.json" and not "./resolve/main//config.json" f"{DUMMY_MODEL_ID}/resolve/main/config.json", ) ) @patch("huggingface_hub.file_download.constants.ENDPOINT", "https://huggingface.co") @patch( "huggingface_hub.file_download.HUGGINGFACE_CO_URL_TEMPLATE", "https://huggingface.co/{repo_id}/resolve/{revision}/{filename}", ) def test_hf_hub_url_with_endpoint(self): self.assertEqual( hf_hub_url( DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, endpoint="https://hf-ci.co", ), "https://hf-ci.co/julien-c/dummy-unknown/resolve/main/config.json", ) def test_try_to_load_from_cache_exist(self): # Make sure the file is cached filepath = hf_hub_download(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME) new_file_path = try_to_load_from_cache(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME) self.assertEqual(filepath, new_file_path) new_file_path = try_to_load_from_cache(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, revision="main") self.assertEqual(filepath, new_file_path) # If file is not cached, returns None self.assertIsNone(try_to_load_from_cache(DUMMY_MODEL_ID, filename="conf.json")) # Same for uncached revisions self.assertIsNone( try_to_load_from_cache( DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, revision="aaa", ) ) # Same for uncached models self.assertIsNone(try_to_load_from_cache("bert-base", filename=constants.CONFIG_NAME)) def test_try_to_load_from_cache_specific_pr_revision_exists(self): # Make sure the file is cached file_path = hf_hub_download(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, revision="refs/pr/1") new_file_path = try_to_load_from_cache(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, revision="refs/pr/1") self.assertEqual(file_path, new_file_path) # If file is not cached, returns None self.assertIsNone(try_to_load_from_cache(DUMMY_MODEL_ID, filename="conf.json", revision="refs/pr/1")) # If revision does not exist, returns None self.assertIsNone( try_to_load_from_cache(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, revision="does-not-exist") ) def test_try_to_load_from_cache_no_exist(self): # Make sure the file is cached with self.assertRaises(EntryNotFoundError): _ = hf_hub_download(DUMMY_MODEL_ID, filename="dummy") new_file_path = try_to_load_from_cache(DUMMY_MODEL_ID, filename="dummy") self.assertEqual(new_file_path, _CACHED_NO_EXIST) new_file_path = try_to_load_from_cache(DUMMY_MODEL_ID, filename="dummy", revision="main") self.assertEqual(new_file_path, _CACHED_NO_EXIST) # If file non-existence is not cached, returns None self.assertIsNone(try_to_load_from_cache(DUMMY_MODEL_ID, filename="dummy2")) def test_try_to_load_from_cache_specific_commit_id_exist(self): """Regression test for #1306. See https://github.com/huggingface/huggingface_hub/pull/1306.""" with SoftTemporaryDirectory() as cache_dir: # Cache file from specific commit id (no "refs/"" folder) commit_id = HfApi().model_info(DUMMY_MODEL_ID).sha filepath = hf_hub_download( DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, revision=commit_id, cache_dir=cache_dir, ) # Must be able to retrieve it "offline" attempt = try_to_load_from_cache( DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, revision=commit_id, cache_dir=cache_dir, ) self.assertEqual(filepath, attempt) def test_try_to_load_from_cache_specific_commit_id_no_exist(self): """Regression test for #1306. See https://github.com/huggingface/huggingface_hub/pull/1306.""" with SoftTemporaryDirectory() as cache_dir: # Cache file from specific commit id (no "refs/"" folder) commit_id = HfApi().model_info(DUMMY_MODEL_ID).sha with self.assertRaises(EntryNotFoundError): hf_hub_download( DUMMY_MODEL_ID, filename="missing_file", revision=commit_id, cache_dir=cache_dir, ) # Must be able to retrieve it "offline" attempt = try_to_load_from_cache( DUMMY_MODEL_ID, filename="missing_file", revision=commit_id, cache_dir=cache_dir, ) self.assertEqual(attempt, _CACHED_NO_EXIST) def test_get_hf_file_metadata_basic(self) -> None: """Test getting metadata from a file on the Hub.""" url = hf_hub_url( DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, revision=DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT, ) metadata = get_hf_file_metadata(url) # Metadata self.assertEqual(metadata.commit_hash, DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT) self.assertIsNotNone(metadata.etag) # example: "85c2fc2dcdd86563aaa85ef4911..." self.assertEqual(metadata.location, url) # no redirect self.assertEqual(metadata.size, 851) def test_get_hf_file_metadata_from_a_renamed_repo(self) -> None: """Test getting metadata from a file in a renamed repo on the Hub.""" url = hf_hub_url( DUMMY_RENAMED_OLD_MODEL_ID, filename=constants.CONFIG_NAME, subfolder="", # Subfolder should be processed as `None` ) metadata = get_hf_file_metadata(url) # Got redirected to renamed repo self.assertEqual( metadata.location, url.replace(DUMMY_RENAMED_OLD_MODEL_ID, DUMMY_RENAMED_NEW_MODEL_ID), ) def test_get_hf_file_metadata_from_a_lfs_file(self) -> None: """Test getting metadata from an LFS file. Must get size of the LFS file, not size of the pointer file """ url = hf_hub_url("gpt2", filename="tf_model.h5") metadata = get_hf_file_metadata(url) self.assertIn("xethub.hf.co", metadata.location) # Redirection self.assertEqual(metadata.size, 497933648) # Size of LFS file, not pointer def test_file_consistency_check_fails_regular_file(self): """Regression test for #1396 (regular file). Download fails if file size is different than the expected one (from headers metadata). See https://github.com/huggingface/huggingface_hub/pull/1396.""" with SoftTemporaryDirectory() as cache_dir: def _mocked_hf_file_metadata(*args, **kwargs): metadata = get_hf_file_metadata(*args, **kwargs) return HfFileMetadata( commit_hash=metadata.commit_hash, etag=metadata.etag, location=metadata.location, size=450, # will expect 450 bytes but will download 496 bytes xet_file_data=None, ) with patch("huggingface_hub.file_download.get_hf_file_metadata", _mocked_hf_file_metadata): with self.assertRaises(EnvironmentError): hf_hub_download(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, cache_dir=cache_dir) def test_file_consistency_check_fails_LFS_file(self): """Regression test for #1396 (LFS file). Download fails if file size is different than the expected one (from headers metadata). See https://github.com/huggingface/huggingface_hub/pull/1396.""" with SoftTemporaryDirectory() as cache_dir: def _mocked_hf_file_metadata(*args, **kwargs): metadata = get_hf_file_metadata(*args, **kwargs) return HfFileMetadata( commit_hash=metadata.commit_hash, etag=metadata.etag, location=metadata.location, size=65000, # will expect 65000 bytes but will download 65074 bytes xet_file_data=None, ) with patch("huggingface_hub.file_download.get_hf_file_metadata", _mocked_hf_file_metadata): with self.assertRaises(EnvironmentError): hf_hub_download(DUMMY_MODEL_ID, filename="pytorch_model.bin", cache_dir=cache_dir) def test_hf_hub_download_when_tmp_file_is_complete(self): """Regression test for #2511. See https://github.com/huggingface/huggingface_hub/issues/2511. When downloading a file, we first download to a temporary file and then move it to the final location. If the temporary file is already partially downloaded, we resume from where we left off. However, if the temporary file is already fully downloaded, we should try to make a GET call with an empty range. This was causing a "416 Range Not Satisfiable" error. """ with SoftTemporaryDirectory() as tmpdir: # Download the file once filepath = Path(hf_hub_download(DUMMY_MODEL_ID, filename="pytorch_model.bin", cache_dir=tmpdir)) # Fake tmp file incomplete_filepath = Path(str(filepath.resolve()) + ".incomplete") incomplete_filepath.write_bytes(filepath.read_bytes()) # fake a partial download filepath.resolve().unlink() # delete snapshot folder to re-trigger a download shutil.rmtree(filepath.parents[2] / "snapshots") # Download must not fail hf_hub_download(DUMMY_MODEL_ID, filename="pytorch_model.bin", cache_dir=tmpdir) @unittest.skipIf(os.name == "nt", "Lock files are always deleted on Windows.") def test_keep_lock_file(self): """Lock files should not be deleted on Linux.""" with SoftTemporaryDirectory() as tmpdir: hf_hub_download(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, cache_dir=tmpdir) lock_file_exist = False locks_dir = os.path.join(tmpdir, ".locks") for subdir, dirs, files in os.walk(locks_dir): for file in files: if file.endswith(".lock"): lock_file_exist = True break self.assertTrue(lock_file_exist, "no lock file can be found") @pytest.mark.usefixtures("fx_cache_dir") class HfHubDownloadToLocalDir(unittest.TestCase): # `cache_dir` is a temporary directory # `local_dir` is a subdirectory in which files will be downloaded # `hub_cache_dir` is a subdirectory in which files will be cached ("HF cache") cache_dir: Path file_name: str = "file.txt" lfs_name: str = "lfs.bin" @property def local_dir(self) -> Path: path = Path(self.cache_dir) / "local" path.mkdir(exist_ok=True, parents=True) return path @property def hub_cache_dir(self) -> Path: path = Path(self.cache_dir) / "cache" path.mkdir(exist_ok=True, parents=True) return path @property def file_path(self) -> Path: return self.local_dir / self.file_name @property def lfs_path(self) -> Path: return self.local_dir / self.lfs_name @classmethod def setUpClass(cls): cls.api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) cls.repo_id = cls.api.create_repo(repo_id=repo_name()).repo_id commit_1 = cls.api.upload_file(path_or_fileobj=b"content", path_in_repo=cls.file_name, repo_id=cls.repo_id) commit_2 = cls.api.upload_file(path_or_fileobj=b"content", path_in_repo=cls.lfs_name, repo_id=cls.repo_id) info = cls.api.get_paths_info(repo_id=cls.repo_id, paths=[cls.file_name, cls.lfs_name]) info = {item.path: item for item in info} cls.commit_hash_1 = commit_1.oid cls.commit_hash_2 = commit_2.oid cls.file_etag = info[cls.file_name].blob_id cls.lfs_etag = info[cls.lfs_name].lfs.sha256 @classmethod def tearDownClass(cls) -> None: cls.api.delete_repo(repo_id=cls.repo_id) @contextmanager def with_patch_head(self): with patch("huggingface_hub.file_download._get_metadata_or_catch_error") as mock: yield mock @contextmanager def with_patch_download(self): with patch("huggingface_hub.file_download._download_to_tmp_and_move") as mock: yield mock def test_empty_local_dir(self): # Download to local dir returned_path = self.api.hf_hub_download( self.repo_id, filename=self.file_name, cache_dir=self.hub_cache_dir, local_dir=self.local_dir ) assert self.local_dir in Path(returned_path).parents # Cache directory not used (no blobs, no symlinks in it) for path in self.hub_cache_dir.glob("**/blobs/**"): assert not path.is_file() for path in self.hub_cache_dir.glob("**/snapshots/**"): assert not path.is_file() def test_metadata_ok_and_revision_is_a_commit_hash_and_match(self): # File already exists + commit_hash matches (and etag not even required) self.file_path.write_text("content") write_download_metadata(self.local_dir, self.file_name, self.commit_hash_1, etag="...") # Download to local dir => no HEAD call needed with self.with_patch_head() as mock: self.api.hf_hub_download( self.repo_id, filename=self.file_name, revision=self.commit_hash_1, local_dir=self.local_dir ) mock.assert_not_called() def test_metadata_ok_and_revision_is_a_commit_hash_and_mismatch(self): # 1 HEAD call + 1 download # File already exists + commit_hash mismatch self.file_path.write_text("content") write_download_metadata(self.local_dir, self.file_name, self.commit_hash_1, etag="...") # Mismatch => download with self.with_patch_download() as mock: self.api.hf_hub_download( self.repo_id, filename=self.file_name, revision=self.commit_hash_2, local_dir=self.local_dir ) mock.assert_called_once() def test_metadata_not_ok_and_revision_is_a_commit_hash(self): # 1 HEAD call + 1 download # File already exists but no metadata self.file_path.write_text("content") # Mismatch => download with self.with_patch_download() as mock: self.api.hf_hub_download( self.repo_id, filename=self.file_name, revision=self.commit_hash_1, local_dir=self.local_dir ) mock.assert_called_once() def test_local_files_only_and_file_exists(self): # must return without error self.file_path.write_text("content2") path = self.api.hf_hub_download( self.repo_id, filename=self.file_name, local_dir=self.local_dir, local_files_only=True ) assert Path(path) == self.file_path assert self.file_path.read_text() == "content2" # not overwritten even if wrong content def test_local_files_only_and_file_missing(self): # must raise with self.assertRaises(LocalEntryNotFoundError): self.api.hf_hub_download( self.repo_id, filename=self.file_name, local_dir=self.local_dir, local_files_only=True ) def test_metadata_ok_and_etag_match(self): # 1 HEAD call + return early self.file_path.write_text("something") write_download_metadata(self.local_dir, self.file_name, self.commit_hash_1, etag=self.file_etag) with self.with_patch_download() as mock: # Download from main => commit_hash mismatch but etag match => return early self.api.hf_hub_download(self.repo_id, filename=self.file_name, local_dir=self.local_dir) mock.assert_not_called() def test_metadata_ok_and_etag_mismatch(self): # 1 HEAD call + 1 download self.file_path.write_text("something") write_download_metadata(self.local_dir, self.file_name, self.commit_hash_1, etag="some_other_etag") with self.with_patch_download() as mock: # Download from main => commit_hash mismatch but etag match => return early self.api.hf_hub_download(self.repo_id, filename=self.file_name, local_dir=self.local_dir) mock.assert_called_once() def test_metadata_ok_and_etag_match_and_force_download(self): # force_download=True takes precedence on any other rule self.file_path.write_text("something") write_download_metadata(self.local_dir, self.file_name, self.commit_hash_1, etag=self.file_etag) with self.with_patch_download() as mock: self.api.hf_hub_download( self.repo_id, filename=self.file_name, local_dir=self.local_dir, force_download=True ) mock.assert_called_once() def test_metadata_not_ok_and_lfs_file_and_sha256_match(self): # 1 HEAD call + 1 hash compute + return early self.lfs_path.write_text("content") with self.with_patch_download() as mock: # Download from main # => no metadata but it's an LFS file # => compute local hash => matches => return early self.api.hf_hub_download(self.repo_id, filename=self.lfs_name, local_dir=self.local_dir) mock.assert_not_called() def test_metadata_not_ok_and_lfs_file_and_sha256_mismatch(self): # 1 HEAD call + 1 file hash + 1 download self.lfs_path.write_text("wrong_content") # Download from main # => no metadata but it's an LFS file # => compute local hash => mismatches => download path = self.api.hf_hub_download(self.repo_id, filename=self.lfs_name, local_dir=self.local_dir) # existing file overwritten assert Path(path).read_text() == "content" def test_file_exists_in_cache(self): # 1 HEAD call + return early self.api.hf_hub_download(self.repo_id, filename=self.file_name, cache_dir=self.hub_cache_dir) with self.with_patch_download() as mock: # Download to local dir # => file is already in Hub cache # => we assume it's faster to make a local copy rather than re-downloading # => duplicate file locally path = self.api.hf_hub_download( self.repo_id, filename=self.file_name, cache_dir=self.hub_cache_dir, local_dir=self.local_dir ) mock.assert_not_called() assert Path(path) == self.file_path def test_file_exists_and_overwrites(self): # 1 HEAD call + 1 download self.file_path.write_text("another content") self.api.hf_hub_download(self.repo_id, filename=self.file_name, local_dir=self.local_dir) assert self.file_path.read_text() == "content" def test_resume_from_incomplete(self): # An incomplete file already exists => use it incomplete_path = self.local_dir / ".cache" / "huggingface" / "download" / (self.file_name + ".incomplete") incomplete_path.parent.mkdir(parents=True, exist_ok=True) incomplete_path.write_text("XXXX") # Here we put fake data to test the resume self.api.hf_hub_download(self.repo_id, filename=self.file_name, local_dir=self.local_dir) self.file_path.read_text() == "XXXXent" def test_do_not_resume_on_force_download(self): # An incomplete file already exists but force_download=True incomplete_path = self.local_dir / ".cache" / "huggingface" / "download" / (self.file_name + ".incomplete") incomplete_path.parent.mkdir(parents=True, exist_ok=True) incomplete_path.write_text("XXXX") self.api.hf_hub_download(self.repo_id, filename=self.file_name, local_dir=self.local_dir, force_download=True) self.file_path.read_text() == "content" @patch("huggingface_hub.file_download.build_hf_headers") def test_passing_token_false_is_respected(self, mock: Mock): """Regression test for #2385. A bug introduced in 0.23.0 was causing the `token` parameter to be ignored when set to `False`. See https://github.com/huggingface/huggingface_hub/issues/2385. """ # Download to local dir mock.reset_mock(return_value={}) self.api.hf_hub_download(self.repo_id, filename=self.file_name, local_dir=self.local_dir, token=False) mock.assert_called() for call in mock.call_args_list: assert call.kwargs["token"] is False # Download to cache dir mock.reset_mock(return_value={}) self.api.hf_hub_download(self.repo_id, filename=self.file_name, cache_dir=self.local_dir, token=False) mock.assert_called() for call in mock.call_args_list: assert call.kwargs["token"] is False @pytest.mark.usefixtures("fx_cache_dir") class StagingCachedDownloadOnAwfulFilenamesTest(unittest.TestCase): """Implement regression tests for #1161. Issue was on filename not url encoded by `hf_hub_download` and `hf_hub_url`. See https://github.com/huggingface/huggingface_hub/issues/1161 """ cache_dir: Path subfolder = "subfolder/to?" filename = "awful?filename%you:should,never.give" filepath = f"subfolder/to?/{filename}" @classmethod def setUpClass(cls): cls.api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) cls.repo_url = cls.api.create_repo(repo_id=repo_name("awful_filename")) cls.expected_resolve_url = ( f"{cls.repo_url}/resolve/main/subfolder/to%3F/awful%3Ffilename%25you%3Ashould%2Cnever.give" ) cls.api.upload_file( path_or_fileobj=b"content", path_in_repo=cls.filepath, repo_id=cls.repo_url.repo_id, ) @classmethod def tearDownClass(cls) -> None: cls.api.delete_repo(repo_id=cls.repo_url.repo_id) def test_hf_hub_url_on_awful_filepath(self): self.assertEqual(hf_hub_url(self.repo_url.repo_id, self.filepath), self.expected_resolve_url) def test_hf_hub_url_on_awful_subfolder_and_filename(self): self.assertEqual( hf_hub_url(self.repo_url.repo_id, self.filename, subfolder=self.subfolder), self.expected_resolve_url, ) @xfail_on_windows(reason="Windows paths cannot contain a '?'.") def test_hf_hub_download_on_awful_filepath(self): local_path = hf_hub_download(self.repo_url.repo_id, self.filepath, cache_dir=self.cache_dir) # Local path is not url-encoded self.assertTrue(local_path.endswith(self.filepath)) @xfail_on_windows(reason="Windows paths cannot contain a '?'.") def test_hf_hub_download_on_awful_subfolder_and_filename(self): local_path = hf_hub_download( self.repo_url.repo_id, self.filename, subfolder=self.subfolder, cache_dir=self.cache_dir, ) # Local path is not url-encoded self.assertTrue(local_path.endswith(self.filepath)) @pytest.mark.usefixtures("fx_cache_dir") class TestHfHubDownloadRelativePaths(unittest.TestCase): """Regression test for HackerOne report 1928845. Issue was that any file outside of the local dir could be overwritten (Windows only). In the end, multiple protections have been added to prevent this (..\\ in filename forbidden on Windows, always check the filepath is in local_dir/snapshot_dir). """ cache_dir: Path @classmethod def setUpClass(cls): cls.api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) cls.repo_id = cls.api.create_repo(repo_id=repo_name()).repo_id cls.api.upload_file(path_or_fileobj=b"content", path_in_repo="folder/..\\..\\..\\file", repo_id=cls.repo_id) @classmethod def tearDownClass(cls) -> None: cls.api.delete_repo(repo_id=cls.repo_id) @xfail_on_windows(reason="Windows paths cannot contain '\\..\\'.", raises=ValueError) def test_download_folder_file_in_cache_dir(self) -> None: hf_hub_download(self.repo_id, "folder/..\\..\\..\\file", cache_dir=self.cache_dir) @xfail_on_windows(reason="Windows paths cannot contain '\\..\\'.", raises=ValueError) def test_download_folder_file_to_local_dir(self) -> None: with SoftTemporaryDirectory() as local_dir: hf_hub_download(self.repo_id, "folder/..\\..\\..\\file", cache_dir=self.cache_dir, local_dir=local_dir) def test_get_pointer_path_and_valid_relative_filename(self) -> None: # Cannot happen because of other protections, but just in case. self.assertEqual( _get_pointer_path("path/to/storage", "abcdef", "path/to/file.txt"), os.path.join("path/to/storage", "snapshots", "abcdef", "path/to/file.txt"), ) def test_get_pointer_path_but_invalid_relative_filename(self) -> None: # Cannot happen because of other protections, but just in case. relative_filename = "folder\\..\\..\\..\\file.txt" if os.name == "nt" else "folder/../../../file.txt" with self.assertRaises(ValueError): _get_pointer_path("path/to/storage", "abcdef", relative_filename) class TestHttpGet: def test_http_get_with_ssl_and_timeout_error(self, caplog): def _iter_content_1() -> Iterable[bytes]: yield b"0" * 10 yield b"0" * 10 raise requests.exceptions.SSLError("Fake SSLError") def _iter_content_2() -> Iterable[bytes]: yield b"0" * 10 raise requests.ReadTimeout("Fake ReadTimeout") def _iter_content_3() -> Iterable[bytes]: yield b"0" * 10 yield b"0" * 10 yield b"0" * 10 raise requests.ConnectionError("Fake ConnectionError") def _iter_content_4() -> Iterable[bytes]: yield b"0" * 10 yield b"0" * 10 yield b"0" * 10 yield b"0" * 10 with patch("huggingface_hub.file_download._request_wrapper") as mock: mock.return_value.headers = {"Content-Length": 100} mock.return_value.iter_content.side_effect = [ _iter_content_1(), _iter_content_2(), _iter_content_3(), _iter_content_4(), ] temp_file = io.BytesIO() http_get("fake_url", temp_file=temp_file) assert len([r for r in caplog.records if r.levelname == "WARNING"]) == 3 # Check final value assert temp_file.tell() == 100 assert temp_file.getvalue() == b"0" * 100 # Check number of calls + correct range headers assert len(mock.call_args_list) == 4 assert mock.call_args_list[0].kwargs["headers"] == {} assert mock.call_args_list[1].kwargs["headers"] == {"Range": "bytes=20-"} assert mock.call_args_list[2].kwargs["headers"] == {"Range": "bytes=30-"} assert mock.call_args_list[3].kwargs["headers"] == {"Range": "bytes=60-"} @pytest.mark.parametrize( "initial_range,expected_ranges", [ # Test suffix ranges (bytes=-100) ( "bytes=-100", [ "bytes=-100", "bytes=-80", "bytes=-70", "bytes=-40", ], ), # Test prefix ranges (bytes=15-) ( "bytes=15-", [ "bytes=15-", "bytes=35-", "bytes=45-", "bytes=75-", ], ), # Test double closed ranges (bytes=15-114) ( "bytes=15-114", [ "bytes=15-114", "bytes=35-114", "bytes=45-114", "bytes=75-114", ], ), ], ) def test_http_get_with_range_headers(self, caplog, initial_range: str, expected_ranges: List[str]): def _iter_content_1() -> Iterable[bytes]: yield b"0" * 10 yield b"0" * 10 raise requests.exceptions.SSLError("Fake SSLError") def _iter_content_2() -> Iterable[bytes]: yield b"0" * 10 raise requests.ReadTimeout("Fake ReadTimeout") def _iter_content_3() -> Iterable[bytes]: yield b"0" * 10 yield b"0" * 10 yield b"0" * 10 raise requests.ConnectionError("Fake ConnectionError") def _iter_content_4() -> Iterable[bytes]: yield b"0" * 10 yield b"0" * 10 yield b"0" * 10 yield b"0" * 10 with patch("huggingface_hub.file_download._request_wrapper") as mock: mock.return_value.headers = {"Content-Length": 100} mock.return_value.iter_content.side_effect = [ _iter_content_1(), _iter_content_2(), _iter_content_3(), _iter_content_4(), ] temp_file = io.BytesIO() http_get("fake_url", temp_file=temp_file, headers={"Range": initial_range}) assert len([r for r in caplog.records if r.levelname == "WARNING"]) == 3 assert temp_file.tell() == 100 assert temp_file.getvalue() == b"0" * 100 assert len(mock.call_args_list) == 4 for i, expected_range in enumerate(expected_ranges): assert mock.call_args_list[i].kwargs["headers"] == {"Range": expected_range} class CreateSymlinkTest(unittest.TestCase): @unittest.skipIf(os.name == "nt", "No symlinks on Windows") @patch("huggingface_hub.file_download.are_symlinks_supported") def test_create_symlink_concurrent_access(self, mock_are_symlinks_supported: Mock) -> None: with SoftTemporaryDirectory() as tmpdir: src = os.path.join(tmpdir, "source") other = os.path.join(tmpdir, "other") dst = os.path.join(tmpdir, "destination") # Normal case: symlink does not exist mock_are_symlinks_supported.return_value = True _create_symlink(src, dst) self.assertEqual(os.path.realpath(dst), os.path.realpath(src)) # Symlink already exists when it tries to create it (most probably from a # concurrent access) but do not raise exception def _are_symlinks_supported(cache_dir: str) -> bool: os.symlink(src, dst) return True mock_are_symlinks_supported.side_effect = _are_symlinks_supported _create_symlink(src, dst) # Symlink already exists but pointing to a different source file. This should # never happen in the context of HF cache system -> raise exception def _are_symlinks_supported(cache_dir: str) -> bool: os.symlink(other, dst) return True mock_are_symlinks_supported.side_effect = _are_symlinks_supported with self.assertRaises(FileExistsError): _create_symlink(src, dst) def test_create_symlink_relative_src(self) -> None: """Regression test for #1388. See https://github.com/huggingface/huggingface_hub/issues/1388. """ # Test dir has to be relative test_dir = Path(".") / "dir_for_create_symlink_test" test_dir.mkdir(parents=True, exist_ok=True) src = Path(test_dir) / "source" src.touch() dst = Path(test_dir) / "destination" _create_symlink(str(src), str(dst)) self.assertTrue(dst.resolve().is_file()) if os.name != "nt": self.assertEqual(dst.resolve(), src.resolve()) shutil.rmtree(test_dir) class TestNormalizeEtag(unittest.TestCase): """Unit tests implemented after a server-side change broke the ETag normalization once (see #1428). TL;DR: _normalize_etag was expecting only strong references, but the server started to return weak references after a config update. Problem was quickly fixed server-side but we prefer to make sure this doesn't happen again by supporting weak etags. For context, etags are used to build the cache-system structure. For more details, see https://github.com/huggingface/huggingface_hub/pull/1428 and related issues. """ def test_strong_reference(self): self.assertEqual( _normalize_etag('"a16a55fda99d2f2e7b69cce5cf93ff4ad3049930"'), "a16a55fda99d2f2e7b69cce5cf93ff4ad3049930" ) def test_weak_reference(self): self.assertEqual( _normalize_etag('W/"a16a55fda99d2f2e7b69cce5cf93ff4ad3049930"'), "a16a55fda99d2f2e7b69cce5cf93ff4ad3049930" ) @with_production_testing def test_resolve_endpoint_on_regular_file(self): url = "https://huggingface.co/gpt2/resolve/e7da7f221d5bf496a48136c0cd264e630fe9fcc8/README.md" response = requests.head(url) self.assertEqual(self._get_etag_and_normalize(response), "a16a55fda99d2f2e7b69cce5cf93ff4ad3049930") @with_production_testing def test_resolve_endpoint_on_lfs_file(self): url = "https://huggingface.co/gpt2/resolve/e7da7f221d5bf496a48136c0cd264e630fe9fcc8/pytorch_model.bin" response = requests.head(url) self.assertEqual( self._get_etag_and_normalize(response), "7c5d3f4b8b76583b422fcb9189ad6c89d5d97a094541ce8932dce3ecabde1421" ) @staticmethod def _get_etag_and_normalize(response: Response) -> str: response.raise_for_status() return _normalize_etag( response.headers.get(constants.HUGGINGFACE_HEADER_X_LINKED_ETAG) or response.headers.get("ETag") ) @with_production_testing class TestEtagTimeoutConfig(unittest.TestCase): @patch("huggingface_hub.file_download.constants.DEFAULT_ETAG_TIMEOUT", 10) @patch("huggingface_hub.file_download.constants.HF_HUB_ETAG_TIMEOUT", 10) def test_etag_timeout_default_value(self): with SoftTemporaryDirectory() as cache_dir: with patch.object( huggingface_hub.file_download, "get_hf_file_metadata", wraps=huggingface_hub.file_download.get_hf_file_metadata, ) as mock_etag_call: hf_hub_download(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, cache_dir=cache_dir) kwargs = mock_etag_call.call_args.kwargs self.assertIn("timeout", kwargs) self.assertEqual(kwargs["timeout"], 10) @patch("huggingface_hub.file_download.constants.DEFAULT_ETAG_TIMEOUT", 10) @patch("huggingface_hub.file_download.constants.HF_HUB_ETAG_TIMEOUT", 10) def test_etag_timeout_parameter_value(self): with SoftTemporaryDirectory() as cache_dir: with patch.object( huggingface_hub.file_download, "get_hf_file_metadata", wraps=huggingface_hub.file_download.get_hf_file_metadata, ) as mock_etag_call: hf_hub_download(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, cache_dir=cache_dir, etag_timeout=12) kwargs = mock_etag_call.call_args.kwargs self.assertIn("timeout", kwargs) self.assertEqual(kwargs["timeout"], 12) # passed as parameter, takes priority @patch("huggingface_hub.file_download.constants.DEFAULT_ETAG_TIMEOUT", 10) @patch("huggingface_hub.file_download.constants.HF_HUB_ETAG_TIMEOUT", 15) # takes priority def test_etag_timeout_set_as_env_variable(self): with SoftTemporaryDirectory() as cache_dir: with patch.object( huggingface_hub.file_download, "get_hf_file_metadata", wraps=huggingface_hub.file_download.get_hf_file_metadata, ) as mock_etag_call: hf_hub_download(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, cache_dir=cache_dir) kwargs = mock_etag_call.call_args.kwargs self.assertIn("timeout", kwargs) self.assertEqual(kwargs["timeout"], 15) @patch("huggingface_hub.file_download.constants.DEFAULT_ETAG_TIMEOUT", 10) @patch("huggingface_hub.file_download.constants.HF_HUB_ETAG_TIMEOUT", 12) # takes priority def test_etag_timeout_set_as_env_variable_parameter_ignored(self): with SoftTemporaryDirectory() as cache_dir: with patch.object( huggingface_hub.file_download, "get_hf_file_metadata", wraps=huggingface_hub.file_download.get_hf_file_metadata, ) as mock_etag_call: hf_hub_download(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, cache_dir=cache_dir, etag_timeout=12) kwargs = mock_etag_call.call_args.kwargs self.assertIn("timeout", kwargs) self.assertEqual(kwargs["timeout"], 12) # passed value ignored, HF_HUB_ETAG_TIMEOUT takes priority @with_production_testing class TestExtraLargeFileDownloadPaths(unittest.TestCase): @patch("huggingface_hub.file_download.constants.HF_HUB_ENABLE_HF_TRANSFER", False) def test_large_file_http_path_error(self): with SoftTemporaryDirectory() as cache_dir: with self.assertRaises( ValueError, msg="The file is too large to be downloaded using the regular download method. Use `hf_transfer` or `xet_get` instead. Try `pip install hf_transfer` or `pip install hf_xet`.", ): hf_hub_download( DUMMY_EXTRA_LARGE_FILE_MODEL_ID, filename=DUMMY_EXTRA_LARGE_FILE_NAME, cache_dir=cache_dir, revision="main", etag_timeout=10, ) # Test "large" file download with hf_transfer. Use a tiny file to keep the tests fast and avoid # internal gateway transfer quotas. @unittest.skipIf( not is_hf_transfer_available(), "hf_transfer not installed, so skipping large file download with hf_transfer check.", ) @patch("huggingface_hub.file_download.constants.HF_HUB_ENABLE_HF_TRANSFER", True) @patch("huggingface_hub.file_download.constants.MAX_HTTP_DOWNLOAD_SIZE", 44) @patch("huggingface_hub.file_download.constants.DOWNLOAD_CHUNK_SIZE", 2) # make sure hf_download is used def test_large_file_download_with_hf_transfer(self): with SoftTemporaryDirectory() as cache_dir: path = hf_hub_download( DUMMY_EXTRA_LARGE_FILE_MODEL_ID, filename=DUMMY_TINY_FILE_NAME, cache_dir=cache_dir, revision="main", etag_timeout=10, ) with open(path, "rb") as f: content = f.read() self.assertEqual(content, b"test\n" * 9) # the file is 9 lines of "test" def _recursive_chmod(path: str, mode: int) -> None: # Taken from https://stackoverflow.com/a/2853934 for root, dirs, files in os.walk(path): for d in dirs: os.chmod(os.path.join(root, d), mode) for f in files: os.chmod(os.path.join(root, f), mode) huggingface_hub-0.31.1/tests/test_hf_api.py000066400000000000000000005775111500667546600207410ustar00rootroot00000000000000# Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import datetime import os import re import subprocess import tempfile import time import types import unittest import uuid from collections.abc import Iterable from concurrent.futures import Future from dataclasses import fields from io import BytesIO from pathlib import Path from typing import List, Optional, Set, Union, get_args from unittest.mock import Mock, patch from urllib.parse import quote, urlparse import pytest import requests from requests.exceptions import HTTPError import huggingface_hub.lfs from huggingface_hub import HfApi, SpaceHardware, SpaceStage, SpaceStorage, constants from huggingface_hub._commit_api import ( CommitOperationAdd, CommitOperationCopy, CommitOperationDelete, _fetch_upload_modes, ) from huggingface_hub.community import DiscussionComment, DiscussionWithDetails from huggingface_hub.errors import ( BadRequestError, EntryNotFoundError, GatedRepoError, HfHubHTTPError, RepositoryNotFoundError, RevisionNotFoundError, ) from huggingface_hub.file_download import hf_hub_download from huggingface_hub.hf_api import ( AccessRequest, Collection, CommitInfo, DatasetInfo, ExpandDatasetProperty_T, ExpandModelProperty_T, ExpandSpaceProperty_T, InferenceEndpoint, ModelInfo, RepoSibling, RepoUrl, SpaceInfo, SpaceRuntime, User, WebhookInfo, WebhookWatchedItem, repo_type_and_id_from_hf_id, ) from huggingface_hub.repocard_data import DatasetCardData, ModelCardData from huggingface_hub.utils import ( NotASafetensorsRepoError, SafetensorsFileMetadata, SafetensorsParsingError, SafetensorsRepoMetadata, SoftTemporaryDirectory, TensorInfo, get_session, hf_raise_for_status, logging, ) from huggingface_hub.utils.endpoint_helpers import _is_emission_within_threshold from .testing_constants import ( ENDPOINT_STAGING, ENTERPRISE_ORG, ENTERPRISE_TOKEN, FULL_NAME, OTHER_TOKEN, OTHER_USER, TOKEN, USER, ) from .testing_utils import ( DUMMY_DATASET_ID, DUMMY_DATASET_ID_REVISION_ONE_SPECIFIC_COMMIT, DUMMY_MODEL_ID, DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT, ENDPOINT_PRODUCTION, SAMPLE_DATASET_IDENTIFIER, expect_deprecation, repo_name, require_git_lfs, rmtree_with_retry, use_tmp_repo, with_production_testing, ) logger = logging.get_logger(__name__) WORKING_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/working_repo") LARGE_FILE_14MB = "https://cdn-media.huggingface.co/lfs-largefiles/progit.epub" LARGE_FILE_18MB = "https://cdn-media.huggingface.co/lfs-largefiles/progit.pdf" INVALID_MODELCARD = """ --- model-index: foo --- This is a modelcard with an invalid metadata section. """ class HfApiCommonTest(unittest.TestCase): @classmethod def setUpClass(cls): """Share the valid token in all tests below.""" cls._token = TOKEN cls._api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) class HfApiRepoFileExistsTest(HfApiCommonTest): def setUp(self) -> None: super().setUp() self.repo_id = self._api.create_repo(repo_name(), private=True).repo_id self.upload = self._api.upload_file(repo_id=self.repo_id, path_in_repo="file.txt", path_or_fileobj=b"content") def tearDown(self) -> None: self._api.delete_repo(self.repo_id) return super().tearDown() def test_repo_exists(self): assert self._api.repo_exists(self.repo_id) self.assertFalse(self._api.repo_exists(self.repo_id, token=False)) # private repo self.assertFalse(self._api.repo_exists("repo-that-does-not-exist")) # missing repo def test_revision_exists(self): assert self._api.revision_exists(self.repo_id, "main") assert not self._api.revision_exists(self.repo_id, "revision-that-does-not-exist") # missing revision assert not self._api.revision_exists(self.repo_id, "main", token=False) # private repo assert not self._api.revision_exists("repo-that-does-not-exist", "main") # missing repo @patch("huggingface_hub.constants.ENDPOINT", "https://hub-ci.huggingface.co") @patch( "huggingface_hub.constants.HUGGINGFACE_CO_URL_TEMPLATE", "https://hub-ci.huggingface.co/{repo_id}/resolve/{revision}/{filename}", ) def test_file_exists(self): assert self._api.file_exists(self.repo_id, "file.txt") self.assertFalse(self._api.file_exists("repo-that-does-not-exist", "file.txt")) # missing repo self.assertFalse(self._api.file_exists(self.repo_id, "file-does-not-exist")) # missing file self.assertFalse( self._api.file_exists(self.repo_id, "file.txt", revision="revision-that-does-not-exist") ) # missing revision self.assertFalse(self._api.file_exists(self.repo_id, "file.txt", token=False)) # private repo class HfApiEndpointsTest(HfApiCommonTest): def test_whoami_with_passing_token(self): info = self._api.whoami(token=self._token) self.assertEqual(info["name"], USER) self.assertEqual(info["fullname"], FULL_NAME) self.assertIsInstance(info["orgs"], list) valid_org = [org for org in info["orgs"] if org["name"] == "valid_org"][0] self.assertEqual(valid_org["fullname"], "Dummy Org") @patch("huggingface_hub.utils._headers.get_token", return_value=TOKEN) def test_whoami_with_implicit_token_from_login(self, mock_get_token: Mock) -> None: """Test using `whoami` after a `huggingface-cli login`.""" with patch.object(self._api, "token", None): # no default token info = self._api.whoami() self.assertEqual(info["name"], USER) @patch("huggingface_hub.utils._headers.get_token") def test_whoami_with_implicit_token_from_hf_api(self, mock_get_token: Mock) -> None: """Test using `whoami` with token from the HfApi client.""" info = self._api.whoami() self.assertEqual(info["name"], USER) mock_get_token.assert_not_called() def test_delete_repo_error_message(self): # test for #751 # See https://github.com/huggingface/huggingface_hub/issues/751 with self.assertRaisesRegex( requests.exceptions.HTTPError, re.compile( r"404 Client Error(.+)\(Request ID: .+\)(.*)Repository Not Found", flags=re.DOTALL, ), ): self._api.delete_repo("repo-that-does-not-exist") def test_delete_repo_missing_ok(self) -> None: self._api.delete_repo("repo-that-does-not-exist", missing_ok=True) def test_update_repo_visibility(self): repo_id = self._api.create_repo(repo_id=repo_name()).repo_id self._api.update_repo_settings(repo_id=repo_id, private=True) assert self._api.model_info(repo_id).private self._api.update_repo_settings(repo_id=repo_id, private=False) assert not self._api.model_info(repo_id).private self._api.delete_repo(repo_id=repo_id) def test_move_repo_normal_usage(self): repo_id = f"{USER}/{repo_name()}" new_repo_id = f"{USER}/{repo_name()}" # Spaces not tested on staging (error 500) for repo_type in [None, constants.REPO_TYPE_MODEL, constants.REPO_TYPE_DATASET]: self._api.create_repo(repo_id=repo_id, repo_type=repo_type) self._api.move_repo(from_id=repo_id, to_id=new_repo_id, repo_type=repo_type) self._api.delete_repo(repo_id=new_repo_id, repo_type=repo_type) def test_move_repo_target_already_exists(self) -> None: repo_id_1 = f"{USER}/{repo_name()}" repo_id_2 = f"{USER}/{repo_name()}" self._api.create_repo(repo_id=repo_id_1) self._api.create_repo(repo_id=repo_id_2) with pytest.raises(HfHubHTTPError, match=r"A model repository called .* already exists"): self._api.move_repo(from_id=repo_id_1, to_id=repo_id_2, repo_type=constants.REPO_TYPE_MODEL) self._api.delete_repo(repo_id=repo_id_1) self._api.delete_repo(repo_id=repo_id_2) def test_move_repo_invalid_repo_id(self) -> None: """Test from_id and to_id must be in the form `"namespace/repo_name"`.""" with pytest.raises(ValueError, match=r"Invalid repo_id*"): self._api.move_repo(from_id="namespace/repo_name", to_id="invalid_repo_id") with pytest.raises(ValueError, match=r"Invalid repo_id*"): self._api.move_repo(from_id="invalid_repo_id", to_id="namespace/repo_name") @use_tmp_repo(repo_type="model") def test_update_repo_settings(self, repo_url: RepoUrl): repo_id = repo_url.repo_id for gated_value in ["auto", "manual", False]: for private_value in [True, False]: # Test both private and public settings self._api.update_repo_settings(repo_id=repo_id, gated=gated_value, private=private_value) info = self._api.model_info(repo_id) assert info.gated == gated_value assert info.private == private_value # Verify the private setting @use_tmp_repo(repo_type="dataset") def test_update_dataset_repo_settings(self, repo_url: RepoUrl): repo_id = repo_url.repo_id repo_type = repo_url.repo_type for gated_value in ["auto", "manual", False]: for private_value in [True, False]: self._api.update_repo_settings( repo_id=repo_id, repo_type=repo_type, gated=gated_value, private=private_value ) info = self._api.dataset_info(repo_id) assert info.gated == gated_value assert info.private == private_value @use_tmp_repo(repo_type="model") def test_update_repo_settings_xet_enabled(self, repo_url: RepoUrl): repo_id = repo_url.repo_id self._api.update_repo_settings(repo_id=repo_id, xet_enabled=True) info = self._api.model_info(repo_id, expand="xetEnabled") assert info.xet_enabled @expect_deprecation("get_token_permission") def test_get_token_permission_on_oauth_token(self): whoami = { "type": "user", "auth": {"type": "oauth", "expiresAt": "2024-10-24T19:43:43.000Z"}, # ... # other values are ignored as we only need to check the "auth" value } with patch.object(self._api, "whoami", return_value=whoami): assert self._api.get_token_permission() is None class CommitApiTest(HfApiCommonTest): def setUp(self) -> None: super().setUp() self.tmp_dir = tempfile.mkdtemp() self.tmp_file = os.path.join(self.tmp_dir, "temp") self.tmp_file_content = "Content of the file" with open(self.tmp_file, "w+") as f: f.write(self.tmp_file_content) os.makedirs(os.path.join(self.tmp_dir, "nested")) self.nested_tmp_file = os.path.join(self.tmp_dir, "nested", "file.bin") with open(self.nested_tmp_file, "wb+") as f: f.truncate(1024 * 1024) self.addCleanup(rmtree_with_retry, self.tmp_dir) def test_upload_file_validation(self) -> None: with self.assertRaises(ValueError, msg="Wrong repo type"): self._api.upload_file( path_or_fileobj=self.tmp_file, path_in_repo="README.md", repo_id="something", repo_type="this type does not exist", ) def test_commit_operation_validation(self): with open(self.tmp_file, "rt") as ftext: with self.assertRaises( ValueError, msg="If you passed a file-like object, make sure it is in binary mode", ): CommitOperationAdd(path_or_fileobj=ftext, path_in_repo="README.md") # type: ignore with self.assertRaises(ValueError, msg="path_or_fileobj is str but does not point to a file"): CommitOperationAdd( path_or_fileobj=os.path.join(self.tmp_dir, "nofile.pth"), path_in_repo="README.md", ) @use_tmp_repo() def test_upload_file_str_path(self, repo_url: RepoUrl) -> None: repo_id = repo_url.repo_id return_val = self._api.upload_file( path_or_fileobj=self.tmp_file, path_in_repo="temp/new_file.md", repo_id=repo_id, ) self.assertEqual(return_val, f"{repo_url}/blob/main/temp/new_file.md") self.assertIsInstance(return_val, CommitInfo) with SoftTemporaryDirectory() as cache_dir: with open(hf_hub_download(repo_id=repo_id, filename="temp/new_file.md", cache_dir=cache_dir)) as f: self.assertEqual(f.read(), self.tmp_file_content) @use_tmp_repo() def test_upload_file_pathlib_path(self, repo_url: RepoUrl) -> None: """Regression test for https://github.com/huggingface/huggingface_hub/issues/1246.""" self._api.upload_file(path_or_fileobj=Path(self.tmp_file), path_in_repo="file.txt", repo_id=repo_url.repo_id) self.assertIn("file.txt", self._api.list_repo_files(repo_id=repo_url.repo_id)) @use_tmp_repo() def test_upload_file_fileobj(self, repo_url: RepoUrl) -> None: repo_id = repo_url.repo_id with open(self.tmp_file, "rb") as filestream: return_val = self._api.upload_file( path_or_fileobj=filestream, path_in_repo="temp/new_file.md", repo_id=repo_id, ) self.assertEqual(return_val, f"{repo_url}/blob/main/temp/new_file.md") with SoftTemporaryDirectory() as cache_dir: with open(hf_hub_download(repo_id=repo_id, filename="temp/new_file.md", cache_dir=cache_dir)) as f: self.assertEqual(f.read(), self.tmp_file_content) @use_tmp_repo() def test_upload_file_bytesio(self, repo_url: RepoUrl) -> None: repo_id = repo_url.repo_id content = BytesIO(b"File content, but in bytes IO") return_val = self._api.upload_file( path_or_fileobj=content, path_in_repo="temp/new_file.md", repo_id=repo_id, ) self.assertEqual(return_val, f"{repo_url}/blob/main/temp/new_file.md") with SoftTemporaryDirectory() as cache_dir: with open(hf_hub_download(repo_id=repo_id, filename="temp/new_file.md", cache_dir=cache_dir)) as f: self.assertEqual(f.read(), content.getvalue().decode()) @use_tmp_repo() def test_upload_data_files_to_model_repo(self, repo_url: RepoUrl) -> None: # If a .parquet file is uploaded to a model repo, it should be uploaded correctly but a warning is raised. with self.assertWarns(UserWarning) as cm: self._api.upload_file( path_or_fileobj=b"content", path_in_repo="data.parquet", repo_id=repo_url.repo_id, ) assert ( cm.warnings[0].message.args[0] == "It seems that you are about to commit a data file (data.parquet) to a model repository. You are sure this is intended? If you are trying to upload a dataset, please set `repo_type='dataset'` or `--repo-type=dataset` in a CLI." ) # Same for arrow file with self.assertWarns(UserWarning) as cm: self._api.upload_file( path_or_fileobj=b"content", path_in_repo="data.arrow", repo_id=repo_url.repo_id, ) # Still correctly uploaded files = self._api.list_repo_files(repo_url.repo_id) assert "data.parquet" in files assert "data.arrow" in files def test_create_repo_return_value(self) -> None: REPO_NAME = repo_name("org") url = self._api.create_repo(repo_id=REPO_NAME) self.assertIsInstance(url, str) self.assertIsInstance(url, RepoUrl) self.assertEqual(url.repo_id, f"{USER}/{REPO_NAME}") self._api.delete_repo(repo_id=url.repo_id) def test_create_repo_already_exists_but_no_write_permission(self): # Create under other user namespace repo_id = self._api.create_repo(repo_id=repo_name(), token=OTHER_TOKEN).repo_id # Try to create with our namespace -> should not fail as the repo already exists self._api.create_repo(repo_id=repo_id, token=TOKEN, exist_ok=True) # Clean up self._api.delete_repo(repo_id=repo_id, token=OTHER_TOKEN) def test_create_repo_private_by_default(self): """Enterprise Hub allows creating private repos by default. Let's test that.""" repo_id = f"{ENTERPRISE_ORG}/{repo_name()}" self._api.create_repo(repo_id, token=ENTERPRISE_TOKEN) info = self._api.model_info(repo_id, token=ENTERPRISE_TOKEN, expand="private") assert info.private self._api.delete_repo(repo_id, token=ENTERPRISE_TOKEN) @use_tmp_repo() def test_upload_file_create_pr(self, repo_url: RepoUrl) -> None: repo_id = repo_url.repo_id return_val = self._api.upload_file( path_or_fileobj=self.tmp_file_content.encode(), path_in_repo="temp/new_file.md", repo_id=repo_id, create_pr=True, ) self.assertEqual(return_val, f"{repo_url}/blob/{quote('refs/pr/1', safe='')}/temp/new_file.md") self.assertIsInstance(return_val, CommitInfo) with SoftTemporaryDirectory() as cache_dir: with open( hf_hub_download( repo_id=repo_id, filename="temp/new_file.md", revision="refs/pr/1", cache_dir=cache_dir ) ) as f: self.assertEqual(f.read(), self.tmp_file_content) @use_tmp_repo() def test_delete_file(self, repo_url: RepoUrl) -> None: self._api.upload_file( path_or_fileobj=self.tmp_file, path_in_repo="temp/new_file.md", repo_id=repo_url.repo_id, ) return_val = self._api.delete_file(path_in_repo="temp/new_file.md", repo_id=repo_url.repo_id) self.assertIsInstance(return_val, CommitInfo) with self.assertRaises(EntryNotFoundError): # Should raise a 404 hf_hub_download(repo_url.repo_id, "temp/new_file.md") def test_get_full_repo_name(self): repo_name_with_no_org = self._api.get_full_repo_name("model") self.assertEqual(repo_name_with_no_org, f"{USER}/model") repo_name_with_no_org = self._api.get_full_repo_name("model", organization="org") self.assertEqual(repo_name_with_no_org, "org/model") @use_tmp_repo() def test_upload_folder(self, repo_url: RepoUrl) -> None: repo_id = repo_url.repo_id # Upload folder url = self._api.upload_folder(folder_path=self.tmp_dir, path_in_repo="temp/dir", repo_id=repo_id) self.assertEqual( url, f"{self._api.endpoint}/{repo_id}/tree/main/temp/dir", ) self.assertIsInstance(url, CommitInfo) # Check files are uploaded for rpath in ["temp", "nested/file.bin"]: local_path = os.path.join(self.tmp_dir, rpath) remote_path = f"temp/dir/{rpath}" filepath = hf_hub_download( repo_id=repo_id, filename=remote_path, revision="main", use_auth_token=self._token ) assert filepath is not None with open(filepath, "rb") as downloaded_file: content = downloaded_file.read() with open(local_path, "rb") as local_file: expected_content = local_file.read() self.assertEqual(content, expected_content) # Re-uploading the same folder twice should be fine return_val = self._api.upload_folder(folder_path=self.tmp_dir, path_in_repo="temp/dir", repo_id=repo_id) self.assertIsInstance(return_val, CommitInfo) @use_tmp_repo() def test_upload_folder_create_pr(self, repo_url: RepoUrl) -> None: repo_id = repo_url.repo_id # Upload folder as a new PR return_val = self._api.upload_folder( folder_path=self.tmp_dir, path_in_repo="temp/dir", repo_id=repo_id, create_pr=True ) self.assertEqual(return_val, f"{self._api.endpoint}/{repo_id}/tree/refs%2Fpr%2F1/temp/dir") # Check files are uploaded for rpath in ["temp", "nested/file.bin"]: local_path = os.path.join(self.tmp_dir, rpath) filepath = hf_hub_download(repo_id=repo_id, filename=f"temp/dir/{rpath}", revision="refs/pr/1") assert Path(local_path).read_bytes() == Path(filepath).read_bytes() def test_upload_folder_default_path_in_repo(self): REPO_NAME = repo_name("upload_folder_to_root") self._api.create_repo(repo_id=REPO_NAME, exist_ok=False) url = self._api.upload_folder(folder_path=self.tmp_dir, repo_id=f"{USER}/{REPO_NAME}") # URL to root of repository self.assertEqual(url, f"{self._api.endpoint}/{USER}/{REPO_NAME}/tree/main/") @use_tmp_repo() def test_upload_folder_git_folder_excluded(self, repo_url: RepoUrl) -> None: # Simulate a folder with a .git folder def _create_file(*parts) -> None: path = Path(self.tmp_dir, *parts) path.parent.mkdir(parents=True, exist_ok=True) path.write_text("content") _create_file(".git", "file.txt") _create_file(".cache", "huggingface", "file.txt") _create_file(".git", "folder", "file.txt") _create_file("folder", ".git", "file.txt") _create_file("folder", ".cache", "huggingface", "file.txt") _create_file("folder", ".git", "folder", "file.txt") _create_file(".git_something", "file.txt") _create_file("file.git") # Upload folder and check that .git folder is excluded self._api.upload_folder(folder_path=self.tmp_dir, repo_id=repo_url.repo_id) self.assertEqual( set(self._api.list_repo_files(repo_id=repo_url.repo_id)), {".gitattributes", ".git_something/file.txt", "file.git", "temp", "nested/file.bin"}, ) @use_tmp_repo() def test_upload_folder_gitignore_already_exists(self, repo_url: RepoUrl) -> None: # Ignore nested folder self._api.upload_file(path_or_fileobj=b"nested/*\n", path_in_repo=".gitignore", repo_id=repo_url.repo_id) # Upload folder self._api.upload_folder(folder_path=self.tmp_dir, repo_id=repo_url.repo_id) # Check nested file not uploaded assert not self._api.file_exists(repo_url.repo_id, "nested/file.bin") @use_tmp_repo() def test_upload_folder_gitignore_in_commit(self, repo_url: RepoUrl) -> None: # Create .gitignore file locally (Path(self.tmp_dir) / ".gitignore").write_text("nested/*\n") # Upload folder self._api.upload_folder(folder_path=self.tmp_dir, repo_id=repo_url.repo_id) # Check nested file not uploaded assert not self._api.file_exists(repo_url.repo_id, "nested/file.bin") @use_tmp_repo() def test_create_commit_create_pr(self, repo_url: RepoUrl) -> None: repo_id = repo_url.repo_id # Upload a first file self._api.upload_file(path_or_fileobj=self.tmp_file, path_in_repo="temp/new_file.md", repo_id=repo_id) # Create a commit with a PR operations = [ CommitOperationDelete(path_in_repo="temp/new_file.md"), CommitOperationAdd(path_in_repo="buffer", path_or_fileobj=b"Buffer data"), ] resp = self._api.create_commit( operations=operations, commit_message="Test create_commit", repo_id=repo_id, create_pr=True ) # Check commit info self.assertIsInstance(resp, CommitInfo) commit_id = resp.oid self.assertIn("pr_revision='refs/pr/1'", repr(resp)) self.assertIsInstance(commit_id, str) self.assertGreater(len(commit_id), 0) self.assertEqual(resp.commit_url, f"{self._api.endpoint}/{repo_id}/commit/{commit_id}") self.assertEqual(resp.commit_message, "Test create_commit") self.assertEqual(resp.commit_description, "") self.assertEqual(resp.pr_url, f"{self._api.endpoint}/{repo_id}/discussions/1") self.assertEqual(resp.pr_num, 1) self.assertEqual(resp.pr_revision, "refs/pr/1") # File doesn't exist on main... with self.assertRaises(HTTPError) as ctx: # Should raise a 404 self._api.hf_hub_download(repo_id, "buffer") self.assertEqual(ctx.exception.response.status_code, 404) # ...but exists on PR filepath = self._api.hf_hub_download(filename="buffer", repo_id=repo_id, revision="refs/pr/1") with open(filepath, "rb") as downloaded_file: content = downloaded_file.read() self.assertEqual(content, b"Buffer data") def test_create_commit_create_pr_against_branch(self): repo_id = f"{USER}/{repo_name()}" # Create repo and create a non-main branch self._api.create_repo(repo_id=repo_id, exist_ok=False) self._api.create_branch(repo_id=repo_id, branch="test_branch") head = self._api.list_repo_refs(repo_id=repo_id).branches[0].target_commit # Create PR against non-main branch works resp = self._api.create_commit( operations=[ CommitOperationAdd(path_in_repo="file.txt", path_or_fileobj=b"content"), ], commit_message="PR against existing branch", repo_id=repo_id, revision="test_branch", create_pr=True, ) self.assertIsInstance(resp, CommitInfo) # Create PR against a oid fails with self.assertRaises(RevisionNotFoundError): self._api.create_commit( operations=[ CommitOperationAdd(path_in_repo="file.txt", path_or_fileobj=b"content"), ], commit_message="PR against a oid", repo_id=repo_id, revision=head, create_pr=True, ) # Create PR against a non-existing branch fails with self.assertRaises(RevisionNotFoundError): self._api.create_commit( operations=[ CommitOperationAdd(path_in_repo="file.txt", path_or_fileobj=b"content"), ], commit_message="PR against missing branch", repo_id=repo_id, revision="missing_branch", create_pr=True, ) # Cleanup self._api.delete_repo(repo_id=repo_id) def test_create_commit_create_pr_on_foreign_repo(self): # Create a repo with another user. The normal CI user don't have rights on it. # We must be able to create a PR on it foreign_api = HfApi(token=OTHER_TOKEN) foreign_repo_url = foreign_api.create_repo(repo_id=repo_name("repo-for-hfh-ci")) self._api.create_commit( operations=[ CommitOperationAdd(path_in_repo="regular.txt", path_or_fileobj=b"File content"), CommitOperationAdd(path_in_repo="lfs.pkl", path_or_fileobj=b"File content"), ], commit_message="PR on foreign repo", repo_id=foreign_repo_url.repo_id, create_pr=True, ) foreign_api.delete_repo(repo_id=foreign_repo_url.repo_id) @use_tmp_repo() def test_create_commit(self, repo_url: RepoUrl) -> None: repo_id = repo_url.repo_id self._api.upload_file(path_or_fileobj=self.tmp_file, path_in_repo="temp/new_file.md", repo_id=repo_id) with open(self.tmp_file, "rb") as fileobj: operations = [ CommitOperationDelete(path_in_repo="temp/new_file.md"), CommitOperationAdd(path_in_repo="buffer", path_or_fileobj=b"Buffer data"), CommitOperationAdd( path_in_repo="bytesio", path_or_fileobj=BytesIO(b"BytesIO data"), ), CommitOperationAdd(path_in_repo="fileobj", path_or_fileobj=fileobj), CommitOperationAdd( path_in_repo="nested/path", path_or_fileobj=self.tmp_file, ), ] resp = self._api.create_commit(operations=operations, commit_message="Test create_commit", repo_id=repo_id) # Check commit info self.assertIsInstance(resp, CommitInfo) self.assertIsNone(resp.pr_url) # No pr created self.assertIsNone(resp.pr_num) self.assertIsNone(resp.pr_revision) with self.assertRaises(HTTPError): # Should raise a 404 hf_hub_download(repo_id, "temp/new_file.md") for path, expected_content in [ ("buffer", b"Buffer data"), ("bytesio", b"BytesIO data"), ("fileobj", self.tmp_file_content.encode()), ("nested/path", self.tmp_file_content.encode()), ]: filepath = hf_hub_download(repo_id=repo_id, filename=path, revision="main") assert filepath is not None with open(filepath, "rb") as downloaded_file: content = downloaded_file.read() self.assertEqual(content, expected_content) @use_tmp_repo() def test_create_commit_conflict(self, repo_url: RepoUrl) -> None: # Get commit on main repo_id = repo_url.repo_id parent_commit = self._api.model_info(repo_id).sha # Upload new file self._api.upload_file(path_or_fileobj=self.tmp_file, path_in_repo="temp/new_file.md", repo_id=repo_id) # Creating a commit with a parent commit that is not the current main should fail operations = [ CommitOperationAdd(path_in_repo="buffer", path_or_fileobj=b"Buffer data"), ] with self.assertRaises(HTTPError) as exc_ctx: self._api.create_commit( operations=operations, commit_message="Test create_commit", repo_id=repo_id, parent_commit=parent_commit, ) self.assertEqual(exc_ctx.exception.response.status_code, 412) self.assertIn( # Check the server message is added to the exception "A commit has happened since. Please refresh and try again.", str(exc_ctx.exception), ) def test_create_commit_repo_does_not_exist(self) -> None: """Test error message is detailed when creating a commit on a missing repo.""" with self.assertRaises(RepositoryNotFoundError) as context: self._api.create_commit( repo_id=f"{USER}/repo_that_do_not_exist", operations=[CommitOperationAdd("config.json", b"content")], commit_message="fake_message", ) request_id = context.exception.response.headers.get("X-Request-Id") expected_message = ( f"404 Client Error. (Request ID: {request_id})\n\nRepository Not" " Found for url:" f" {self._api.endpoint}/api/models/{USER}/repo_that_do_not_exist/preupload/main.\nPlease" " make sure you specified the correct `repo_id` and" " `repo_type`.\nIf you are trying to access a private or gated" " repo, make sure you are authenticated." " For more details, see https://huggingface.co/docs/huggingface_hub/authentication" "\nNote: Creating a commit assumes that the repo already exists on the Huggingface Hub." " Please use `create_repo` if it's not the case." ) assert str(context.exception) == expected_message @patch("huggingface_hub.utils._headers.get_token", return_value=TOKEN) def test_create_commit_lfs_file_implicit_token(self, get_token_mock: Mock) -> None: """Test that uploading a file as LFS works with cached token. Regression test for https://github.com/huggingface/huggingface_hub/pull/1084. """ REPO_NAME = repo_name("create_commit_with_lfs") repo_id = f"{USER}/{REPO_NAME}" with patch.object(self._api, "token", None): # no default token # Create repo self._api.create_repo(repo_id=REPO_NAME, exist_ok=False) # Set repo to track png files as LFS self._api.create_commit( operations=[ CommitOperationAdd( path_in_repo=".gitattributes", path_or_fileobj=b"*.png filter=lfs diff=lfs merge=lfs -text", ), ], commit_message="Update .gitattributes", repo_id=repo_id, ) # Upload a PNG file self._api.create_commit( operations=[ CommitOperationAdd(path_in_repo="image.png", path_or_fileobj=b"image data"), ], commit_message="Test upload lfs file", repo_id=repo_id, ) # Check uploaded as LFS info = self._api.model_info(repo_id=repo_id, files_metadata=True) siblings = {file.rfilename: file for file in info.siblings} self.assertIsInstance(siblings["image.png"].lfs, dict) # LFS file # Delete repo self._api.delete_repo(repo_id=REPO_NAME) @use_tmp_repo() def test_create_commit_huge_regular_files(self, repo_url: RepoUrl) -> None: """Test committing 12 text files (>100MB in total) at once. This was not possible when using `json` format instead of `ndjson` on the `/create-commit` endpoint. See https://github.com/huggingface/huggingface_hub/pull/1117. """ operations = [ CommitOperationAdd( path_in_repo=f"file-{num}.text", path_or_fileobj=b"Hello regular " + b"a" * 1024 * 1024 * 9, ) for num in range(12) ] self._api.create_commit( operations=operations, # 12*9MB regular => too much for "old" method commit_message="Test create_commit with huge regular files", repo_id=repo_url.repo_id, ) @use_tmp_repo() def test_commit_preflight_on_lots_of_lfs_files(self, repo_url: RepoUrl): """Test committing 1300 LFS files at once. This was not possible when `_fetch_upload_modes` was not fetching metadata by chunks. We are not testing the full upload as it would require to upload 1300 files which is unnecessary for the test. Having an overall large payload (for `/create-commit` endpoint) is tested in `test_create_commit_huge_regular_files`. There is also a 25k LFS files limit on the Hub but this is not tested. See https://github.com/huggingface/huggingface_hub/pull/1117. """ operations = [ CommitOperationAdd( path_in_repo=f"file-{num}.bin", # considered as LFS path_or_fileobj=b"Hello LFS" + b"a" * 2048, # big enough sample ) for num in range(1300) ] # Test `_fetch_upload_modes` preflight ("are they regular or LFS files?") _fetch_upload_modes( additions=operations, repo_type="model", repo_id=repo_url.repo_id, headers=self._api._build_hf_headers(), revision="main", endpoint=ENDPOINT_STAGING, ) for operation in operations: self.assertEqual(operation._upload_mode, "lfs") self.assertFalse(operation._is_committed) self.assertFalse(operation._is_uploaded) def test_create_commit_repo_id_case_insensitive(self): """Test create commit but repo_id is lowercased. Regression test for #1371. Hub API is already case insensitive. Somehow the issue was with the `requests` streaming implementation when generating the ndjson payload "on the fly". It seems that the server was receiving only the first line which causes a confusing "400 Bad Request - Add a line with the key `lfsFile`, `file` or `deletedFile`". Passing raw bytes instead of a generator fixes the problem. See https://github.com/huggingface/huggingface_hub/issues/1371. """ REPO_NAME = repo_name("CaSe_Is_ImPoRtAnT") repo_id = self._api.create_repo(repo_id=REPO_NAME, exist_ok=False).repo_id self._api.create_commit( repo_id=repo_id.lower(), # API is case-insensitive! commit_message="Add 1 regular and 1 LFs files.", operations=[ CommitOperationAdd(path_in_repo="file.txt", path_or_fileobj=b"content"), CommitOperationAdd(path_in_repo="lfs.bin", path_or_fileobj=b"LFS content"), ], ) repo_files = self._api.list_repo_files(repo_id=repo_id) self.assertIn("file.txt", repo_files) self.assertIn("lfs.bin", repo_files) @use_tmp_repo() def test_commit_copy_file(self, repo_url: RepoUrl) -> None: """Test CommitOperationCopy. Works only when copying an LFS file. """ repo_id = repo_url.repo_id self._api.upload_file(path_or_fileobj=b"content", repo_id=repo_id, path_in_repo="file.txt") self._api.upload_file(path_or_fileobj=b"LFS content", repo_id=repo_id, path_in_repo="lfs.bin") self._api.create_commit( repo_id=repo_id, commit_message="Copy LFS file.", operations=[ CommitOperationCopy(src_path_in_repo="lfs.bin", path_in_repo="lfs Copy.bin"), CommitOperationCopy(src_path_in_repo="lfs.bin", path_in_repo="lfs Copy (1).bin"), ], ) self._api.create_commit( repo_id=repo_id, commit_message="Copy regular file.", operations=[CommitOperationCopy(src_path_in_repo="file.txt", path_in_repo="file Copy.txt")], ) with self.assertRaises(EntryNotFoundError): self._api.create_commit( repo_id=repo_id, commit_message="Copy a file that doesn't exist.", operations=[ CommitOperationCopy(src_path_in_repo="doesnt-exist.txt", path_in_repo="doesnt-exist Copy.txt") ], ) # Check repo files repo_files = self._api.list_repo_files(repo_id=repo_id) self.assertIn("file.txt", repo_files) self.assertIn("file Copy.txt", repo_files) self.assertIn("lfs.bin", repo_files) self.assertIn("lfs Copy.bin", repo_files) self.assertIn("lfs Copy (1).bin", repo_files) # Check same LFS file repo_file1, repo_file2 = self._api.get_paths_info(repo_id=repo_id, paths=["lfs.bin", "lfs Copy.bin"]) self.assertEqual(repo_file1.lfs["sha256"], repo_file2.lfs["sha256"]) @use_tmp_repo() def test_create_commit_mutates_operations(self, repo_url: RepoUrl) -> None: repo_id = repo_url.repo_id operations = [ CommitOperationAdd(path_in_repo="lfs.bin", path_or_fileobj=b"content"), CommitOperationAdd(path_in_repo="file.txt", path_or_fileobj=b"content"), ] self._api.create_commit( repo_id=repo_id, commit_message="Copy LFS file.", operations=operations, ) assert operations[0]._is_committed assert operations[0]._is_uploaded # LFS file self.assertEqual(operations[0].path_or_fileobj, b"content") # not removed by default assert operations[1]._is_committed self.assertEqual(operations[1].path_or_fileobj, b"content") @use_tmp_repo() def test_pre_upload_before_commit(self, repo_url: RepoUrl) -> None: repo_id = repo_url.repo_id operations = [ CommitOperationAdd(path_in_repo="lfs.bin", path_or_fileobj=b"content1"), CommitOperationAdd(path_in_repo="file.txt", path_or_fileobj=b"content"), CommitOperationAdd(path_in_repo="lfs2.bin", path_or_fileobj=b"content2"), CommitOperationAdd(path_in_repo="file2.txt", path_or_fileobj=b"content"), ] # First: preupload 1 by 1 for operation in operations: self._api.preupload_lfs_files(repo_id, [operation]) assert operations[0]._is_uploaded self.assertEqual(operations[0].path_or_fileobj, b"") # Freed memory assert operations[2]._is_uploaded self.assertEqual(operations[2].path_or_fileobj, b"") # Freed memory # create commit and capture debug logs with self.assertLogs("huggingface_hub", level="DEBUG") as debug_logs: self._api.create_commit( repo_id=repo_id, commit_message="Copy LFS file.", operations=operations, ) # No LFS files uploaded during commit assert any("No LFS files to upload." in log for log in debug_logs.output) @use_tmp_repo() def test_commit_modelcard_invalid_metadata(self, repo_url: RepoUrl) -> None: with patch.object(self._api, "preupload_lfs_files") as mock: with self.assertRaisesRegex(ValueError, "Invalid metadata in README.md"): self._api.create_commit( repo_id=repo_url.repo_id, operations=[ CommitOperationAdd(path_in_repo="README.md", path_or_fileobj=INVALID_MODELCARD.encode()), CommitOperationAdd(path_in_repo="file.txt", path_or_fileobj=b"content"), CommitOperationAdd(path_in_repo="lfs.bin", path_or_fileobj=b"content"), ], commit_message="Test commit", ) # Failed early => no LFS files uploaded mock.assert_not_called() @use_tmp_repo() def test_commit_modelcard_empty_metadata(self, repo_url: RepoUrl) -> None: modelcard = "This is a modelcard without metadata" with self.assertWarnsRegex(UserWarning, "Warnings while validating metadata in README.md"): commit = self._api.create_commit( repo_id=repo_url.repo_id, operations=[ CommitOperationAdd(path_in_repo="README.md", path_or_fileobj=modelcard.encode()), CommitOperationAdd(path_in_repo="file.txt", path_or_fileobj=b"content"), CommitOperationAdd(path_in_repo="lfs.bin", path_or_fileobj=b"content"), ], commit_message="Test commit", ) # Commit still happened correctly assert isinstance(commit, CommitInfo) def test_create_file_with_relative_path(self): """Creating a file with a relative path_in_repo is forbidden. Previously taken from a regression test for HackerOne report 1928845. The bug enabled attackers to create files outside of the local dir if users downloaded a file with a relative path_in_repo on Windows. This is not relevant anymore as the API now forbids such paths. """ repo_id = self._api.create_repo(repo_id=repo_name()).repo_id with self.assertRaises(HfHubHTTPError) as cm: self._api.upload_file(path_or_fileobj=b"content", path_in_repo="..\\ddd", repo_id=repo_id) assert cm.exception.response.status_code == 422 @use_tmp_repo() def test_prevent_empty_commit_if_no_op(self, repo_url: RepoUrl) -> None: with self.assertLogs("huggingface_hub", level="INFO") as logs: self._api.create_commit(repo_id=repo_url.repo_id, commit_message="Empty commit", operations=[]) assert ( logs.records[0].message == "No files have been modified since last commit. Skipping to prevent empty commit." ) assert logs.records[0].levelname == "WARNING" @use_tmp_repo() def test_prevent_empty_commit_if_no_new_addition(self, repo_url: RepoUrl) -> None: self._api.create_commit( repo_id=repo_url.repo_id, commit_message="initial commit", operations=[ CommitOperationAdd(path_or_fileobj=b"Regular file content", path_in_repo="file.txt"), CommitOperationAdd(path_or_fileobj=b"LFS content", path_in_repo="lfs.bin"), ], ) with self.assertLogs("huggingface_hub", level="INFO") as logs: self._api.create_commit( repo_id=repo_url.repo_id, commit_message="Empty commit", operations=[ CommitOperationAdd(path_or_fileobj=b"Regular file content", path_in_repo="file.txt"), CommitOperationAdd(path_or_fileobj=b"LFS content", path_in_repo="lfs.bin"), ], ) assert logs.records[0].message == "Removing 2 file(s) from commit that have not changed." assert logs.records[0].levelname == "INFO" assert ( logs.records[1].message == "No files have been modified since last commit. Skipping to prevent empty commit." ) assert logs.records[1].levelname == "WARNING" @use_tmp_repo() def test_prevent_empty_commit_if_no_new_copy(self, repo_url: RepoUrl) -> None: # Add 2 regular identical files and 2 LFS identical files self._api.create_commit( repo_id=repo_url.repo_id, commit_message="initial commit", operations=[ CommitOperationAdd(path_or_fileobj=b"Regular file content", path_in_repo="file.txt"), CommitOperationAdd(path_or_fileobj=b"Regular file content", path_in_repo="file_copy.txt"), CommitOperationAdd(path_or_fileobj=b"LFS content", path_in_repo="lfs.bin"), CommitOperationAdd(path_or_fileobj=b"LFS content", path_in_repo="lfs_copy.bin"), ], ) with self.assertLogs("huggingface_hub", level="INFO") as logs: self._api.create_commit( repo_id=repo_url.repo_id, commit_message="Empty commit", operations=[ CommitOperationCopy(src_path_in_repo="file.txt", path_in_repo="file_copy.txt"), CommitOperationCopy(src_path_in_repo="lfs.bin", path_in_repo="lfs_copy.bin"), ], ) assert logs.records[0].message == "Removing 2 file(s) from commit that have not changed." assert logs.records[0].levelname == "INFO" assert ( logs.records[1].message == "No files have been modified since last commit. Skipping to prevent empty commit." ) assert logs.records[1].levelname == "WARNING" @use_tmp_repo() def test_empty_commit_on_pr(self, repo_url: RepoUrl) -> None: """ Regression test for #2411. Revision was quoted twice, leading to a HTTP 404. See https://github.com/huggingface/huggingface_hub/issues/2411. """ pr = self._api.create_pull_request(repo_id=repo_url.repo_id, title="Test PR") with self.assertLogs("huggingface_hub", level="WARNING"): url = self._api.create_commit( repo_id=repo_url.repo_id, operations=[], commit_message="Empty commit", revision=pr.git_reference, ) commits = self._api.list_repo_commits(repo_id=repo_url.repo_id, revision=pr.git_reference) assert len(commits) == 1 # no 2nd commit assert url.oid == commits[0].commit_id @use_tmp_repo() def test_continue_commit_without_existing_files(self, repo_url: RepoUrl) -> None: self._api.create_commit( repo_id=repo_url.repo_id, commit_message="initial commit", operations=[ CommitOperationAdd(path_or_fileobj=b"content 1.0", path_in_repo="file.txt"), CommitOperationAdd(path_or_fileobj=b"content 2.0", path_in_repo="file2.txt"), CommitOperationAdd(path_or_fileobj=b"LFS content 1.0", path_in_repo="lfs.bin"), CommitOperationAdd(path_or_fileobj=b"LFS content 2.0", path_in_repo="lfs2.bin"), ], ) with self.assertLogs("huggingface_hub", level="DEBUG") as logs: self._api.create_commit( repo_id=repo_url.repo_id, commit_message="second commit", operations=[ # Did not change => will be removed from commit CommitOperationAdd(path_or_fileobj=b"content 1.0", path_in_repo="file.txt"), # Change => will be kept CommitOperationAdd(path_or_fileobj=b"content 2.1", path_in_repo="file2.txt"), # New file => will be kept CommitOperationAdd(path_or_fileobj=b"content 3.0", path_in_repo="file3.txt"), # Did not change => will be removed from commit CommitOperationAdd(path_or_fileobj=b"LFS content 1.0", path_in_repo="lfs.bin"), # Change => will be kept CommitOperationAdd(path_or_fileobj=b"LFS content 2.1", path_in_repo="lfs2.bin"), # New file => will be kept CommitOperationAdd(path_or_fileobj=b"LFS content 3.0", path_in_repo="lfs3.bin"), ], ) debug_logs = [log.message for log in logs.records if log.levelname == "DEBUG"] info_logs = [log.message for log in logs.records if log.levelname == "INFO"] warning_logs = [log.message for log in logs.records if log.levelname == "WARNING"] assert "Skipping upload for 'file.txt' as the file has not changed." in debug_logs assert "Skipping upload for 'lfs.bin' as the file has not changed." in debug_logs assert "Removing 2 file(s) from commit that have not changed." in info_logs assert len(warning_logs) == 0 # no warnings since the commit is not empty paths_info = { item.path: item.last_commit for item in self._api.get_paths_info( repo_id=repo_url.repo_id, paths=["file.txt", "file2.txt", "file3.txt", "lfs.bin", "lfs2.bin", "lfs3.bin"], expand=True, ) } # Check which files are in the last commit assert paths_info["file.txt"].title == "initial commit" assert paths_info["file2.txt"].title == "second commit" assert paths_info["file3.txt"].title == "second commit" assert paths_info["lfs.bin"].title == "initial commit" assert paths_info["lfs2.bin"].title == "second commit" assert paths_info["lfs3.bin"].title == "second commit" @use_tmp_repo() def test_continue_commit_if_copy_is_identical(self, repo_url: RepoUrl) -> None: self._api.create_commit( repo_id=repo_url.repo_id, commit_message="initial commit", operations=[ CommitOperationAdd(path_or_fileobj=b"content 1.0", path_in_repo="file.txt"), CommitOperationAdd(path_or_fileobj=b"content 1.0", path_in_repo="file_copy.txt"), CommitOperationAdd(path_or_fileobj=b"content 2.0", path_in_repo="file2.txt"), CommitOperationAdd(path_or_fileobj=b"LFS content 1.0", path_in_repo="lfs.bin"), CommitOperationAdd(path_or_fileobj=b"LFS content 1.0", path_in_repo="lfs_copy.bin"), CommitOperationAdd(path_or_fileobj=b"LFS content 2.0", path_in_repo="lfs2.bin"), ], ) with self.assertLogs("huggingface_hub", level="DEBUG") as logs: self._api.create_commit( repo_id=repo_url.repo_id, commit_message="second commit", operations=[ # Did not change => will be removed from commit CommitOperationCopy(src_path_in_repo="file.txt", path_in_repo="file_copy.txt"), # Change => will be kept CommitOperationCopy(src_path_in_repo="file2.txt", path_in_repo="file.txt"), # New file => will be kept CommitOperationCopy(src_path_in_repo="file2.txt", path_in_repo="file3.txt"), # Did not change => will be removed from commit CommitOperationCopy(src_path_in_repo="lfs.bin", path_in_repo="lfs_copy.bin"), # Change => will be kept CommitOperationCopy(src_path_in_repo="lfs2.bin", path_in_repo="lfs.bin"), # New file => will be kept CommitOperationCopy(src_path_in_repo="lfs2.bin", path_in_repo="lfs3.bin"), ], ) debug_logs = [log.message for log in logs.records if log.levelname == "DEBUG"] info_logs = [log.message for log in logs.records if log.levelname == "INFO"] warning_logs = [log.message for log in logs.records if log.levelname == "WARNING"] assert ( "Skipping copy for 'file.txt' -> 'file_copy.txt' as the content of the source file is the same as the destination file." in debug_logs ) assert ( "Skipping copy for 'lfs.bin' -> 'lfs_copy.bin' as the content of the source file is the same as the destination file." in debug_logs ) assert "Removing 2 file(s) from commit that have not changed." in info_logs assert len(warning_logs) == 0 # no warnings since the commit is not empty paths_info = { item.path: item.last_commit for item in self._api.get_paths_info( repo_id=repo_url.repo_id, paths=[ "file.txt", "file_copy.txt", "file3.txt", "lfs.bin", "lfs_copy.bin", "lfs3.bin", ], expand=True, ) } # Check which files are in the last commit assert paths_info["file.txt"].title == "second commit" assert paths_info["file_copy.txt"].title == "initial commit" assert paths_info["file3.txt"].title == "second commit" assert paths_info["lfs.bin"].title == "second commit" assert paths_info["lfs_copy.bin"].title == "initial commit" assert paths_info["lfs3.bin"].title == "second commit" @use_tmp_repo() def test_continue_commit_if_only_deletion(self, repo_url: RepoUrl) -> None: self._api.create_commit( repo_id=repo_url.repo_id, commit_message="initial commit", operations=[ CommitOperationAdd(path_or_fileobj=b"content 1.0", path_in_repo="file.txt"), CommitOperationAdd(path_or_fileobj=b"content 1.0", path_in_repo="file_copy.txt"), CommitOperationAdd(path_or_fileobj=b"content 2.0", path_in_repo="file2.txt"), ], ) with self.assertLogs("huggingface_hub", level="DEBUG") as logs: self._api.create_commit( repo_id=repo_url.repo_id, commit_message="second commit", operations=[ # Did not change => will be removed from commit CommitOperationAdd(path_or_fileobj=b"content 1.0", path_in_repo="file.txt"), # identical to file.txt => will be removed from commit CommitOperationCopy(src_path_in_repo="file.txt", path_in_repo="file_copy.txt"), # Delete operation => kept in any case CommitOperationDelete(path_in_repo="file2.txt"), ], ) debug_logs = [log.message for log in logs.records if log.levelname == "DEBUG"] info_logs = [log.message for log in logs.records if log.levelname == "INFO"] warning_logs = [log.message for log in logs.records if log.levelname == "WARNING"] assert "Skipping upload for 'file.txt' as the file has not changed." in debug_logs assert ( "Skipping copy for 'file.txt' -> 'file_copy.txt' as the content of the source file is the same as the destination file." in debug_logs ) assert "Removing 2 file(s) from commit that have not changed." in info_logs assert len(warning_logs) == 0 # no warnings since the commit is not empty remote_files = self._api.list_repo_files(repo_id=repo_url.repo_id) assert "file.txt" in remote_files assert "file2.txt" not in remote_files class HfApiUploadEmptyFileTest(HfApiCommonTest): @classmethod def setUpClass(cls): super().setUpClass() # Create repo for all tests as they are not dependent on each other. cls.repo_id = f"{USER}/{repo_name('upload_empty_file')}" cls._api.create_repo(repo_id=cls.repo_id, exist_ok=False) @classmethod def tearDownClass(cls): cls._api.delete_repo(repo_id=cls.repo_id) super().tearDownClass() def test_upload_empty_lfs_file(self) -> None: # Should have been an LFS file, but uploaded as regular (would fail otherwise) self._api.upload_file(repo_id=self.repo_id, path_in_repo="empty.pkl", path_or_fileobj=b"") info = self._api.repo_info(repo_id=self.repo_id, files_metadata=True) repo_file = {file.rfilename: file for file in info.siblings}["empty.pkl"] self.assertEqual(repo_file.size, 0) self.assertIsNone(repo_file.lfs) # As regular class HfApiDeleteFolderTest(HfApiCommonTest): def setUp(self): self.repo_id = f"{USER}/{repo_name('create_commit_delete_folder')}" self._api.create_repo(repo_id=self.repo_id, exist_ok=False) self._api.create_commit( repo_id=self.repo_id, commit_message="Init repo", operations=[ CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="1/file_1.md"), CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="1/file_2.md"), CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="2/file_3.md"), ], ) def tearDown(self): self._api.delete_repo(repo_id=self.repo_id) def test_create_commit_delete_folder_implicit(self): self._api.create_commit( operations=[CommitOperationDelete(path_in_repo="1/")], commit_message="Test delete folder implicit", repo_id=self.repo_id, ) with self.assertRaises(EntryNotFoundError): hf_hub_download(self.repo_id, "1/file_1.md", use_auth_token=self._token) with self.assertRaises(EntryNotFoundError): hf_hub_download(self.repo_id, "1/file_2.md", use_auth_token=self._token) # Still exists hf_hub_download(self.repo_id, "2/file_3.md", use_auth_token=self._token) def test_create_commit_delete_folder_explicit(self): self._api.delete_folder(path_in_repo="1", repo_id=self.repo_id) with self.assertRaises(EntryNotFoundError): hf_hub_download(self.repo_id, "1/file_1.md", use_auth_token=self._token) def test_create_commit_implicit_delete_folder_is_ok(self): self._api.create_commit( operations=[CommitOperationDelete(path_in_repo="1")], commit_message="Failing delete folder", repo_id=self.repo_id, ) class HfApiListFilesInfoTest(HfApiCommonTest): @classmethod def setUpClass(cls): super().setUpClass() cls.repo_id = cls._api.create_repo(repo_id=repo_name()).repo_id cls._api.create_commit( repo_id=cls.repo_id, commit_message="A first repo", operations=[ CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="file.md"), CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="lfs.bin"), CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="1/file_1.md"), CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="1/2/file_1_2.md"), CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="2/file_2.md"), ], ) cls._api.create_commit( repo_id=cls.repo_id, commit_message="Another commit", operations=[ CommitOperationAdd(path_or_fileobj=b"data2", path_in_repo="3/file_3.md"), ], ) @classmethod def tearDownClass(cls): cls._api.delete_repo(repo_id=cls.repo_id) class HfApiListRepoTreeTest(HfApiCommonTest): @classmethod def setUpClass(cls): super().setUpClass() cls.repo_id = cls._api.create_repo(repo_id=repo_name()).repo_id cls._api.create_commit( repo_id=cls.repo_id, commit_message="A first repo", operations=[ CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="file.md"), CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="lfs.bin"), CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="1/file_1.md"), CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="1/2/file_1_2.md"), CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="2/file_2.md"), ], ) cls._api.create_commit( repo_id=cls.repo_id, commit_message="Another commit", operations=[ CommitOperationAdd(path_or_fileobj=b"data2", path_in_repo="3/file_3.md"), ], ) @classmethod def tearDownClass(cls): cls._api.delete_repo(repo_id=cls.repo_id) def test_list_tree(self): tree = list(self._api.list_repo_tree(repo_id=self.repo_id)) self.assertEqual(len(tree), 6) self.assertEqual({tree_obj.path for tree_obj in tree}, {"file.md", "lfs.bin", "1", "2", "3", ".gitattributes"}) tree = list(self._api.list_repo_tree(repo_id=self.repo_id, path_in_repo="1")) self.assertEqual(len(tree), 2) self.assertEqual({tree_obj.path for tree_obj in tree}, {"1/file_1.md", "1/2"}) def test_list_tree_recursively(self): tree = list(self._api.list_repo_tree(repo_id=self.repo_id, recursive=True)) self.assertEqual(len(tree), 11) self.assertEqual( {tree_obj.path for tree_obj in tree}, { "file.md", "lfs.bin", "1/file_1.md", "1/2/file_1_2.md", "2/file_2.md", "3/file_3.md", "1", "2", "1/2", "3", ".gitattributes", }, ) def test_list_unknown_tree(self): with self.assertRaises(EntryNotFoundError): list(self._api.list_repo_tree(repo_id=self.repo_id, path_in_repo="unknown")) def test_list_with_empty_path(self): self.assertEqual( set(tree_obj.path for tree_obj in self._api.list_repo_tree(repo_id=self.repo_id, path_in_repo="")), set(tree_obj.path for tree_obj in self._api.list_repo_tree(repo_id=self.repo_id)), ) @with_production_testing def test_list_tree_with_expand(self): tree = list( HfApi().list_repo_tree( repo_id="prompthero/openjourney-v4", expand=True, revision="c9211c53404dd6f4cfac5f04f33535892260668e", ) ) assert len(tree) == 11 # check last_commit and security are present for a file model_ckpt = next(tree_obj for tree_obj in tree if tree_obj.path == "openjourney-v4.ckpt") assert model_ckpt.last_commit is not None assert model_ckpt.last_commit["oid"] == "bda967fdb79a50844e4a02cccae3217a8ecc86cd" assert model_ckpt.security is not None assert model_ckpt.security["safe"] assert isinstance(model_ckpt.security["av_scan"], dict) # all details in here # check last_commit is present for a folder feature_extractor = next(tree_obj for tree_obj in tree if tree_obj.path == "feature_extractor") self.assertIsNotNone(feature_extractor.last_commit) self.assertEqual(feature_extractor.last_commit["oid"], "47b62b20b20e06b9de610e840282b7e6c3d51190") @with_production_testing def test_list_files_without_expand(self): tree = list( HfApi().list_repo_tree( repo_id="prompthero/openjourney-v4", revision="c9211c53404dd6f4cfac5f04f33535892260668e", ) ) self.assertEqual(len(tree), 11) # check last_commit and security are missing for a file model_ckpt = next(tree_obj for tree_obj in tree if tree_obj.path == "openjourney-v4.ckpt") self.assertIsNone(model_ckpt.last_commit) self.assertIsNone(model_ckpt.security) # check last_commit is missing for a folder feature_extractor = next(tree_obj for tree_obj in tree if tree_obj.path == "feature_extractor") self.assertIsNone(feature_extractor.last_commit) class HfApiTagEndpointTest(HfApiCommonTest): @use_tmp_repo("model") def test_create_tag_on_main(self, repo_url: RepoUrl) -> None: """Check `create_tag` on default main branch works.""" self._api.create_tag(repo_url.repo_id, tag="v0", tag_message="This is a tag message.") # Check tag is on `main` tag_info = self._api.model_info(repo_url.repo_id, revision="v0") main_info = self._api.model_info(repo_url.repo_id, revision="main") self.assertEqual(tag_info.sha, main_info.sha) @use_tmp_repo("model") def test_create_tag_on_pr(self, repo_url: RepoUrl) -> None: """Check `create_tag` on a PR ref works.""" # Create a PR with a readme commit_info: CommitInfo = self._api.create_commit( repo_id=repo_url.repo_id, create_pr=True, commit_message="upload readme", operations=[CommitOperationAdd(path_or_fileobj=b"this is a file content", path_in_repo="readme.md")], ) # Tag the PR self._api.create_tag(repo_url.repo_id, tag="v0", revision=commit_info.pr_revision) # Check tag is on `refs/pr/1` tag_info = self._api.model_info(repo_url.repo_id, revision="v0") pr_info = self._api.model_info(repo_url.repo_id, revision=commit_info.pr_revision) main_info = self._api.model_info(repo_url.repo_id) self.assertEqual(tag_info.sha, pr_info.sha) self.assertNotEqual(tag_info.sha, main_info.sha) @use_tmp_repo("dataset") def test_create_tag_on_commit_oid(self, repo_url: RepoUrl) -> None: """Check `create_tag` on specific commit oid works (both long and shorthands). Test it on a `dataset` repo. """ # Create a PR with a readme commit_info_1: CommitInfo = self._api.create_commit( repo_id=repo_url.repo_id, repo_type="dataset", commit_message="upload readme", operations=[CommitOperationAdd(path_or_fileobj=b"this is a file content", path_in_repo="readme.md")], ) commit_info_2: CommitInfo = self._api.create_commit( repo_id=repo_url.repo_id, repo_type="dataset", commit_message="upload config", operations=[CommitOperationAdd(path_or_fileobj=b"{'hello': 'world'}", path_in_repo="config.json")], ) # Tag commits self._api.create_tag( repo_url.repo_id, tag="commit_1", repo_type="dataset", revision=commit_info_1.oid, # long version ) self._api.create_tag( repo_url.repo_id, tag="commit_2", repo_type="dataset", revision=commit_info_2.oid[:7], # use shorthand ! ) # Check tags tag_1_info = self._api.dataset_info(repo_url.repo_id, revision="commit_1") tag_2_info = self._api.dataset_info(repo_url.repo_id, revision="commit_2") self.assertEqual(tag_1_info.sha, commit_info_1.oid) self.assertEqual(tag_2_info.sha, commit_info_2.oid) @use_tmp_repo("model") def test_invalid_tag_name(self, repo_url: RepoUrl) -> None: """Check `create_tag` with an invalid tag name.""" with self.assertRaises(HTTPError): self._api.create_tag(repo_url.repo_id, tag="invalid tag") @use_tmp_repo("model") def test_create_tag_on_missing_revision(self, repo_url: RepoUrl) -> None: """Check `create_tag` on a missing revision.""" with self.assertRaises(RevisionNotFoundError): self._api.create_tag(repo_url.repo_id, tag="invalid tag", revision="foobar") @use_tmp_repo("model") def test_create_tag_twice(self, repo_url: RepoUrl) -> None: """Check `create_tag` called twice on same tag should fail with HTTP 409.""" self._api.create_tag(repo_url.repo_id, tag="tag_1") with self.assertRaises(HfHubHTTPError) as err: self._api.create_tag(repo_url.repo_id, tag="tag_1") self.assertEqual(err.exception.response.status_code, 409) # exist_ok=True => doesn't fail self._api.create_tag(repo_url.repo_id, tag="tag_1", exist_ok=True) @use_tmp_repo("model") def test_create_and_delete_tag(self, repo_url: RepoUrl) -> None: """Check `delete_tag` deletes the tag.""" self._api.create_tag(repo_url.repo_id, tag="v0") self._api.model_info(repo_url.repo_id, revision="v0") self._api.delete_tag(repo_url.repo_id, tag="v0") with self.assertRaises(RevisionNotFoundError): self._api.model_info(repo_url.repo_id, revision="v0") @use_tmp_repo("model") def test_delete_tag_missing_tag(self, repo_url: RepoUrl) -> None: """Check cannot `delete_tag` if tag doesn't exist.""" with self.assertRaises(RevisionNotFoundError): self._api.delete_tag(repo_url.repo_id, tag="v0") @use_tmp_repo("model") def test_delete_tag_with_branch_name(self, repo_url: RepoUrl) -> None: """Try to `delete_tag` if tag is a branch name. Currently getting a HTTP 500. See https://github.com/huggingface/moon-landing/issues/4223. """ with self.assertRaises(HfHubHTTPError): self._api.delete_tag(repo_url.repo_id, tag="main") class HfApiBranchEndpointTest(HfApiCommonTest): @use_tmp_repo() def test_create_and_delete_branch(self, repo_url: RepoUrl) -> None: """Test `create_branch` from main branch.""" self._api.create_branch(repo_url.repo_id, branch="cool-branch") # Check `cool-branch` branch exists self._api.model_info(repo_url.repo_id, revision="cool-branch") # Delete it self._api.delete_branch(repo_url.repo_id, branch="cool-branch") # Check doesn't exist anymore with self.assertRaises(RevisionNotFoundError): self._api.model_info(repo_url.repo_id, revision="cool-branch") @use_tmp_repo() def test_create_branch_existing_branch_fails(self, repo_url: RepoUrl) -> None: """Test `create_branch` on existing branch.""" self._api.create_branch(repo_url.repo_id, branch="cool-branch") with self.assertRaisesRegex(HfHubHTTPError, "Reference already exists"): self._api.create_branch(repo_url.repo_id, branch="cool-branch") with self.assertRaisesRegex(HfHubHTTPError, "Reference already exists"): self._api.create_branch(repo_url.repo_id, branch="main") # exist_ok=True => doesn't fail self._api.create_branch(repo_url.repo_id, branch="cool-branch", exist_ok=True) self._api.create_branch(repo_url.repo_id, branch="main", exist_ok=True) @use_tmp_repo() def test_create_branch_existing_tag_does_not_fail(self, repo_url: RepoUrl) -> None: """Test `create_branch` on existing tag.""" self._api.create_tag(repo_url.repo_id, tag="tag") self._api.create_branch(repo_url.repo_id, branch="tag") @unittest.skip( "Test user is flagged as isHF which gives permissions to create invalid references." "Not relevant to test it anyway (i.e. it's more a server-side test)." ) @use_tmp_repo() def test_create_branch_forbidden_ref_branch_fails(self, repo_url: RepoUrl) -> None: """Test `create_branch` on forbidden ref branch.""" with self.assertRaisesRegex(BadRequestError, "Invalid reference for a branch"): self._api.create_branch(repo_url.repo_id, branch="refs/pr/5") with self.assertRaisesRegex(BadRequestError, "Invalid reference for a branch"): self._api.create_branch(repo_url.repo_id, branch="refs/something/random") @use_tmp_repo() def test_delete_branch_on_protected_branch_fails(self, repo_url: RepoUrl) -> None: """Test `delete_branch` fails on protected branch.""" with self.assertRaisesRegex(HfHubHTTPError, "Cannot delete refs/heads/main"): self._api.delete_branch(repo_url.repo_id, branch="main") @use_tmp_repo() def test_delete_branch_on_missing_branch_fails(self, repo_url: RepoUrl) -> None: """Test `delete_branch` fails on missing branch.""" with self.assertRaisesRegex(HfHubHTTPError, "Invalid rev id"): self._api.delete_branch(repo_url.repo_id, branch="cool-branch") # Using a tag instead of branch -> fails self._api.create_tag(repo_url.repo_id, tag="cool-tag") with self.assertRaisesRegex(HfHubHTTPError, "Invalid rev id"): self._api.delete_branch(repo_url.repo_id, branch="cool-tag") @use_tmp_repo() def test_create_branch_from_revision(self, repo_url: RepoUrl) -> None: """Test `create_branch` from a different revision than main HEAD.""" # Create commit and remember initial/latest commit initial_commit = self._api.model_info(repo_url.repo_id).sha commit = self._api.create_commit( repo_url.repo_id, operations=[CommitOperationAdd(path_in_repo="app.py", path_or_fileobj=b"content")], commit_message="test commit", ) latest_commit = commit.oid # Create branches self._api.create_branch(repo_url.repo_id, branch="from-head") self._api.create_branch(repo_url.repo_id, branch="from-initial", revision=initial_commit) self._api.create_branch(repo_url.repo_id, branch="from-branch", revision="from-initial") time.sleep(0.2) # hack: wait for server to update cache? # Checks branches start from expected commits self.assertEqual( { "main": latest_commit, "from-head": latest_commit, "from-initial": initial_commit, "from-branch": initial_commit, }, {ref.name: ref.target_commit for ref in self._api.list_repo_refs(repo_id=repo_url.repo_id).branches}, ) class HfApiDeleteFilesTest(HfApiCommonTest): def setUp(self) -> None: super().setUp() self.repo_id = self._api.create_repo(repo_id=repo_name()).repo_id self._api.create_commit( repo_id=self.repo_id, operations=[ # Regular files CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="file.txt"), CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="nested/file.txt"), CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="nested/sub/file.txt"), # LFS files CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="lfs.bin"), CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="nested/lfs.bin"), CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="nested/sub/lfs.bin"), ], commit_message="Init repo structure", ) def tearDown(self) -> None: self._api.delete_repo(repo_id=self.repo_id) super().tearDown() def remote_files(self) -> Set[set]: return set(self._api.list_repo_files(repo_id=self.repo_id)) def test_delete_single_file(self): self._api.delete_files(repo_id=self.repo_id, delete_patterns=["file.txt"]) assert "file.txt" not in self.remote_files() def test_delete_multiple_files(self): self._api.delete_files(repo_id=self.repo_id, delete_patterns=["file.txt", "lfs.bin"]) files = self.remote_files() assert "file.txt" not in files assert "lfs.bin" not in files def test_delete_folder_with_pattern(self): self._api.delete_files(repo_id=self.repo_id, delete_patterns=["nested/*"]) assert self.remote_files() == {".gitattributes", "file.txt", "lfs.bin"} def test_delete_folder_without_pattern(self): self._api.delete_files(repo_id=self.repo_id, delete_patterns=["nested/"]) assert self.remote_files() == {".gitattributes", "file.txt", "lfs.bin"} def test_unknown_path_do_not_raise(self): self._api.delete_files(repo_id=self.repo_id, delete_patterns=["not_existing", "nested/*"]) assert self.remote_files() == {".gitattributes", "file.txt", "lfs.bin"} def test_delete_bin_files_with_patterns(self): self._api.delete_files(repo_id=self.repo_id, delete_patterns=["*.bin"]) files = self.remote_files() assert "lfs.bin" not in files assert "nested/lfs.bin" not in files assert "nested/sub/lfs.bin" not in files def test_delete_files_in_folders_with_patterns(self): self._api.delete_files(repo_id=self.repo_id, delete_patterns=["*/file.txt"]) files = self.remote_files() assert "file.txt" in files assert "nested/file.txt" not in files assert "nested/sub/file.txt" not in files def test_delete_all_files(self): self._api.delete_files(repo_id=self.repo_id, delete_patterns=["*"]) assert self.remote_files() == {".gitattributes"} class HfApiPublicStagingTest(unittest.TestCase): def setUp(self) -> None: self._api = HfApi() def test_staging_list_datasets(self): self._api.list_datasets() def test_staging_list_models(self): self._api.list_models() class HfApiPublicProductionTest(unittest.TestCase): @with_production_testing def setUp(self) -> None: self._api = HfApi() def test_list_models(self): models = list(self._api.list_models(limit=500)) assert len(models) > 100 assert isinstance(models[0], ModelInfo) def test_list_models_author(self): models = list(self._api.list_models(author="google")) assert len(models) > 10 assert isinstance(models[0], ModelInfo) for model in models: assert model.id.startswith("google/") def test_list_models_search(self): models = list(self._api.list_models(search="bert")) assert len(models) > 10 assert isinstance(models[0], ModelInfo) for model in models[:10]: # Rough rule: at least first 10 will have "bert" in the name # Not optimal since it is dependent on how the Hub implements the search # (and changes it in the future) but for now it should do the trick. assert "bert" in model.id.lower() def test_list_models_complex_query(self): # Let's list the 10 most recent models # with tags "bert" and "jax", # ordered by last modified date. models = list(self._api.list_models(filter=("bert", "jax"), sort="last_modified", direction=-1, limit=10)) # we have at least 1 models assert len(models) > 1 assert len(models) <= 10 model = models[0] assert isinstance(model, ModelInfo) assert all(tag in model.tags for tag in ["bert", "jax"]) def test_list_models_sort_trending_score(self): models = list(self._api.list_models(sort="trending_score", limit=10)) assert len(models) == 10 assert isinstance(models[0], ModelInfo) assert all(model.trending_score is not None for model in models) def test_list_models_sort_created_at(self): models = list(self._api.list_models(sort="created_at", limit=10)) assert len(models) == 10 assert isinstance(models[0], ModelInfo) assert all(model.created_at is not None for model in models) def test_list_models_sort_downloads(self): models = list(self._api.list_models(sort="downloads", limit=10)) assert len(models) == 10 assert isinstance(models[0], ModelInfo) assert all(model.downloads is not None for model in models) def test_list_models_sort_likes(self): models = list(self._api.list_models(sort="likes", limit=10)) assert len(models) == 10 assert isinstance(models[0], ModelInfo) assert all(model.likes is not None for model in models) def test_list_models_with_config(self): for model in self._api.list_models(filter=("adapter-transformers", "bert"), fetch_config=True, limit=20): self.assertIsNotNone(model.config) def test_list_models_without_config(self): for model in self._api.list_models(filter=("adapter-transformers", "bert"), fetch_config=False, limit=20): self.assertIsNone(model.config) def test_list_models_expand_author(self): # Only the selected field is returned models = list(self._api.list_models(expand=["author"], limit=5)) for model in models: assert model.author is not None assert model.id is not None assert model.downloads is None assert model.created_at is None assert model.last_modified is None def test_list_models_expand_multiple(self): # Only the selected fields are returned models = list(self._api.list_models(expand=["author", "downloadsAllTime"], limit=5)) for model in models: assert model.author is not None assert model.downloads_all_time is not None assert model.downloads is None def test_list_models_expand_unexpected_value(self): # Unexpected value => HTTP 400 with self.assertRaises(HfHubHTTPError) as cm: list(self._api.list_models(expand=["foo"])) assert cm.exception.response.status_code == 400 def test_list_models_expand_cannot_be_used_with_other_params(self): # `expand` cannot be used with other params with self.assertRaises(ValueError): next(self._api.list_models(expand=["author"], full=True)) with self.assertRaises(ValueError): next(self._api.list_models(expand=["author"], fetch_config=True)) with self.assertRaises(ValueError): next(self._api.list_models(expand=["author"], cardData=True)) def test_list_models_gated_only(self): for model in self._api.list_models(expand=["gated"], gated=True, limit=5): assert model.gated in ("auto", "manual") def test_list_models_non_gated_only(self): for model in self._api.list_models(expand=["gated"], gated=False, limit=5): assert model.gated is False @pytest.mark.skip("Inference parameter is being revamped") def test_list_models_inference_warm(self): for model in self._api.list_models(inference=["warm"], expand="inference", limit=5): assert model.inference == "warm" @pytest.mark.skip("Inference parameter is being revamped") def test_list_models_inference_cold(self): for model in self._api.list_models(inference=["cold"], expand="inference", limit=5): assert model.inference == "cold" def test_model_info(self): model = self._api.model_info(repo_id=DUMMY_MODEL_ID) self.assertIsInstance(model, ModelInfo) self.assertNotEqual(model.sha, DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT) self.assertEqual(model.created_at, datetime.datetime(2022, 3, 2, 23, 29, 5, tzinfo=datetime.timezone.utc)) # One particular commit (not the top of `main`) model = self._api.model_info(repo_id=DUMMY_MODEL_ID, revision=DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT) self.assertIsInstance(model, ModelInfo) self.assertEqual(model.sha, DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT) def test_model_info_with_security(self): # Note: this test might break in the future if `security_repo_status` object structure gets updated server-side # (not yet fully stable) model = self._api.model_info( repo_id=DUMMY_MODEL_ID, revision=DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT, securityStatus=True, ) self.assertIsNotNone(model.security_repo_status) self.assertEqual(model.security_repo_status, {"scansDone": True, "filesWithIssues": []}) def test_model_info_with_file_metadata(self): model = self._api.model_info( repo_id=DUMMY_MODEL_ID, revision=DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT, files_metadata=True, ) files = model.siblings assert files is not None self._check_siblings_metadata(files) def test_model_info_corrupted_model_index(self) -> None: """Loading model info from a model with corrupted data card should still work. Here we use a model with a "model-index" that is not an array. Git hook should prevent this from happening on the server, but models uploaded before we implemented the check might have this issue. Example data from https://huggingface.co/Waynehillsdev/Waynehills-STT-doogie-server. """ with self.assertLogs("huggingface_hub", level="WARNING") as warning_logs: model = ModelInfo( **{ "_id": "621ffdc036468d709f1751d8", "id": "Waynehillsdev/Waynehills-STT-doogie-server", "cardData": { "license": "apache-2.0", "tags": ["generated_from_trainer"], "model-index": {"name": "Waynehills-STT-doogie-server"}, }, "gitalyUid": "53c57f29a007fc728c968127061b7b740dcf2b1ad401d907f703b27658559413", "likes": 0, "private": False, "config": {"architectures": ["Wav2Vec2ForCTC"], "model_type": "wav2vec2"}, "downloads": 1, "tags": [ "transformers", "pytorch", "tensorboard", "wav2vec2", "automatic-speech-recognition", "generated_from_trainer", "license:apache-2.0", "endpoints_compatible", "region:us", ], "pipeline_tag": "automatic-speech-recognition", "createdAt": "2022-03-02T23:29:04.000Z", "siblings": None, } ) assert model.card_data.eval_results is None assert any("Invalid model-index" in log for log in warning_logs.output) def test_model_info_with_widget_data(self): info = self._api.model_info("HuggingFaceH4/zephyr-7b-beta") assert info.widget_data is not None def test_model_info_expand_author(self): # Only the selected field is returned model = self._api.model_info(repo_id="HuggingFaceH4/zephyr-7b-beta", expand=["author"]) assert model.author == "HuggingFaceH4" assert model.downloads is None assert model.created_at is None assert model.last_modified is None def test_model_info_expand_multiple(self): # Only the selected fields are returned model = self._api.model_info(repo_id="HuggingFaceH4/zephyr-7b-beta", expand=["author", "downloadsAllTime"]) assert model.author == "HuggingFaceH4" assert model.downloads is None assert model.downloads_all_time is not None assert model.created_at is None assert model.last_modified is None def test_model_info_expand_unexpected_value(self): # Unexpected value => HTTP 400 with self.assertRaises(HfHubHTTPError) as cm: self._api.model_info("HuggingFaceH4/zephyr-7b-beta", expand=["foo"]) assert cm.exception.response.status_code == 400 def test_model_info_expand_cannot_be_used_with_other_params(self): # `expand` cannot be used with other params with self.assertRaises(ValueError): self._api.model_info("HuggingFaceH4/zephyr-7b-beta", expand=["author"], securityStatus=True) with self.assertRaises(ValueError): self._api.model_info("HuggingFaceH4/zephyr-7b-beta", expand=["author"], files_metadata=True) def test_list_repo_files(self): files = self._api.list_repo_files(repo_id=DUMMY_MODEL_ID) expected_files = [ ".gitattributes", "README.md", "config.json", "flax_model.msgpack", "merges.txt", "pytorch_model.bin", "tf_model.h5", "vocab.json", ] self.assertListEqual(files, expected_files) def test_list_datasets_no_filter(self): datasets = list(self._api.list_datasets(limit=500)) self.assertGreater(len(datasets), 100) self.assertIsInstance(datasets[0], DatasetInfo) def test_filter_datasets_by_author_and_name(self): datasets = list(self._api.list_datasets(author="huggingface", dataset_name="DataMeasurementsFiles")) assert len(datasets) > 0 assert "huggingface" in datasets[0].author assert "DataMeasurementsFiles" in datasets[0].id def test_filter_datasets_by_benchmark(self): datasets = list(self._api.list_datasets(benchmark="raft")) assert len(datasets) > 0 assert "benchmark:raft" in datasets[0].tags def test_filter_datasets_by_language_creator(self): datasets = list(self._api.list_datasets(language_creators="crowdsourced")) assert len(datasets) > 0 assert "language_creators:crowdsourced" in datasets[0].tags def test_filter_datasets_by_language_only(self): datasets = list(self._api.list_datasets(language="en", limit=10)) assert len(datasets) > 0 assert "language:en" in datasets[0].tags datasets = list(self._api.list_datasets(language=("en", "fr"), limit=10)) assert len(datasets) > 0 assert "language:en" in datasets[0].tags assert "language:fr" in datasets[0].tags def test_filter_datasets_by_multilinguality(self): datasets = list(self._api.list_datasets(multilinguality="multilingual", limit=10)) assert len(datasets) > 0 assert "multilinguality:multilingual" in datasets[0].tags def test_filter_datasets_by_size_categories(self): datasets = list(self._api.list_datasets(size_categories="100K 0 assert "size_categories:100K 0 assert "task_categories:audio-classification" in datasets[0].tags def test_filter_datasets_by_task_ids(self): datasets = list(self._api.list_datasets(task_ids="natural-language-inference", limit=10)) assert len(datasets) > 0 assert "task_ids:natural-language-inference" in datasets[0].tags def test_list_datasets_full(self): datasets = list(self._api.list_datasets(full=True, limit=500)) assert len(datasets) > 100 assert isinstance(datasets[0], DatasetInfo) assert any(dataset.card_data for dataset in datasets) def test_list_datasets_author(self): datasets = list(self._api.list_datasets(author="huggingface", limit=10)) assert len(datasets) > 0 assert datasets[0].author == "huggingface" def test_list_datasets_search(self): datasets = list(self._api.list_datasets(search="wikipedia", limit=10)) assert len(datasets) > 5 for dataset in datasets: assert "wikipedia" in dataset.id.lower() def test_list_datasets_expand_author(self): # Only the selected field is returned datasets = list(self._api.list_datasets(expand=["author"], limit=5)) for dataset in datasets: assert dataset.author is not None assert dataset.id is not None assert dataset.downloads is None assert dataset.created_at is None assert dataset.last_modified is None def test_list_datasets_expand_multiple(self): # Only the selected fields are returned datasets = list(self._api.list_datasets(expand=["author", "downloadsAllTime"], limit=5)) for dataset in datasets: assert dataset.author is not None assert dataset.downloads_all_time is not None assert dataset.downloads is None def test_list_datasets_expand_unexpected_value(self): # Unexpected value => HTTP 400 with self.assertRaises(HfHubHTTPError) as cm: list(self._api.list_datasets(expand=["foo"])) assert cm.exception.response.status_code == 400 def test_list_datasets_expand_cannot_be_used_with_full(self): # `expand` cannot be used with `full` with self.assertRaises(ValueError): next(self._api.list_datasets(expand=["author"], full=True)) def test_list_datasets_gated_only(self): for dataset in self._api.list_datasets(expand=["gated"], gated=True, limit=5): assert dataset.gated in ("auto", "manual") def test_list_datasets_non_gated_only(self): for dataset in self._api.list_datasets(expand=["gated"], gated=False, limit=5): assert dataset.gated is False def test_filter_datasets_with_card_data(self): assert any(dataset.card_data is not None for dataset in self._api.list_datasets(full=True, limit=50)) assert all(dataset.card_data is None for dataset in self._api.list_datasets(full=False, limit=50)) def test_filter_datasets_by_tag(self): for dataset in self._api.list_datasets(tags="fiftyone", limit=5): assert "fiftyone" in dataset.tags def test_dataset_info(self): dataset = self._api.dataset_info(repo_id=DUMMY_DATASET_ID) assert isinstance(dataset.card_data, DatasetCardData) and len(dataset.card_data) > 0 assert isinstance(dataset.siblings, list) and len(dataset.siblings) > 0 assert isinstance(dataset, DatasetInfo) assert dataset.sha != DUMMY_DATASET_ID_REVISION_ONE_SPECIFIC_COMMIT dataset = self._api.dataset_info( repo_id=DUMMY_DATASET_ID, revision=DUMMY_DATASET_ID_REVISION_ONE_SPECIFIC_COMMIT, ) assert isinstance(dataset, DatasetInfo) assert dataset.sha == DUMMY_DATASET_ID_REVISION_ONE_SPECIFIC_COMMIT def test_dataset_info_with_file_metadata(self): dataset = self._api.dataset_info(repo_id=SAMPLE_DATASET_IDENTIFIER, files_metadata=True) files = dataset.siblings assert files is not None self._check_siblings_metadata(files) def _check_siblings_metadata(self, files: List[RepoSibling]): """Check requested metadata has been received from the server.""" at_least_one_lfs = False for file in files: assert isinstance(file.blob_id, str) assert isinstance(file.size, int) if file.lfs is not None: at_least_one_lfs = True assert isinstance(file.lfs, dict) assert "sha256" in file.lfs assert at_least_one_lfs def test_dataset_info_expand_author(self): # Only the selected field is returned dataset = self._api.dataset_info(repo_id="HuggingFaceH4/no_robots", expand=["author"]) assert dataset.author == "HuggingFaceH4" assert dataset.downloads is None assert dataset.created_at is None assert dataset.last_modified is None def test_dataset_info_expand_multiple(self): # Only the selected fields are returned dataset = self._api.dataset_info(repo_id="HuggingFaceH4/no_robots", expand=["author", "downloadsAllTime"]) assert dataset.author == "HuggingFaceH4" assert dataset.downloads is None assert dataset.downloads_all_time is not None assert dataset.created_at is None assert dataset.last_modified is None def test_dataset_info_expand_unexpected_value(self): # Unexpected value => HTTP 400 with self.assertRaises(HfHubHTTPError) as cm: self._api.dataset_info("HuggingFaceH4/no_robots", expand=["foo"]) assert cm.exception.response.status_code == 400 def test_dataset_info_expand_cannot_be_used_with_files_metadata(self): # `expand` cannot be used with other `files_metadata` with self.assertRaises(ValueError): self._api.dataset_info("HuggingFaceH4/no_robots", expand=["author"], files_metadata=True) def test_space_info(self) -> None: space = self._api.space_info(repo_id="HuggingFaceH4/zephyr-chat") assert space.id == "HuggingFaceH4/zephyr-chat" assert space.author == "HuggingFaceH4" assert isinstance(space.runtime, SpaceRuntime) def test_space_info_expand_author(self): # Only the selected field is returned space = self._api.space_info(repo_id="HuggingFaceH4/zephyr-chat", expand=["author"]) assert space.author == "HuggingFaceH4" assert space.created_at is None assert space.last_modified is None def test_space_info_expand_multiple(self): # Only the selected fields are returned space = self._api.space_info(repo_id="HuggingFaceH4/zephyr-chat", expand=["author", "likes"]) assert space.author == "HuggingFaceH4" assert space.created_at is None assert space.last_modified is None assert space.likes is not None def test_space_info_expand_unexpected_value(self): # Unexpected value => HTTP 400 with self.assertRaises(HfHubHTTPError) as cm: self._api.space_info("HuggingFaceH4/zephyr-chat", expand=["foo"]) assert cm.exception.response.status_code == 400 def test_space_info_expand_cannot_be_used_with_files_metadata(self): # `expand` cannot be used with other files_metadata with self.assertRaises(ValueError): self._api.space_info("HuggingFaceH4/zephyr-chat", expand=["author"], files_metadata=True) def test_filter_models_by_author(self): models = list(self._api.list_models(author="muellerzr")) assert len(models) > 0 assert "muellerzr" in models[0].id def test_filter_models_by_author_and_name(self): # Test we can search by an author and a name, but the model is not found models = list(self._api.list_models(author="facebook", model_name="bart-base")) assert "facebook/bart-base" in models[0].id def test_failing_filter_models_by_author_and_model_name(self): # Test we can search by an author and a name, but the model is not found models = list(self._api.list_models(author="muellerzr", model_name="testme")) assert len(models) == 0 def test_filter_models_with_library(self): models = list(self._api.list_models(author="microsoft", model_name="wavlm-base-sd", library="tensorflow")) assert len(models) == 0 models = list(self._api.list_models(author="microsoft", model_name="wavlm-base-sd", library="pytorch")) assert len(models) > 0 def test_filter_models_with_task(self): models = list(self._api.list_models(task="fill-mask", model_name="albert-base-v2")) assert models[0].pipeline_tag == "fill-mask" assert "albert" in models[0].id assert "base" in models[0].id assert "v2" in models[0].id models = list(self._api.list_models(task="dummytask")) assert len(models) == 0 def test_filter_models_by_language(self): for language in ["en", "fr", "zh"]: for model in self._api.list_models(language=language, limit=5): assert language in model.tags def test_filter_models_with_tag(self): models = list(self._api.list_models(author="HuggingFaceBR4", tags=["tensorboard"])) assert models[0].id.startswith("HuggingFaceBR4/") assert "tensorboard" in models[0].tags models = list(self._api.list_models(tags="dummytag")) assert len(models) == 0 def test_filter_models_with_card_data(self): models = self._api.list_models(filter="co2_eq_emissions", cardData=True) assert any(model.card_data is not None for model in models) models = self._api.list_models(filter="co2_eq_emissions") assert all(model.card_data is None for model in models) def test_is_emission_within_threshold(self): # tests that dictionary is handled correctly as "emissions" and that # 17g is accepted and parsed correctly as a value # regression test for #753 kwargs = {field.name: None for field in fields(ModelInfo) if field.init} kwargs = {**kwargs, "card_data": ModelCardData(co2_eq_emissions={"emissions": "17g"})} model = ModelInfo(**kwargs) assert _is_emission_within_threshold(model, -1, 100) def test_filter_emissions_with_max(self): assert all( model.card_data["co2_eq_emissions"] <= 100 for model in self._api.list_models(emissions_thresholds=(None, 100), cardData=True, limit=1000) if isinstance(model.card_data["co2_eq_emissions"], (float, int)) ) def test_filter_emissions_with_min(self): assert all( [ model.card_data["co2_eq_emissions"] >= 5 for model in self._api.list_models(emissions_thresholds=(5, None), cardData=True, limit=1000) if isinstance(model.card_data["co2_eq_emissions"], (float, int)) ] ) def test_filter_emissions_with_min_and_max(self): models = list(self._api.list_models(emissions_thresholds=(5, 100), cardData=True, limit=1000)) assert all( [ model.card_data["co2_eq_emissions"] >= 5 for model in models if isinstance(model.card_data["co2_eq_emissions"], (float, int)) ] ) assert all( [ model.card_data["co2_eq_emissions"] <= 100 for model in models if isinstance(model.card_data["co2_eq_emissions"], (float, int)) ] ) def test_list_spaces_full(self): spaces = list(self._api.list_spaces(full=True, limit=500)) assert len(spaces) > 100 space = spaces[0] assert isinstance(space, SpaceInfo) assert any(space.card_data for space in spaces) def test_list_spaces_author(self): spaces = list(self._api.list_spaces(author="julien-c")) assert len(spaces) > 10 for space in spaces: assert space.id.startswith("julien-c/") def test_list_spaces_search(self): spaces = list(self._api.list_spaces(search="wikipedia", limit=10)) assert "wikipedia" in spaces[0].id.lower() def test_list_spaces_sort_and_direction(self): # Descending order => first item has more likes than second spaces_descending_likes = list(self._api.list_spaces(sort="likes", direction=-1, limit=100)) assert spaces_descending_likes[0].likes > spaces_descending_likes[1].likes def test_list_spaces_limit(self): spaces = list(self._api.list_spaces(limit=5)) assert len(spaces) == 5 def test_list_spaces_with_models(self): spaces = list(self._api.list_spaces(models="bert-base-uncased")) assert "bert-base-uncased" in spaces[0].models def test_list_spaces_with_datasets(self): spaces = list(self._api.list_spaces(datasets="wikipedia")) assert "wikipedia" in spaces[0].datasets def test_list_spaces_linked(self): space_id = "stabilityai/stable-diffusion" spaces = [space for space in self._api.list_spaces(search=space_id) if space.id == space_id] assert spaces[0].models is None assert spaces[0].datasets is None spaces = [space for space in self._api.list_spaces(search=space_id, linked=True) if space.id == space_id] assert spaces[0].models is not None assert spaces[0].datasets is not None def test_list_spaces_expand_author(self): # Only the selected field is returned spaces = list(self._api.list_spaces(expand=["author"], limit=5)) for space in spaces: assert space.author is not None assert space.id is not None assert space.created_at is None assert space.last_modified is None def test_list_spaces_expand_multiple(self): # Only the selected fields are returned spaces = list(self._api.list_spaces(expand=["author", "likes"], limit=5)) for space in spaces: assert space.author is not None assert space.likes is not None def test_list_spaces_expand_unexpected_value(self): # Unexpected value => HTTP 400 with self.assertRaises(HfHubHTTPError) as cm: list(self._api.list_spaces(expand=["foo"])) assert cm.exception.response.status_code == 400 def test_list_spaces_expand_cannot_be_used_with_full(self): # `expand` cannot be used with full with self.assertRaises(ValueError): next(self._api.list_spaces(expand=["author"], full=True)) def test_get_paths_info(self): paths_info = self._api.get_paths_info( "allenai/c4", ["en", "en/c4-train.00001-of-01024.json.gz", "non_existing_path"], expand=True, revision="607bd4c8450a42878aa9ddc051a65a055450ef87", repo_type="dataset", ) assert len(paths_info) == 2 assert paths_info[0].path == "en" assert paths_info[0].tree_id is not None assert paths_info[0].last_commit is not None assert paths_info[1].path == "en/c4-train.00001-of-01024.json.gz" assert paths_info[1].blob_id is not None assert paths_info[1].last_commit is not None assert paths_info[1].lfs is not None assert paths_info[1].security is not None assert paths_info[1].size > 0 def test_get_safetensors_metadata_single_file(self) -> None: info = self._api.get_safetensors_metadata("bigscience/bloomz-560m") assert isinstance(info, SafetensorsRepoMetadata) assert not info.sharded assert info.metadata is None # Never populated on non-sharded models assert len(info.files_metadata) == 1 assert "model.safetensors" in info.files_metadata file_metadata = info.files_metadata["model.safetensors"] assert isinstance(file_metadata, SafetensorsFileMetadata) assert file_metadata.metadata == {"format": "pt"} assert len(file_metadata.tensors) == 293 assert isinstance(info.weight_map, dict) assert info.weight_map["h.0.input_layernorm.bias"] == "model.safetensors" assert info.parameter_count == {"F16": 559214592} def test_get_safetensors_metadata_sharded_model(self) -> None: info = self._api.get_safetensors_metadata("HuggingFaceH4/zephyr-7b-beta") assert isinstance(info, SafetensorsRepoMetadata) assert info.sharded assert isinstance(info.metadata, dict) # populated for sharded model assert len(info.files_metadata) == 8 for file_metadata in info.files_metadata.values(): assert isinstance(file_metadata, SafetensorsFileMetadata) assert info.parameter_count == {"BF16": 7241732096} def test_not_a_safetensors_repo(self) -> None: with self.assertRaises(NotASafetensorsRepoError): self._api.get_safetensors_metadata("huggingface-hub-ci/test_safetensors_metadata") def test_get_safetensors_metadata_from_revision(self) -> None: info = self._api.get_safetensors_metadata("huggingface-hub-ci/test_safetensors_metadata", revision="refs/pr/1") assert isinstance(info, SafetensorsRepoMetadata) def test_parse_safetensors_metadata(self) -> None: info = self._api.parse_safetensors_file_metadata( "HuggingFaceH4/zephyr-7b-beta", "model-00003-of-00008.safetensors" ) assert isinstance(info, SafetensorsFileMetadata) assert info.metadata == {"format": "pt"} assert isinstance(info.tensors, dict) tensor = info.tensors["model.layers.10.input_layernorm.weight"] assert tensor == TensorInfo(dtype="BF16", shape=[4096], data_offsets=(0, 8192)) assert tensor.parameter_count == 4096 assert info.parameter_count == {"BF16": 989888512} def test_not_a_safetensors_file(self) -> None: with self.assertRaises(SafetensorsParsingError): self._api.parse_safetensors_file_metadata( "HuggingFaceH4/zephyr-7b-beta", "pytorch_model-00001-of-00008.bin" ) class HfApiPrivateTest(HfApiCommonTest): def setUp(self) -> None: super().setUp() self.REPO_NAME = repo_name("private") self._api.create_repo(repo_id=self.REPO_NAME, private=True) self._api.create_repo(repo_id=self.REPO_NAME, private=True, repo_type="dataset") def tearDown(self) -> None: self._api.delete_repo(repo_id=self.REPO_NAME) self._api.delete_repo(repo_id=self.REPO_NAME, repo_type="dataset") @patch("huggingface_hub.utils._headers.get_token", return_value=None) def test_model_info(self, mock_get_token: Mock) -> None: with patch.object(self._api, "token", None): # no default token # Test we cannot access model info without a token with self.assertRaisesRegex( requests.exceptions.HTTPError, re.compile( r"401 Client Error(.+)\(Request ID: .+\)(.*)Repository Not Found", flags=re.DOTALL, ), ): _ = self._api.model_info(repo_id=f"{USER}/{self.REPO_NAME}") model_info = self._api.model_info(repo_id=f"{USER}/{self.REPO_NAME}", use_auth_token=self._token) self.assertIsInstance(model_info, ModelInfo) @patch("huggingface_hub.utils._headers.get_token", return_value=None) def test_dataset_info(self, mock_get_token: Mock) -> None: with patch.object(self._api, "token", None): # no default token # Test we cannot access model info without a token with self.assertRaisesRegex( requests.exceptions.HTTPError, re.compile( r"401 Client Error(.+)\(Request ID: .+\)(.*)Repository Not Found", flags=re.DOTALL, ), ): _ = self._api.dataset_info(repo_id=f"{USER}/{self.REPO_NAME}") dataset_info = self._api.dataset_info(repo_id=f"{USER}/{self.REPO_NAME}", use_auth_token=self._token) self.assertIsInstance(dataset_info, DatasetInfo) def test_list_private_datasets(self): orig = len(list(self._api.list_datasets(use_auth_token=False))) new = len(list(self._api.list_datasets(use_auth_token=self._token))) self.assertGreater(new, orig) def test_list_private_models(self): orig = len(list(self._api.list_models(use_auth_token=False))) new = len(list(self._api.list_models(use_auth_token=self._token))) self.assertGreater(new, orig) @with_production_testing def test_list_private_spaces(self): orig = len(list(self._api.list_spaces(use_auth_token=False))) new = len(list(self._api.list_spaces(use_auth_token=self._token))) self.assertGreaterEqual(new, orig) @pytest.mark.usefixtures("fx_cache_dir") class UploadFolderMockedTest(unittest.TestCase): api = HfApi() cache_dir: Path def setUp(self) -> None: (self.cache_dir / "file.txt").write_text("content") (self.cache_dir / "lfs.bin").write_text("content") (self.cache_dir / "sub").mkdir() (self.cache_dir / "sub" / "file.txt").write_text("content") (self.cache_dir / "sub" / "lfs_in_sub.bin").write_text("content") (self.cache_dir / "subdir").mkdir() (self.cache_dir / "subdir" / "file.txt").write_text("content") (self.cache_dir / "subdir" / "lfs_in_subdir.bin").write_text("content") self.all_local_files = { "lfs.bin", "file.txt", "sub/file.txt", "sub/lfs_in_sub.bin", "subdir/file.txt", "subdir/lfs_in_subdir.bin", } self.repo_files_mock = Mock() self.repo_files_mock.return_value = [ # all remote files ".gitattributes", "file.txt", "file1.txt", "sub/file.txt", "sub/file1.txt", "subdir/file.txt", "subdir/lfs_in_subdir.bin", ] self.api.list_repo_files = self.repo_files_mock self.create_commit_mock = Mock() self.create_commit_mock.return_value.commit_url = f"{ENDPOINT_STAGING}/username/repo_id/commit/dummy_sha" self.create_commit_mock.return_value.pr_url = None self.api.create_commit = self.create_commit_mock def _upload_folder_alias(self, **kwargs) -> List[Union[CommitOperationAdd, CommitOperationDelete]]: """Alias to call `upload_folder` + retrieve the CommitOperation list passed to `create_commit`.""" if "folder_path" not in kwargs: kwargs["folder_path"] = self.cache_dir self.api.upload_folder(repo_id="repo_id", **kwargs) return self.create_commit_mock.call_args_list[0][1]["operations"] def test_allow_everything(self): operations = self._upload_folder_alias() assert all(isinstance(op, CommitOperationAdd) for op in operations) assert {op.path_in_repo for op in operations} == self.all_local_files def test_allow_everything_in_subdir_no_trailing_slash(self): operations = self._upload_folder_alias(folder_path=self.cache_dir / "subdir", path_in_repo="subdir") assert all(isinstance(op, CommitOperationAdd) for op in operations) assert {op.path_in_repo for op in operations} == { # correct `path_in_repo` "subdir/file.txt", "subdir/lfs_in_subdir.bin", } def test_allow_everything_in_subdir_with_trailing_slash(self): operations = self._upload_folder_alias(folder_path=self.cache_dir / "subdir", path_in_repo="subdir/") assert all(isinstance(op, CommitOperationAdd) for op in operations) self.assertEqual( {op.path_in_repo for op in operations}, {"subdir/file.txt", "subdir/lfs_in_subdir.bin"}, # correct `path_in_repo` ) def test_allow_txt_ignore_subdir(self): operations = self._upload_folder_alias(allow_patterns="*.txt", ignore_patterns="subdir/*") assert all(isinstance(op, CommitOperationAdd) for op in operations) assert {op.path_in_repo for op in operations} == {"sub/file.txt", "file.txt"} # only .txt files, not in subdir def test_allow_txt_not_root_ignore_subdir(self): operations = self._upload_folder_alias(allow_patterns="**/*.txt", ignore_patterns="subdir/*") assert all(isinstance(op, CommitOperationAdd) for op in operations) assert {op.path_in_repo for op in operations} == { # only .txt files, not in subdir, not at root "sub/file.txt" } def test_path_in_repo_dot(self): """Regression test for #1382 when using `path_in_repo="."`. Using `path_in_repo="."` or `path_in_repo=None` should be equivalent. See https://github.com/huggingface/huggingface_hub/pull/1382. """ operation_with_dot = self._upload_folder_alias(path_in_repo=".", allow_patterns=["file.txt"])[0] operation_with_none = self._upload_folder_alias(path_in_repo=None, allow_patterns=["file.txt"])[0] assert operation_with_dot.path_in_repo == "file.txt" assert operation_with_none.path_in_repo == "file.txt" def test_delete_txt(self): operations = self._upload_folder_alias(delete_patterns="*.txt") added_files = {op.path_in_repo for op in operations if isinstance(op, CommitOperationAdd)} deleted_files = {op.path_in_repo for op in operations if isinstance(op, CommitOperationDelete)} assert added_files == self.all_local_files assert deleted_files == {"file1.txt", "sub/file1.txt"} # since "file.txt" and "sub/file.txt" are overwritten, no need to delete them first assert "file.txt" in added_files assert "sub/file.txt" in added_files def test_delete_txt_in_sub(self): operations = self._upload_folder_alias( path_in_repo="sub/", folder_path=self.cache_dir / "sub", delete_patterns="*.txt" ) added_files = {op.path_in_repo for op in operations if isinstance(op, CommitOperationAdd)} deleted_files = {op.path_in_repo for op in operations if isinstance(op, CommitOperationDelete)} assert added_files == {"sub/file.txt", "sub/lfs_in_sub.bin"} # added only in sub/ assert deleted_files == {"sub/file1.txt"} # delete only in sub/ def test_delete_txt_in_sub_ignore_sub_file_txt(self): operations = self._upload_folder_alias( path_in_repo="sub", folder_path=self.cache_dir / "sub", ignore_patterns="file.txt", delete_patterns="*.txt" ) added_files = {op.path_in_repo for op in operations if isinstance(op, CommitOperationAdd)} deleted_files = {op.path_in_repo for op in operations if isinstance(op, CommitOperationDelete)} # since "sub/file.txt" should be deleted and is not overwritten (ignore_patterns), we delete it explicitly assert added_files == {"sub/lfs_in_sub.bin"} # no "sub/file.txt" assert deleted_files == {"sub/file1.txt", "sub/file.txt"} def test_delete_if_path_in_repo(self): # Regression test for https://github.com/huggingface/huggingface_hub/pull/2129 operations = self._upload_folder_alias(path_in_repo=".", folder_path=self.cache_dir, delete_patterns="*") deleted_files = {op.path_in_repo for op in operations if isinstance(op, CommitOperationDelete)} assert deleted_files == {"file1.txt", "sub/file1.txt"} # all the 'old' files @pytest.mark.usefixtures("fx_cache_dir") class HfLargefilesTest(HfApiCommonTest): cache_dir: Path def tearDown(self): self._api.delete_repo(repo_id=self.repo_id) def setup_local_clone(self) -> None: scheme = urlparse(self.repo_url).scheme repo_url_auth = self.repo_url.replace(f"{scheme}://", f"{scheme}://user:{TOKEN}@") subprocess.run( ["git", "clone", repo_url_auth, str(self.cache_dir)], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) subprocess.run(["git", "lfs", "track", "*.pdf"], check=True, cwd=self.cache_dir) subprocess.run(["git", "lfs", "track", "*.epub"], check=True, cwd=self.cache_dir) @require_git_lfs def test_end_to_end_thresh_6M(self): # Little-hack: create repo with defined `_lfsmultipartthresh`. Only for tests purposes self._api._lfsmultipartthresh = 6 * 10**6 self.repo_url = self._api.create_repo(repo_id=repo_name()) self.repo_id = self.repo_url.repo_id self._api._lfsmultipartthresh = None self.setup_local_clone() subprocess.run( ["wget", LARGE_FILE_18MB], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=self.cache_dir ) subprocess.run(["git", "add", "*"], check=True, cwd=self.cache_dir) subprocess.run(["git", "commit", "-m", "commit message"], check=True, cwd=self.cache_dir) # This will fail as we haven't set up our custom transfer agent yet. failed_process = subprocess.run( ["git", "push"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=self.cache_dir, ) self.assertEqual(failed_process.returncode, 1) self.assertIn("cli lfs-enable-largefiles", failed_process.stderr.decode()) # ^ Instructions on how to fix this are included in the error message. subprocess.run(["huggingface-cli", "lfs-enable-largefiles", self.cache_dir], check=True) start_time = time.time() subprocess.run(["git", "push"], check=True, cwd=self.cache_dir) print("took", time.time() - start_time) # To be 100% sure, let's download the resolved file pdf_url = f"{self.repo_url}/resolve/main/progit.pdf" DEST_FILENAME = "uploaded.pdf" subprocess.run( ["wget", pdf_url, "-O", DEST_FILENAME], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=self.cache_dir, ) dest_filesize = (self.cache_dir / DEST_FILENAME).stat().st_size assert dest_filesize == 18685041 @require_git_lfs def test_end_to_end_thresh_16M(self): # Here we'll push one multipart and one non-multipart file in the same commit, and see what happens # Little-hack: create repo with defined `_lfsmultipartthresh`. Only for tests purposes self._api._lfsmultipartthresh = 16 * 10**6 self.repo_url = self._api.create_repo(repo_id=repo_name()) self.repo_id = self.repo_url.repo_id self._api._lfsmultipartthresh = None self.setup_local_clone() subprocess.run( ["wget", LARGE_FILE_18MB], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=self.cache_dir ) subprocess.run( ["wget", LARGE_FILE_14MB], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=self.cache_dir ) subprocess.run(["git", "add", "*"], check=True, cwd=self.cache_dir) subprocess.run(["git", "commit", "-m", "both files in same commit"], check=True, cwd=self.cache_dir) subprocess.run(["huggingface-cli", "lfs-enable-largefiles", self.cache_dir], check=True) start_time = time.time() subprocess.run(["git", "push"], check=True, cwd=self.cache_dir) print("took", time.time() - start_time) def test_upload_lfs_file_multipart(self): """End to end test to check upload an LFS file using multipart upload works.""" self._api._lfsmultipartthresh = 16 * 10**6 self.repo_id = self._api.create_repo(repo_id=repo_name()).repo_id self._api._lfsmultipartthresh = None with patch.object( huggingface_hub.lfs, "_upload_parts_iteratively", wraps=huggingface_hub.lfs._upload_parts_iteratively, ) as mock: self._api.upload_file(repo_id=self.repo_id, path_or_fileobj=b"0" * 18 * 10**6, path_in_repo="lfs.bin") mock.assert_called_once() # It used multipart upload class ParseHFUrlTest(unittest.TestCase): def test_repo_type_and_id_from_hf_id_on_correct_values(self): possible_values = { "https://huggingface.co/id": [None, None, "id"], "https://huggingface.co/user/id": [None, "user", "id"], "https://huggingface.co/datasets/user/id": ["dataset", "user", "id"], "https://huggingface.co/spaces/user/id": ["space", "user", "id"], "user/id": [None, "user", "id"], "dataset/user/id": ["dataset", "user", "id"], "space/user/id": ["space", "user", "id"], "id": [None, None, "id"], "hf://id": [None, None, "id"], "hf://user/id": [None, "user", "id"], "hf://model/user/name": ["model", "user", "name"], # 's' is optional "hf://models/user/name": ["model", "user", "name"], } for key, value in possible_values.items(): self.assertEqual( repo_type_and_id_from_hf_id(key, hub_url=ENDPOINT_PRODUCTION), tuple(value), ) def test_repo_type_and_id_from_hf_id_on_wrong_values(self): for hub_id in [ "https://unknown-endpoint.co/id", "https://huggingface.co/datasets/user/id@revision", # @ forbidden "datasets/user/id/subpath", "hffs://model/user/name", "spaeces/user/id", # with typo in repo type ]: with self.assertRaises(ValueError): repo_type_and_id_from_hf_id(hub_id, hub_url=ENDPOINT_PRODUCTION) class HfApiDiscussionsTest(HfApiCommonTest): def setUp(self): self.repo_id = self._api.create_repo(repo_id=repo_name()).repo_id self.pull_request = self._api.create_discussion( repo_id=self.repo_id, pull_request=True, title="Test Pull Request" ) self.discussion = self._api.create_discussion( repo_id=self.repo_id, pull_request=False, title="Test Discussion" ) def tearDown(self): self._api.delete_repo(repo_id=self.repo_id) def test_create_discussion(self): discussion = self._api.create_discussion(repo_id=self.repo_id, title=" Test discussion ! ") self.assertEqual(discussion.num, 3) self.assertEqual(discussion.author, USER) self.assertEqual(discussion.is_pull_request, False) self.assertEqual(discussion.title, "Test discussion !") @use_tmp_repo("dataset") def test_create_discussion_space(self, repo_url: RepoUrl): """Regression test for #1463. Computed URL was malformed with `dataset` and `space` repo_types. See https://github.com/huggingface/huggingface_hub/issues/1463. """ discussion = self._api.create_discussion(repo_id=repo_url.repo_id, repo_type="dataset", title="title") self.assertEqual(discussion.url, f"{repo_url}/discussions/1") def test_create_pull_request(self): discussion = self._api.create_discussion(repo_id=self.repo_id, title=" Test PR ! ", pull_request=True) self.assertEqual(discussion.num, 3) self.assertEqual(discussion.author, USER) self.assertEqual(discussion.is_pull_request, True) self.assertEqual(discussion.title, "Test PR !") model_info = self._api.repo_info(repo_id=self.repo_id, revision="refs/pr/1") self.assertIsInstance(model_info, ModelInfo) def test_get_repo_discussion(self): discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id) self.assertIsInstance(discussions_generator, types.GeneratorType) self.assertListEqual( list([d.num for d in discussions_generator]), [self.discussion.num, self.pull_request.num] ) def test_get_repo_discussion_by_type(self): discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_type="pull_request") self.assertIsInstance(discussions_generator, types.GeneratorType) self.assertListEqual(list([d.num for d in discussions_generator]), [self.pull_request.num]) discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_type="discussion") self.assertIsInstance(discussions_generator, types.GeneratorType) self.assertListEqual(list([d.num for d in discussions_generator]), [self.discussion.num]) discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_type="all") self.assertIsInstance(discussions_generator, types.GeneratorType) self.assertListEqual( list([d.num for d in discussions_generator]), [self.discussion.num, self.pull_request.num] ) def test_get_repo_discussion_by_author(self): discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, author="unknown") self.assertIsInstance(discussions_generator, types.GeneratorType) self.assertListEqual(list([d.num for d in discussions_generator]), []) def test_get_repo_discussion_by_status(self): self._api.change_discussion_status(self.repo_id, self.discussion.num, "closed") discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_status="open") self.assertIsInstance(discussions_generator, types.GeneratorType) self.assertListEqual(list([d.num for d in discussions_generator]), [self.pull_request.num]) discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_status="closed") self.assertIsInstance(discussions_generator, types.GeneratorType) self.assertListEqual(list([d.num for d in discussions_generator]), [self.discussion.num]) discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_status="all") self.assertIsInstance(discussions_generator, types.GeneratorType) self.assertListEqual( list([d.num for d in discussions_generator]), [self.discussion.num, self.pull_request.num] ) @with_production_testing def test_get_repo_discussion_pagination(self): discussions = list( HfApi().get_repo_discussions(repo_id="open-llm-leaderboard/open_llm_leaderboard", repo_type="space") ) assert len(discussions) > 50 def test_get_discussion_details(self): retrieved = self._api.get_discussion_details(repo_id=self.repo_id, discussion_num=2) self.assertEqual(retrieved, self.discussion) def test_edit_discussion_comment(self): def get_first_comment(discussion: DiscussionWithDetails) -> DiscussionComment: return [evt for evt in discussion.events if evt.type == "comment"][0] edited_comment = self._api.edit_discussion_comment( repo_id=self.repo_id, discussion_num=self.pull_request.num, comment_id=get_first_comment(self.pull_request).id, new_content="**Edited** comment 🤗", ) retrieved = self._api.get_discussion_details(repo_id=self.repo_id, discussion_num=self.pull_request.num) self.assertEqual(get_first_comment(retrieved).edited, True) self.assertEqual(get_first_comment(retrieved).id, get_first_comment(self.pull_request).id) self.assertEqual(get_first_comment(retrieved).content, "**Edited** comment 🤗") self.assertEqual(get_first_comment(retrieved), edited_comment) def test_comment_discussion(self): new_comment = self._api.comment_discussion( repo_id=self.repo_id, discussion_num=self.discussion.num, comment="""\ # Multi-line comment **With formatting**, including *italic text* & ~strike through~ And even [links](http://hf.co)! 💥🤯 """, ) retrieved = self._api.get_discussion_details(repo_id=self.repo_id, discussion_num=self.discussion.num) self.assertEqual(len(retrieved.events), 2) self.assertIn(new_comment.id, {event.id for event in retrieved.events}) def test_rename_discussion(self): rename_event = self._api.rename_discussion( repo_id=self.repo_id, discussion_num=self.discussion.num, new_title="New title2" ) retrieved = self._api.get_discussion_details(repo_id=self.repo_id, discussion_num=self.discussion.num) self.assertIn(rename_event.id, (event.id for event in retrieved.events)) self.assertEqual(rename_event.old_title, self.discussion.title) self.assertEqual(rename_event.new_title, "New title2") def test_change_discussion_status(self): status_change_event = self._api.change_discussion_status( repo_id=self.repo_id, discussion_num=self.discussion.num, new_status="closed" ) retrieved = self._api.get_discussion_details(repo_id=self.repo_id, discussion_num=self.discussion.num) self.assertIn(status_change_event.id, (event.id for event in retrieved.events)) self.assertEqual(status_change_event.new_status, "closed") with self.assertRaises(ValueError): self._api.change_discussion_status( repo_id=self.repo_id, discussion_num=self.discussion.num, new_status="published" ) def test_merge_pull_request(self): self._api.create_commit( repo_id=self.repo_id, commit_message="Commit some file", operations=[CommitOperationAdd(path_in_repo="file.test", path_or_fileobj=b"Content")], revision=self.pull_request.git_reference, ) self._api.change_discussion_status( repo_id=self.repo_id, discussion_num=self.pull_request.num, new_status="open" ) self._api.merge_pull_request(self.repo_id, self.pull_request.num) retrieved = self._api.get_discussion_details(repo_id=self.repo_id, discussion_num=self.pull_request.num) self.assertEqual(retrieved.status, "merged") self.assertIsNotNone(retrieved.merge_commit_oid) class ActivityApiTest(unittest.TestCase): @classmethod def setUpClass(cls) -> None: cls.api = HfApi() # no auth! def test_unlike_missing_repo(self) -> None: with self.assertRaises(RepositoryNotFoundError): self.api.unlike("missing_repo_id", token=TOKEN) def test_list_likes_repos_auth_and_implicit_user(self) -> None: # User is implicit likes = self.api.list_liked_repos(token=TOKEN) self.assertEqual(likes.user, USER) def test_list_likes_repos_auth_and_explicit_user(self) -> None: # User is explicit even if auth likes = self.api.list_liked_repos(user=OTHER_USER, token=TOKEN) self.assertEqual(likes.user, OTHER_USER) @with_production_testing def test_list_repo_likers(self) -> None: # a repo with > 5000 likes all_likers = list( HfApi().list_repo_likers(repo_id="open-llm-leaderboard/open_llm_leaderboard", repo_type="space") ) self.assertIsInstance(all_likers[0], User) self.assertGreater(len(all_likers), 5000) @with_production_testing def test_list_likes_on_production(self) -> None: # Test julien-c likes a lot of repos ! likes = HfApi().list_liked_repos("julien-c") self.assertEqual(len(likes.models) + len(likes.datasets) + len(likes.spaces), likes.total) self.assertGreater(len(likes.models), 0) self.assertGreater(len(likes.datasets), 0) self.assertGreater(len(likes.spaces), 0) class TestSquashHistory(HfApiCommonTest): @use_tmp_repo() def test_super_squash_history_on_branch(self, repo_url: RepoUrl) -> None: # Upload + update file on main repo_id = repo_url.repo_id self._api.upload_file(repo_id=repo_id, path_in_repo="file.txt", path_or_fileobj=b"content") self._api.upload_file(repo_id=repo_id, path_in_repo="lfs.bin", path_or_fileobj=b"content") self._api.upload_file(repo_id=repo_id, path_in_repo="file.txt", path_or_fileobj=b"another_content") # Upload file on a new branch self._api.create_branch(repo_id=repo_id, branch="v0.1", exist_ok=True) self._api.upload_file(repo_id=repo_id, path_in_repo="file.txt", path_or_fileobj=b"foo", revision="v0.1") # Squash history on main self._api.super_squash_history(repo_id=repo_id) # List history squashed_main_commits = self._api.list_repo_commits(repo_id=repo_id, revision="main") branch_commits = self._api.list_repo_commits(repo_id=repo_id, revision="v0.1") # Main branch has been squashed but initial commits still exists on other branch assert len(squashed_main_commits) == 1 assert squashed_main_commits[0].title == "Super-squash branch 'main' using huggingface_hub" assert len(branch_commits) == 5 assert branch_commits[-1].title == "initial commit" # Squash history on branch self._api.super_squash_history(repo_id=repo_id, branch="v0.1") squashed_branch_commits = self._api.list_repo_commits(repo_id=repo_id, revision="v0.1") assert len(squashed_branch_commits) == 1 assert squashed_branch_commits[0].title == "Super-squash branch 'v0.1' using huggingface_hub" @use_tmp_repo() def test_super_squash_history_on_special_ref(self, repo_url: RepoUrl) -> None: """Regression test for https://github.com/huggingface/dataset-viewer/pull/3131. In practice, it doesn't make any sense to super squash a PR as it will not be mergeable anymore. The only case where it's useful is for the dataset-viewer on refs/convert/parquet. """ repo_id = repo_url.repo_id pr = self._api.create_pull_request(repo_id=repo_id, title="Test super squash on PR") # Upload + update file on PR self._api.upload_file( repo_id=repo_id, path_in_repo="file.txt", path_or_fileobj=b"content", revision=pr.git_reference ) self._api.upload_file( repo_id=repo_id, path_in_repo="lfs.bin", path_or_fileobj=b"content", revision=pr.git_reference ) self._api.upload_file( repo_id=repo_id, path_in_repo="file.txt", path_or_fileobj=b"another_content", revision=pr.git_reference ) # Squash history PR self._api.super_squash_history(repo_id=repo_id, branch=pr.git_reference) squashed_branch_commits = self._api.list_repo_commits(repo_id=repo_id, revision=pr.git_reference) assert len(squashed_branch_commits) == 1 class TestListAndPermanentlyDeleteLFSFiles(HfApiCommonTest): @use_tmp_repo() def test_list_and_delete_lfs_files(self, repo_url: RepoUrl) -> None: repo_id = repo_url.repo_id # Main files self._api.upload_file(path_or_fileobj=b"LFS content", path_in_repo="lfs_file.bin", repo_id=repo_id) self._api.upload_file(path_or_fileobj=b"TXT content", path_in_repo="txt_file.txt", repo_id=repo_id) self._api.upload_file(path_or_fileobj=b"LFS content 2", path_in_repo="lfs_file_2.bin", repo_id=repo_id) self._api.upload_file(path_or_fileobj=b"TXT content 2", path_in_repo="txt_file_2.txt", repo_id=repo_id) # Branch files self._api.create_branch(repo_id=repo_id, branch="my-branch") self._api.upload_file( path_or_fileobj=b"LFS content branch", path_in_repo="lfs_file_branch.bin", repo_id=repo_id, revision="my-branch", ) self._api.upload_file( path_or_fileobj=b"TXT content branch", path_in_repo="txt_file_branch.txt", repo_id=repo_id, revision="my-branch", ) # PR files self._api.upload_file( path_or_fileobj=b"LFS content PR", path_in_repo="lfs_file_PR.bin", repo_id=repo_id, create_pr=True ) self._api.upload_file( path_or_fileobj=b"TXT content PR", path_in_repo="txt_file_PR.txt", repo_id=repo_id, create_pr=True ) # List LFS files lfs_files = [file for file in self._api.list_lfs_files(repo_id=repo_id)] assert len(lfs_files) == 4 assert {file.filename for file in lfs_files} == { "lfs_file.bin", "lfs_file_2.bin", "lfs_file_branch.bin", "lfs_file_PR.bin", } # Select LFS files that are on main lfs_files_on_main = [file for file in lfs_files if file.ref == "main"] assert len(lfs_files_on_main) == 2 # Permanently delete LFS files self._api.permanently_delete_lfs_files(repo_id=repo_id, lfs_files=lfs_files_on_main) # LFS files from branch and PR remain lfs_files = [file for file in self._api.list_lfs_files(repo_id=repo_id)] assert len(lfs_files) == 2 assert {file.filename for file in lfs_files} == {"lfs_file_branch.bin", "lfs_file_PR.bin"} # Downloading "lfs_file.bin" fails with EntryNotFoundError files = self._api.list_repo_files(repo_id=repo_id) assert set(files) == {".gitattributes", "txt_file.txt", "txt_file_2.txt"} with pytest.raises(EntryNotFoundError): self._api.hf_hub_download(repo_id=repo_id, filename="lfs_file.bin") @pytest.mark.vcr class TestSpaceAPIProduction(unittest.TestCase): """ Testing Space API is not possible on staging. We use VCR-ed to mimic server requests. """ repo_id: str api: HfApi _BASIC_APP_PY_TEMPLATE = """ import gradio as gr def greet(name): return "Hello " + name + "!!" iface = gr.Interface(fn=greet, inputs="text", outputs="text") iface.launch() """.encode() @with_production_testing def setUp(self): super().setUp() # If generating new VCR => use personal token and REMOVE IT from the VCR self.repo_id = "user/tmp_test_space" # no need to be unique as it's a VCRed test self.api = HfApi(token="hf_fake_token", endpoint=ENDPOINT_PRODUCTION) # Create a Space self.api.create_repo(repo_id=self.repo_id, repo_type="space", space_sdk="gradio", private=True, exist_ok=True) self.api.upload_file( path_or_fileobj=self._BASIC_APP_PY_TEMPLATE, repo_id=self.repo_id, repo_type="space", path_in_repo="app.py", ) def tearDown(self): self.api.delete_repo(repo_id=self.repo_id, repo_type="space") super().tearDown() def test_manage_secrets(self) -> None: # Add 3 secrets self.api.add_space_secret(self.repo_id, "foo", "123") self.api.add_space_secret(self.repo_id, "token", "hf_api_123456") self.api.add_space_secret(self.repo_id, "gh_api_key", "******") # Add secret with optional description self.api.add_space_secret(self.repo_id, "bar", "123", description="This is a secret") # Update secret self.api.add_space_secret(self.repo_id, "foo", "456") # Update secret with optional description self.api.add_space_secret(self.repo_id, "foo", "789", description="This is a secret") self.api.add_space_secret(self.repo_id, "bar", "456", description="This is another secret") # Delete secret self.api.delete_space_secret(self.repo_id, "gh_api_key") # Doesn't fail on missing key self.api.delete_space_secret(self.repo_id, "missing_key") def test_manage_variables(self) -> None: # Get variables self.api.get_space_variables(self.repo_id) # Add 3 variables self.api.add_space_variable(self.repo_id, "foo", "123") self.api.add_space_variable(self.repo_id, "MODEL_REPO_ID", "user/repo") # Add 1 variable with optional description self.api.add_space_variable(self.repo_id, "MODEL_PAPER", "arXiv", description="found it there") # Update variable self.api.add_space_variable(self.repo_id, "foo", "456") # Update variable with optional description self.api.add_space_variable(self.repo_id, "foo", "456", description="updated description") # Delete variable self.api.delete_space_variable(self.repo_id, "gh_api_key") # Doesn't fail on missing key self.api.delete_space_variable(self.repo_id, "missing_key") # Returning all variables created variables = self.api.get_space_variables(self.repo_id) self.assertEqual(len(variables), 3) def test_space_runtime(self) -> None: runtime = self.api.get_space_runtime(self.repo_id) # Space has just been created: hardware might not be set yet. self.assertIn(runtime.hardware, (None, SpaceHardware.CPU_BASIC)) self.assertIn(runtime.requested_hardware, (None, SpaceHardware.CPU_BASIC)) # Space is either "BUILDING" (if not yet done) or "NO_APP_FILE" (if building failed) self.assertIn(runtime.stage, (SpaceStage.NO_APP_FILE, SpaceStage.BUILDING)) self.assertIn(runtime.stage, ("NO_APP_FILE", "BUILDING")) # str works as well # Raw response from Hub self.assertIsInstance(runtime.raw, dict) def test_static_space_runtime(self) -> None: """ Regression test for static Spaces. See https://github.com/huggingface/huggingface_hub/pull/1754. """ runtime = self.api.get_space_runtime("victor/static-space") self.assertIsInstance(runtime.raw, dict) @with_production_testing def test_pause_and_restart_space(self) -> None: # Upload a fake app.py file self.api.upload_file(path_or_fileobj=b"", path_in_repo="app.py", repo_id=self.repo_id, repo_type="space") # Wait for the Space to be "BUILDING" count = 0 while True: if self.api.get_space_runtime(self.repo_id).stage == SpaceStage.BUILDING: break time.sleep(1.0) count += 1 if count > 10: raise Exception("Space is not building after 10 seconds.") # Pause it runtime_after_pause = self.api.pause_space(self.repo_id) self.assertEqual(runtime_after_pause.stage, SpaceStage.PAUSED) # Restart self.api.restart_space(self.repo_id) time.sleep(0.5) runtime_after_restart = self.api.get_space_runtime(self.repo_id) self.assertNotEqual(runtime_after_restart.stage, SpaceStage.PAUSED) @pytest.mark.usefixtures("fx_cache_dir") class TestCommitInBackground(HfApiCommonTest): cache_dir: Path @use_tmp_repo() def test_commit_to_repo_in_background(self, repo_url: RepoUrl) -> None: repo_id = repo_url.repo_id (self.cache_dir / "file.txt").write_text("content") (self.cache_dir / "lfs.bin").write_text("content") t0 = time.time() upload_future_1 = self._api.upload_file( path_or_fileobj=b"1", path_in_repo="1.txt", repo_id=repo_id, commit_message="Upload 1", run_as_future=True ) upload_future_2 = self._api.upload_file( path_or_fileobj=b"2", path_in_repo="2.txt", repo_id=repo_id, commit_message="Upload 2", run_as_future=True ) upload_future_3 = self._api.upload_folder( repo_id=repo_id, folder_path=self.cache_dir, commit_message="Upload folder", run_as_future=True ) t1 = time.time() # all futures are queued instantly self.assertLessEqual(t1 - t0, 0.01) # wait for the last job to complete upload_future_3.result() # all of them are now complete (ran in order) assert upload_future_1.done() assert upload_future_2.done() assert upload_future_3.done() # 4 commits, sorted in reverse order of creation commits = self._api.list_repo_commits(repo_id=repo_id) self.assertEqual(len(commits), 4) self.assertEqual(commits[0].title, "Upload folder") self.assertEqual(commits[1].title, "Upload 2") self.assertEqual(commits[2].title, "Upload 1") self.assertEqual(commits[3].title, "initial commit") @use_tmp_repo() def test_run_as_future(self, repo_url: RepoUrl) -> None: repo_id = repo_url.repo_id # update repo visibility to private self._api.run_as_future(self._api.update_repo_settings, repo_id=repo_id, private=True) future_1 = self._api.run_as_future(self._api.model_info, repo_id=repo_id) # update repo visibility to public self._api.run_as_future(self._api.update_repo_settings, repo_id=repo_id, private=False) future_2 = self._api.run_as_future(self._api.model_info, repo_id=repo_id) self.assertIsInstance(future_1, Future) self.assertIsInstance(future_2, Future) # Wait for first info future info_1 = future_1.result() self.assertFalse(future_2.done()) # Wait for second info future info_2 = future_2.result() assert future_2.done() # Like/unlike is correct self.assertEqual(info_1.private, True) self.assertEqual(info_2.private, False) class TestDownloadHfApiAlias(unittest.TestCase): def setUp(self) -> None: self.api = HfApi( endpoint="https://hf.co", token="user_token", library_name="cool_one", library_version="1.0.0", user_agent="myself", ) return super().setUp() @patch("huggingface_hub.file_download.hf_hub_download") def test_hf_hub_download_alias(self, mock: Mock) -> None: self.api.hf_hub_download("my_repo_id", "file.txt") mock.assert_called_once_with( # Call values repo_id="my_repo_id", filename="file.txt", # HfAPI values endpoint="https://hf.co", library_name="cool_one", library_version="1.0.0", user_agent="myself", token="user_token", # Default values subfolder=None, repo_type=None, revision=None, cache_dir=None, local_dir=None, local_dir_use_symlinks="auto", force_download=False, force_filename=None, proxies=None, etag_timeout=10, resume_download=None, local_files_only=False, headers=None, ) @patch("huggingface_hub._snapshot_download.snapshot_download") def test_snapshot_download_alias(self, mock: Mock) -> None: self.api.snapshot_download("my_repo_id") mock.assert_called_once_with( # Call values repo_id="my_repo_id", # HfAPI values endpoint="https://hf.co", library_name="cool_one", library_version="1.0.0", user_agent="myself", token="user_token", # Default values repo_type=None, revision=None, cache_dir=None, local_dir=None, local_dir_use_symlinks="auto", proxies=None, etag_timeout=10, resume_download=None, force_download=False, local_files_only=False, allow_patterns=None, ignore_patterns=None, max_workers=8, tqdm_class=None, ) class TestSpaceAPIMocked(unittest.TestCase): """ Testing Space hardware requests is resource intensive for the server (need to spawn GPUs). Tests are mocked to check the correct values are sent. """ def setUp(self) -> None: self.api = HfApi(token="fake_token") self.repo_id = "fake_repo_id" get_session_mock = Mock() self.post_mock = get_session_mock().post self.post_mock.return_value.json.return_value = { "url": f"{self.api.endpoint}/spaces/user/repo_id", "stage": "RUNNING", "sdk": "gradio", "sdkVersion": "3.17.0", "hardware": { "current": "t4-medium", "requested": "t4-medium", }, "storage": "large", "gcTimeout": None, } self.delete_mock = get_session_mock().delete self.delete_mock.return_value.json.return_value = { "url": f"{self.api.endpoint}/spaces/user/repo_id", "stage": "RUNNING", "sdk": "gradio", "sdkVersion": "3.17.0", "hardware": { "current": "t4-medium", "requested": "t4-medium", }, "storage": None, "gcTimeout": None, } self.patcher = patch("huggingface_hub.hf_api.get_session", get_session_mock) self.patcher.start() def tearDown(self) -> None: self.patcher.stop() def test_create_space_with_hardware(self) -> None: self.api.create_repo( self.repo_id, repo_type="space", space_sdk="gradio", space_hardware=SpaceHardware.T4_MEDIUM, ) self.post_mock.assert_called_once_with( f"{self.api.endpoint}/api/repos/create", headers=self.api._build_hf_headers(), json={ "name": self.repo_id, "organization": None, "type": "space", "sdk": "gradio", "hardware": "t4-medium", }, ) def test_create_space_with_hardware_and_sleep_time(self) -> None: self.api.create_repo( self.repo_id, repo_type="space", space_sdk="gradio", space_hardware=SpaceHardware.T4_MEDIUM, space_sleep_time=123, ) self.post_mock.assert_called_once_with( f"{self.api.endpoint}/api/repos/create", headers=self.api._build_hf_headers(), json={ "name": self.repo_id, "organization": None, "type": "space", "sdk": "gradio", "hardware": "t4-medium", "sleepTimeSeconds": 123, }, ) def test_create_space_with_storage(self) -> None: self.api.create_repo( self.repo_id, repo_type="space", space_sdk="gradio", space_storage=SpaceStorage.LARGE, ) self.post_mock.assert_called_once_with( f"{self.api.endpoint}/api/repos/create", headers=self.api._build_hf_headers(), json={ "name": self.repo_id, "organization": None, "type": "space", "sdk": "gradio", "storageTier": "large", }, ) def test_create_space_with_secrets_and_variables(self) -> None: self.api.create_repo( self.repo_id, repo_type="space", space_sdk="gradio", space_secrets=[ {"key": "Testsecret", "value": "Testvalue", "description": "Testdescription"}, {"key": "Testsecret2", "value": "Testvalue"}, ], space_variables=[ {"key": "Testvariable", "value": "Testvalue", "description": "Testdescription"}, {"key": "Testvariable2", "value": "Testvalue"}, ], ) self.post_mock.assert_called_once_with( f"{self.api.endpoint}/api/repos/create", headers=self.api._build_hf_headers(), json={ "name": self.repo_id, "organization": None, "type": "space", "sdk": "gradio", "secrets": [ {"key": "Testsecret", "value": "Testvalue", "description": "Testdescription"}, {"key": "Testsecret2", "value": "Testvalue"}, ], "variables": [ {"key": "Testvariable", "value": "Testvalue", "description": "Testdescription"}, {"key": "Testvariable2", "value": "Testvalue"}, ], }, ) def test_duplicate_space(self) -> None: self.api.duplicate_space( self.repo_id, to_id=f"{USER}/new_repo_id", private=True, hardware=SpaceHardware.T4_MEDIUM, storage=SpaceStorage.LARGE, sleep_time=123, secrets=[ {"key": "Testsecret", "value": "Testvalue", "description": "Testdescription"}, {"key": "Testsecret2", "value": "Testvalue"}, ], variables=[ {"key": "Testvariable", "value": "Testvalue", "description": "Testdescription"}, {"key": "Testvariable2", "value": "Testvalue"}, ], ) self.post_mock.assert_called_once_with( f"{self.api.endpoint}/api/spaces/{self.repo_id}/duplicate", headers=self.api._build_hf_headers(), json={ "repository": f"{USER}/new_repo_id", "private": True, "hardware": "t4-medium", "storageTier": "large", "sleepTimeSeconds": 123, "secrets": [ {"key": "Testsecret", "value": "Testvalue", "description": "Testdescription"}, {"key": "Testsecret2", "value": "Testvalue"}, ], "variables": [ {"key": "Testvariable", "value": "Testvalue", "description": "Testdescription"}, {"key": "Testvariable2", "value": "Testvalue"}, ], }, ) def test_request_space_hardware_no_sleep_time(self) -> None: self.api.request_space_hardware(self.repo_id, SpaceHardware.T4_MEDIUM) self.post_mock.assert_called_once_with( f"{self.api.endpoint}/api/spaces/{self.repo_id}/hardware", headers=self.api._build_hf_headers(), json={"flavor": "t4-medium"}, ) def test_request_space_hardware_with_sleep_time(self) -> None: self.api.request_space_hardware(self.repo_id, SpaceHardware.T4_MEDIUM, sleep_time=123) self.post_mock.assert_called_once_with( f"{self.api.endpoint}/api/spaces/{self.repo_id}/hardware", headers=self.api._build_hf_headers(), json={"flavor": "t4-medium", "sleepTimeSeconds": 123}, ) def test_set_space_sleep_time_upgraded_hardware(self) -> None: self.api.set_space_sleep_time(self.repo_id, sleep_time=123) self.post_mock.assert_called_once_with( f"{self.api.endpoint}/api/spaces/{self.repo_id}/sleeptime", headers=self.api._build_hf_headers(), json={"seconds": 123}, ) def test_set_space_sleep_time_cpu_basic(self) -> None: self.post_mock.return_value.json.return_value["hardware"]["requested"] = "cpu-basic" with self.assertWarns(UserWarning): self.api.set_space_sleep_time(self.repo_id, sleep_time=123) def test_request_space_storage(self) -> None: runtime = self.api.request_space_storage(self.repo_id, SpaceStorage.LARGE) self.post_mock.assert_called_once_with( f"{self.api.endpoint}/api/spaces/{self.repo_id}/storage", headers=self.api._build_hf_headers(), json={"tier": "large"}, ) assert runtime.storage == SpaceStorage.LARGE def test_delete_space_storage(self) -> None: runtime = self.api.delete_space_storage(self.repo_id) self.delete_mock.assert_called_once_with( f"{self.api.endpoint}/api/spaces/{self.repo_id}/storage", headers=self.api._build_hf_headers(), ) assert runtime.storage is None def test_restart_space_factory_reboot(self) -> None: self.api.restart_space(self.repo_id, factory_reboot=True) self.post_mock.assert_called_once_with( f"{self.api.endpoint}/api/spaces/{self.repo_id}/restart", headers=self.api._build_hf_headers(), params={"factory": "true"}, ) class ListGitRefsTest(unittest.TestCase): @classmethod @with_production_testing def setUpClass(cls) -> None: cls.api = HfApi() return super().setUpClass() def test_list_refs_gpt2(self) -> None: refs = self.api.list_repo_refs("gpt2") self.assertGreater(len(refs.branches), 0) main_branch = [branch for branch in refs.branches if branch.name == "main"][0] self.assertEqual(main_branch.ref, "refs/heads/main") self.assertIsNone(refs.pull_requests) # Can get info by revision self.api.repo_info("gpt2", revision=main_branch.target_commit) def test_list_refs_bigcode(self) -> None: refs = self.api.list_repo_refs("bigcode/admin", repo_type="dataset") self.assertGreater(len(refs.branches), 0) self.assertGreater(len(refs.converts), 0) self.assertIsNone(refs.pull_requests) main_branch = [branch for branch in refs.branches if branch.name == "main"][0] self.assertEqual(main_branch.ref, "refs/heads/main") convert_branch = [branch for branch in refs.converts if branch.name == "parquet"][0] self.assertEqual(convert_branch.ref, "refs/convert/parquet") # Can get info by convert revision self.api.repo_info( "bigcode/admin", repo_type="dataset", revision=convert_branch.target_commit, ) def test_list_refs_with_prs(self) -> None: refs = self.api.list_repo_refs("openchat/openchat_3.5", include_pull_requests=True) self.assertGreater(len(refs.pull_requests), 1) assert refs.pull_requests[0].ref.startswith("refs/pr/") class ListGitCommitsTest(unittest.TestCase): @classmethod def setUpClass(cls) -> None: cls.api = HfApi(token=TOKEN) # Create repo (with initial commit) cls.repo_id = cls.api.create_repo(repo_name()).repo_id # Create a commit on `main` branch cls.api.upload_file(repo_id=cls.repo_id, path_or_fileobj=b"content", path_in_repo="content.txt") # Create a commit in a PR cls.api.upload_file(repo_id=cls.repo_id, path_or_fileobj=b"on_pr", path_in_repo="on_pr.txt", create_pr=True) # Create another commit on `main` branch cls.api.upload_file(repo_id=cls.repo_id, path_or_fileobj=b"on_main", path_in_repo="on_main.txt") return super().setUpClass() @classmethod def tearDownClass(cls) -> None: cls.api.delete_repo(cls.repo_id) return super().tearDownClass() def test_list_commits_on_main(self) -> None: commits = self.api.list_repo_commits(self.repo_id) # "on_pr" commit not returned self.assertEqual(len(commits), 3) assert all("on_pr" not in commit.title for commit in commits) # USER is always the author assert all(commit.authors == [USER] for commit in commits) # latest commit first self.assertEqual(commits[0].title, "Upload on_main.txt with huggingface_hub") # Formatted field not returned by default for commit in commits: self.assertIsNone(commit.formatted_title) self.assertIsNone(commit.formatted_message) def test_list_commits_on_pr(self) -> None: commits = self.api.list_repo_commits(self.repo_id, revision="refs/pr/1") # "on_pr" commit returned but not the "on_main" one self.assertEqual(len(commits), 3) assert all("on_main" not in commit.title for commit in commits) self.assertEqual(commits[0].title, "Upload on_pr.txt with huggingface_hub") def test_list_commits_include_formatted(self) -> None: for commit in self.api.list_repo_commits(self.repo_id, formatted=True): self.assertIsNotNone(commit.formatted_title) self.assertIsNotNone(commit.formatted_message) def test_list_commits_on_missing_repo(self) -> None: with self.assertRaises(RepositoryNotFoundError): self.api.list_repo_commits("missing_repo_id") def test_list_commits_on_missing_revision(self) -> None: with self.assertRaises(RevisionNotFoundError): self.api.list_repo_commits(self.repo_id, revision="missing_revision") @patch("huggingface_hub.hf_api.build_hf_headers") class HfApiTokenAttributeTest(unittest.TestCase): def test_token_passed(self, mock_build_hf_headers: Mock) -> None: HfApi(token="default token")._build_hf_headers(token="A token") self._assert_token_is(mock_build_hf_headers, "A token") def test_no_token_passed(self, mock_build_hf_headers: Mock) -> None: HfApi(token="default token")._build_hf_headers() self._assert_token_is(mock_build_hf_headers, "default token") def test_token_true_passed(self, mock_build_hf_headers: Mock) -> None: HfApi(token="default token")._build_hf_headers(token=True) self._assert_token_is(mock_build_hf_headers, True) def test_token_false_passed(self, mock_build_hf_headers: Mock) -> None: HfApi(token="default token")._build_hf_headers(token=False) self._assert_token_is(mock_build_hf_headers, False) def test_no_token_at_all(self, mock_build_hf_headers: Mock) -> None: HfApi()._build_hf_headers(token=None) self._assert_token_is(mock_build_hf_headers, None) def _assert_token_is(self, mock_build_hf_headers: Mock, expected_value: str) -> None: self.assertEqual(mock_build_hf_headers.call_args[1]["token"], expected_value) def test_library_name_and_version_are_set(self, mock_build_hf_headers: Mock) -> None: HfApi(library_name="a", library_version="b")._build_hf_headers() self.assertEqual(mock_build_hf_headers.call_args[1]["library_name"], "a") self.assertEqual(mock_build_hf_headers.call_args[1]["library_version"], "b") def test_library_name_and_version_are_overwritten(self, mock_build_hf_headers: Mock) -> None: api = HfApi(library_name="a", library_version="b") api._build_hf_headers(library_name="A", library_version="B") self.assertEqual(mock_build_hf_headers.call_args[1]["library_name"], "A") self.assertEqual(mock_build_hf_headers.call_args[1]["library_version"], "B") def test_user_agent_is_set(self, mock_build_hf_headers: Mock) -> None: HfApi(user_agent={"a": "b"})._build_hf_headers() self.assertEqual(mock_build_hf_headers.call_args[1]["user_agent"], {"a": "b"}) def test_user_agent_is_overwritten(self, mock_build_hf_headers: Mock) -> None: HfApi(user_agent={"a": "b"})._build_hf_headers(user_agent={"A": "B"}) self.assertEqual(mock_build_hf_headers.call_args[1]["user_agent"], {"A": "B"}) @patch("huggingface_hub.constants.ENDPOINT", ENDPOINT_PRODUCTION) class RepoUrlTest(unittest.TestCase): def test_repo_url_class(self): url = RepoUrl("https://huggingface.co/gpt2") # RepoUrl Is a string self.assertIsInstance(url, str) self.assertEqual(url, "https://huggingface.co/gpt2") # Any str-method can be applied self.assertEqual(url.split("/"), "https://huggingface.co/gpt2".split("/")) # String formatting and concatenation work self.assertEqual(f"New repo: {url}", "New repo: https://huggingface.co/gpt2") self.assertEqual("New repo: " + url, "New repo: https://huggingface.co/gpt2") # __repr__ is modified for debugging purposes self.assertEqual( repr(url), "RepoUrl('https://huggingface.co/gpt2'," " endpoint='https://huggingface.co', repo_type='model', repo_id='gpt2')", ) def test_repo_url_endpoint(self): # Implicit endpoint url = RepoUrl("https://huggingface.co/gpt2") self.assertEqual(url.endpoint, ENDPOINT_PRODUCTION) # Explicit endpoint url = RepoUrl("https://example.com/gpt2", endpoint="https://example.com") self.assertEqual(url.endpoint, "https://example.com") def test_repo_url_repo_type(self): # Explicit repo type url = RepoUrl("https://huggingface.co/user/repo_name") self.assertEqual(url.repo_type, "model") url = RepoUrl("https://huggingface.co/datasets/user/repo_name") self.assertEqual(url.repo_type, "dataset") url = RepoUrl("https://huggingface.co/spaces/user/repo_name") self.assertEqual(url.repo_type, "space") # Implicit repo type (model) url = RepoUrl("https://huggingface.co/user/repo_name") self.assertEqual(url.repo_type, "model") def test_repo_url_namespace(self): # Canonical model (e.g. no username) url = RepoUrl("https://huggingface.co/gpt2") self.assertIsNone(url.namespace) self.assertEqual(url.repo_id, "gpt2") # "Normal" model url = RepoUrl("https://huggingface.co/dummy_user/dummy_model") self.assertEqual(url.namespace, "dummy_user") self.assertEqual(url.repo_id, "dummy_user/dummy_model") def test_repo_url_url_property(self): # RepoUrl.url returns a pure `str` value url = RepoUrl("https://huggingface.co/gpt2") self.assertEqual(url, "https://huggingface.co/gpt2") self.assertEqual(url.url, "https://huggingface.co/gpt2") self.assertIsInstance(url, RepoUrl) self.assertNotIsInstance(url.url, RepoUrl) def test_repo_url_canonical_model(self): for _id in ("gpt2", "hf://gpt2", "https://huggingface.co/gpt2"): with self.subTest(_id): url = RepoUrl(_id) self.assertEqual(url.repo_id, "gpt2") self.assertEqual(url.repo_type, "model") def test_repo_url_canonical_dataset(self): for _id in ("datasets/squad", "hf://datasets/squad", "https://huggingface.co/datasets/squad"): with self.subTest(_id): url = RepoUrl(_id) self.assertEqual(url.repo_id, "squad") self.assertEqual(url.repo_type, "dataset") def test_repo_url_in_commit_info(self): info = CommitInfo( commit_url="https://huggingface.co/Wauplin/test-repo-id-mixin/commit/52d172a8b276e529d5260d6f3f76c85be5889dee", commit_message="Dummy message", commit_description="Dummy description", oid="52d172a8b276e529d5260d6f3f76c85be5889dee", pr_url=None, ) assert isinstance(info.repo_url, RepoUrl) assert info.repo_url.endpoint == "https://huggingface.co" assert info.repo_url.repo_id == "Wauplin/test-repo-id-mixin" assert info.repo_url.repo_type == "model" class HfApiDuplicateSpaceTest(HfApiCommonTest): @unittest.skip("Duplicating Space doesn't work on staging.") def test_duplicate_space_success(self) -> None: """Check `duplicate_space` works.""" from_repo_name = repo_name() from_repo_id = self._api.create_repo( repo_id=from_repo_name, repo_type="space", space_sdk="static", token=OTHER_TOKEN, ).repo_id self._api.upload_file( path_or_fileobj=b"data", path_in_repo="temp/new_file.md", repo_id=from_repo_id, repo_type="space", token=OTHER_TOKEN, ) to_repo_id = self._api.duplicate_space(from_repo_id).repo_id assert to_repo_id == f"{USER}/{from_repo_name}" assert self._api.list_repo_files(repo_id=from_repo_id, repo_type="space") == [ ".gitattributes", "README.md", "index.html", "style.css", "temp/new_file.md", ] assert self._api.list_repo_files(repo_id=to_repo_id, repo_type="space") == self._api.list_repo_files( repo_id=from_repo_id, repo_type="space" ) self._api.delete_repo(repo_id=from_repo_id, repo_type="space", token=OTHER_TOKEN) self._api.delete_repo(repo_id=to_repo_id, repo_type="space") def test_duplicate_space_from_missing_repo(self) -> None: """Check `duplicate_space` fails when the from_repo doesn't exist.""" with self.assertRaises(RepositoryNotFoundError): self._api.duplicate_space(f"{OTHER_USER}/repo_that_does_not_exist") class CollectionAPITest(HfApiCommonTest): def setUp(self) -> None: id = uuid.uuid4() self.title = f"My cool stuff {id}" self.slug_prefix = f"{USER}/my-cool-stuff-{id}" self.slug: Optional[str] = None # Populated by the tests => use to delete in tearDown return super().setUp() def tearDown(self) -> None: if self.slug is not None: # Delete collection even if test failed self._api.delete_collection(self.slug, missing_ok=True) return super().tearDown() @with_production_testing def test_list_collections(self) -> None: item_id = "teknium/OpenHermes-2.5-Mistral-7B" item_type = "model" limit = 3 collections = HfApi().list_collections(item=f"{item_type}s/{item_id}", limit=limit) # Check return type self.assertIsInstance(collections, Iterable) collections = list(collections) # Check length self.assertEqual(len(collections), limit) # Check all collections contain the item for collection in collections: # all items are not necessarily returned when listing collections => retrieve complete one full_collection = HfApi().get_collection(collection.slug) assert any(item.item_id == item_id and item.item_type == item_type for item in full_collection.items) def test_create_collection_with_description(self) -> None: collection = self._api.create_collection(self.title, description="Contains a lot of cool stuff") self.slug = collection.slug self.assertIsInstance(collection, Collection) self.assertEqual(collection.title, self.title) self.assertEqual(collection.description, "Contains a lot of cool stuff") self.assertEqual(collection.items, []) assert collection.slug.startswith(self.slug_prefix) self.assertEqual(collection.url, f"{ENDPOINT_STAGING}/collections/{collection.slug}") @pytest.mark.skip("Creating duplicated collections work on staging") def test_create_collection_exists_ok(self) -> None: # Create collection once without description collection_1 = self._api.create_collection(self.title) self.slug = collection_1.slug # Cannot create twice with same title with self.assertRaises(HTTPError): # already exists self._api.create_collection(self.title) # Can ignore error collection_2 = self._api.create_collection(self.title, description="description", exists_ok=True) self.assertEqual(collection_1.slug, collection_2.slug) self.assertIsNone(collection_1.description) self.assertIsNone(collection_2.description) # Did not got updated! def test_create_private_collection(self) -> None: collection = self._api.create_collection(self.title, private=True) self.slug = collection.slug # Get private collection self._api.get_collection(collection.slug) # no error with self.assertRaises(HTTPError): self._api.get_collection(collection.slug, token=OTHER_TOKEN) # not authorized # Get public collection self._api.update_collection_metadata(collection.slug, private=False) self._api.get_collection(collection.slug) # no error self._api.get_collection(collection.slug, token=OTHER_TOKEN) # no error def test_update_collection(self) -> None: # Create collection collection_1 = self._api.create_collection(self.title) self.slug = collection_1.slug # Update metadata new_title = f"New title {uuid.uuid4()}" collection_2 = self._api.update_collection_metadata( collection_slug=collection_1.slug, title=new_title, description="New description", private=True, theme="pink", ) self.assertEqual(collection_2.title, new_title) self.assertEqual(collection_2.description, "New description") self.assertEqual(collection_2.private, True) self.assertEqual(collection_2.theme, "pink") self.assertNotEqual(collection_1.slug, collection_2.slug) # Different slug, same id self.assertEqual(collection_1.slug.split("-")[-1], collection_2.slug.split("-")[-1]) # Works with both slugs, same collection returned self.assertEqual(self._api.get_collection(collection_1.slug).slug, collection_2.slug) self.assertEqual(self._api.get_collection(collection_2.slug).slug, collection_2.slug) def test_delete_collection(self) -> None: collection = self._api.create_collection(self.title) self._api.delete_collection(collection.slug) # Cannot delete twice the same collection with self.assertRaises(HTTPError): # already exists self._api.delete_collection(collection.slug) # Possible to ignore error self._api.delete_collection(collection.slug, missing_ok=True) def test_collection_items(self) -> None: # Create some repos model_id = self._api.create_repo(repo_name()).repo_id dataset_id = self._api.create_repo(repo_name(), repo_type="dataset").repo_id # Create collection + add items to it collection = self._api.create_collection(self.title) self._api.add_collection_item(collection.slug, model_id, "model", note="This is my model") self._api.add_collection_item(collection.slug, dataset_id, "dataset") # note is optional # Check consistency collection = self._api.get_collection(collection.slug) self.assertEqual(len(collection.items), 2) self.assertEqual(collection.items[0].item_id, model_id) self.assertEqual(collection.items[0].item_type, "model") self.assertEqual(collection.items[0].note, "This is my model") self.assertEqual(collection.items[1].item_id, dataset_id) self.assertEqual(collection.items[1].item_type, "dataset") self.assertIsNone(collection.items[1].note) # Add existing item fails (except if ignore error) with self.assertRaises(HTTPError): self._api.add_collection_item(collection.slug, model_id, "model") self._api.add_collection_item(collection.slug, model_id, "model", exists_ok=True) # Add inexistent item fails with self.assertRaises(HTTPError): self._api.add_collection_item(collection.slug, model_id, "dataset") # Update first item self._api.update_collection_item( collection.slug, collection.items[0].item_object_id, note="New note", position=1 ) # Check consistency collection = self._api.get_collection(collection.slug) self.assertEqual(collection.items[0].item_id, dataset_id) # position got updated self.assertEqual(collection.items[1].item_id, model_id) self.assertEqual(collection.items[1].note, "New note") # note got updated # Delete last item self._api.delete_collection_item(collection.slug, collection.items[1].item_object_id) self._api.delete_collection_item(collection.slug, collection.items[1].item_object_id, missing_ok=True) # Check consistency collection = self._api.get_collection(collection.slug) self.assertEqual(len(collection.items), 1) # only 1 item remaining self.assertEqual(collection.items[0].item_id, dataset_id) # position got updated # Delete everything self._api.delete_repo(model_id) self._api.delete_repo(dataset_id, repo_type="dataset") self._api.delete_collection(collection.slug) class AccessRequestAPITest(HfApiCommonTest): def setUp(self) -> None: # Setup test with a gated repo super().setUp() self.repo_id = self._api.create_repo(repo_name()).repo_id response = get_session().put( f"{self._api.endpoint}/api/models/{self.repo_id}/settings", json={"gated": "auto"}, headers=self._api._build_hf_headers(), ) hf_raise_for_status(response) def tearDown(self) -> None: self._api.delete_repo(self.repo_id) return super().tearDown() def test_access_requests_normal_usage(self) -> None: # No access requests initially requests = self._api.list_accepted_access_requests(self.repo_id) assert len(requests) == 0 requests = self._api.list_pending_access_requests(self.repo_id) assert len(requests) == 0 requests = self._api.list_rejected_access_requests(self.repo_id) assert len(requests) == 0 # Grant access to a user self._api.grant_access(self.repo_id, OTHER_USER) # User is in accepted list requests = self._api.list_accepted_access_requests(self.repo_id) assert len(requests) == 1 request = requests[0] assert isinstance(request, AccessRequest) assert request.username == OTHER_USER assert request.email is None # email not shared when granted access manually assert request.status == "accepted" assert isinstance(request.timestamp, datetime.datetime) # Cancel access self._api.cancel_access_request(self.repo_id, OTHER_USER) requests = self._api.list_accepted_access_requests(self.repo_id) assert len(requests) == 0 # not accepted anymore requests = self._api.list_pending_access_requests(self.repo_id) assert len(requests) == 1 assert requests[0].username == OTHER_USER # Reject access self._api.reject_access_request(self.repo_id, OTHER_USER, rejection_reason="This is a rejection reason") requests = self._api.list_pending_access_requests(self.repo_id) assert len(requests) == 0 # not pending anymore requests = self._api.list_rejected_access_requests(self.repo_id) assert len(requests) == 1 assert requests[0].username == OTHER_USER # Accept again self._api.accept_access_request(self.repo_id, OTHER_USER) requests = self._api.list_accepted_access_requests(self.repo_id) assert len(requests) == 1 assert requests[0].username == OTHER_USER def test_access_request_error(self): # Grant access to a user self._api.grant_access(self.repo_id, OTHER_USER) # Cannot grant twice with self.assertRaises(HTTPError): self._api.grant_access(self.repo_id, OTHER_USER) # Cannot accept to already accepted with self.assertRaises(HTTPError): self._api.accept_access_request(self.repo_id, OTHER_USER) # Cannot reject to already rejected self._api.reject_access_request(self.repo_id, OTHER_USER, rejection_reason="This is a rejection reason") with self.assertRaises(HTTPError): self._api.reject_access_request(self.repo_id, OTHER_USER, rejection_reason="This is a rejection reason") # Cannot cancel to already cancelled self._api.cancel_access_request(self.repo_id, OTHER_USER) with self.assertRaises(HTTPError): self._api.cancel_access_request(self.repo_id, OTHER_USER) @with_production_testing class UserApiTest(unittest.TestCase): @classmethod def setUpClass(cls) -> None: cls.api = HfApi() # no auth! def test_user_overview(self) -> None: overview = self.api.get_user_overview("julien-c") assert overview.user_type == "user" assert overview.username == "julien-c" assert overview.num_likes > 10 assert overview.num_upvotes > 10 assert len(overview.orgs) > 0 assert any(org.name == "huggingface" for org in overview.orgs) assert overview.num_following > 300 assert overview.num_followers > 1000 def test_organization_members(self) -> None: members = self.api.list_organization_members("huggingface") assert len(list(members)) > 1 def test_user_followers(self) -> None: followers = self.api.list_user_followers("clem") assert len(list(followers)) > 500 def test_user_following(self) -> None: following = self.api.list_user_following("clem") assert len(list(following)) > 500 class PaperApiTest(unittest.TestCase): @classmethod @with_production_testing def setUpClass(cls) -> None: cls.api = HfApi() return super().setUpClass() def test_papers_by_query(self) -> None: papers = list(self.api.list_papers(query="llama")) assert len(papers) > 0 assert "The Llama 3 Herd of Models" in [paper.title for paper in papers] def test_get_paper_by_id_success(self) -> None: paper = self.api.paper_info("2407.21783") assert paper.title == "The Llama 3 Herd of Models" def test_get_paper_by_id_not_found(self) -> None: with self.assertRaises(HfHubHTTPError) as context: self.api.paper_info("1234.56789") assert context.exception.response.status_code == 404 class WebhookApiTest(HfApiCommonTest): def setUp(self) -> None: super().setUp() self.webhook_url = "https://webhook.site/test" self.watched_items = [ WebhookWatchedItem(type="user", name="julien-c"), # can be either a dataclass {"type": "org", "name": "HuggingFaceH4"}, # or a simple dictionary ] self.domains = ["repo", "discussion"] self.secret = "my-secret" # Create a webhook to be used in the tests self.webhook = self._api.create_webhook( url=self.webhook_url, watched=self.watched_items, domains=self.domains, secret=self.secret ) def tearDown(self) -> None: # Clean up the created webhook self._api.delete_webhook(self.webhook.id) super().tearDown() def test_get_webhook(self) -> None: webhook = self._api.get_webhook(self.webhook.id) self.assertIsInstance(webhook, WebhookInfo) self.assertEqual(webhook.id, self.webhook.id) self.assertEqual(webhook.url, self.webhook_url) def test_list_webhooks(self) -> None: webhooks = self._api.list_webhooks() assert any(webhook.id == self.webhook.id for webhook in webhooks) def test_create_webhook(self) -> None: new_webhook = self._api.create_webhook( url=self.webhook_url, watched=self.watched_items, domains=self.domains, secret=self.secret ) self.assertIsInstance(new_webhook, WebhookInfo) self.assertEqual(new_webhook.url, self.webhook_url) # Clean up the newly created webhook self._api.delete_webhook(new_webhook.id) def test_update_webhook(self) -> None: updated_url = "https://webhook.site/new" updated_webhook = self._api.update_webhook( self.webhook.id, url=updated_url, watched=self.watched_items, domains=self.domains, secret=self.secret ) self.assertEqual(updated_webhook.url, updated_url) def test_enable_webhook(self) -> None: enabled_webhook = self._api.enable_webhook(self.webhook.id) self.assertFalse(enabled_webhook.disabled) def test_disable_webhook(self) -> None: disabled_webhook = self._api.disable_webhook(self.webhook.id) assert disabled_webhook.disabled def test_delete_webhook(self) -> None: # Create another webhook to test deletion webhook_to_delete = self._api.create_webhook( url=self.webhook_url, watched=self.watched_items, domains=self.domains, secret=self.secret ) self._api.delete_webhook(webhook_to_delete.id) with self.assertRaises(HTTPError): self._api.get_webhook(webhook_to_delete.id) class TestExpandPropertyType(HfApiCommonTest): @use_tmp_repo(repo_type="model") def test_expand_model_property_type_is_up_to_date(self, repo_url: RepoUrl): self._check_expand_property_is_up_to_date(repo_url) @use_tmp_repo(repo_type="dataset") def test_expand_dataset_property_type_is_up_to_date(self, repo_url: RepoUrl): self._check_expand_property_is_up_to_date(repo_url) @use_tmp_repo(repo_type="space") def test_expand_space_property_type_is_up_to_date(self, repo_url: RepoUrl): self._check_expand_property_is_up_to_date(repo_url) def _check_expand_property_is_up_to_date(self, repo_url: RepoUrl): repo_id = repo_url.repo_id repo_type = repo_url.repo_type property_type = ( ExpandModelProperty_T if repo_type == "model" else (ExpandDatasetProperty_T if repo_type == "dataset" else ExpandSpaceProperty_T) ) property_type_name = ( "ExpandModelProperty_T" if repo_type == "model" else ("ExpandDatasetProperty_T" if repo_type == "dataset" else "ExpandSpaceProperty_T") ) try: self._api.repo_info(repo_id=repo_id, repo_type=repo_type, expand=["does_not_exist"]) raise Exception("Should have raised an exception") except HfHubHTTPError as e: assert e.response.status_code == 400 message = e.response.json()["error"] assert message.startswith('"expand" must be one of ') defined_args = set(get_args(property_type)) expected_args = set(message.replace('"expand" must be one of ', "").strip("[]").split(", ")) expected_args.discard("gitalyUid") # internal one, do not document if defined_args != expected_args: should_be_removed = defined_args - expected_args should_be_added = expected_args - defined_args msg = f"Literal `{property_type_name}` is outdated! This is probably due to a server-side update." if should_be_removed: msg += f"\nArg(s) not supported anymore: {', '.join(should_be_removed)}" if should_be_added: msg += f"\nNew arg(s) to support: {', '.join(should_be_added)}" msg += f"\nPlease open a PR to update `./src/huggingface_hub/hf_api.py` accordingly. `{property_type_name}` should be updated as well as `{repo_type}_info` and `list_{repo_type}s` docstrings." msg += "\nThank you in advance!" raise ValueError(msg) class TestLargeUpload(HfApiCommonTest): @use_tmp_repo(repo_type="dataset") def test_upload_large_folder(self, repo_url: RepoUrl) -> None: N_FILES_PER_FOLDER = 4 with SoftTemporaryDirectory() as tmpdir: folder = Path(tmpdir) / "large_folder" # Create 16 LFS files + 16 regular files for i in range(N_FILES_PER_FOLDER): subfolder = folder / f"subfolder_{i}" subfolder.mkdir(parents=True, exist_ok=True) for j in range(N_FILES_PER_FOLDER): (subfolder / f"file_lfs_{i}_{j}.bin").write_bytes(f"content_lfs_{i}_{j}".encode()) (subfolder / f"file_regular_{i}_{j}.txt").write_bytes(f"content_regular_{i}_{j}".encode()) # Upload the folder self._api.upload_large_folder( repo_id=repo_url.repo_id, repo_type=repo_url.repo_type, folder_path=folder, num_workers=4 ) # Check all files have been uploaded uploaded_files = self._api.list_repo_files(repo_url.repo_id, repo_type=repo_url.repo_type) for i in range(N_FILES_PER_FOLDER): for j in range(N_FILES_PER_FOLDER): assert f"subfolder_{i}/file_lfs_{i}_{j}.bin" in uploaded_files assert f"subfolder_{i}/file_regular_{i}_{j}.txt" in uploaded_files class TestHfApiAuthCheck(HfApiCommonTest): @use_tmp_repo(repo_type="dataset") def test_auth_check_success(self, repo_url: RepoUrl) -> None: self._api.auth_check(repo_id=repo_url.repo_id, repo_type=repo_url.repo_type) def test_auth_check_repo_missing(self) -> None: with self.assertRaises(RepositoryNotFoundError): self._api.auth_check(repo_id="username/missing_repo_id") def test_auth_check_gated_repo(self) -> None: repo_id = self._api.create_repo(repo_name()).repo_id response = get_session().put( f"{self._api.endpoint}/api/models/{repo_id}/settings", json={"gated": "auto"}, headers=self._api._build_hf_headers(token=TOKEN), ) hf_raise_for_status(response) with self.assertRaises(GatedRepoError): self._api.auth_check(repo_id=repo_id, token=OTHER_TOKEN) class HfApiInferenceCatalogTest(HfApiCommonTest): def test_list_inference_catalog(self) -> None: models = self._api.list_inference_catalog() # note: @experimental api # Check that server returns a list[str] => at least if it changes in the future, we'll notice assert isinstance(models, List) assert len(models) > 0 assert all(isinstance(model, str) for model in models) @patch("huggingface_hub.hf_api.get_session") def test_create_inference_endpoint_from_catalog(self, mock_get_session: Mock) -> None: mock_response = Mock() mock_response.json.return_value = { "endpoint": { "compute": { "accelerator": "gpu", "id": "aws-us-east-1-nvidia-l4-x1", "instanceSize": "x1", "instanceType": "nvidia-l4", "scaling": { "maxReplica": 1, "measure": {"hardwareUsage": None}, "metric": "hardwareUsage", "minReplica": 0, "scaleToZeroTimeout": 15, }, }, "model": { "env": {}, "framework": "pytorch", "image": { "tgi": { "disableCustomKernels": False, "healthRoute": "/health", "port": 80, "url": "ghcr.io/huggingface/text-generation-inference:3.1.1", } }, "repository": "meta-llama/Llama-3.2-3B-Instruct", "revision": "0cb88a4f764b7a12671c53f0838cd831a0843b95", "secrets": {}, "task": "text-generation", }, "name": "llama-3-2-3b-instruct-eey", "provider": {"region": "us-east-1", "vendor": "aws"}, "status": { "createdAt": "2025-03-07T15:30:13.949Z", "createdBy": {"id": "6273f303f6d63a28483fde12", "name": "Wauplin"}, "message": "Endpoint waiting to be scheduled", "readyReplica": 0, "state": "pending", "targetReplica": 1, "updatedAt": "2025-03-07T15:30:13.949Z", "updatedBy": {"id": "6273f303f6d63a28483fde12", "name": "Wauplin"}, }, "type": "protected", } } mock_get_session.return_value.post.return_value = mock_response endpoint = self._api.create_inference_endpoint_from_catalog( repo_id="meta-llama/Llama-3.2-3B-Instruct", namespace="Wauplin" ) assert isinstance(endpoint, InferenceEndpoint) assert endpoint.name == "llama-3-2-3b-instruct-eey" huggingface_hub-0.31.1/tests/test_hf_file_system.py000066400000000000000000000647121500667546600225050ustar00rootroot00000000000000import copy import datetime import io import os import tempfile import unittest from pathlib import Path from typing import Optional from unittest.mock import patch import fsspec import pytest from huggingface_hub import hf_file_system from huggingface_hub.errors import RepositoryNotFoundError, RevisionNotFoundError from huggingface_hub.hf_file_system import HfFileSystem, HfFileSystemFile, HfFileSystemStreamFile from .testing_constants import ENDPOINT_STAGING, TOKEN from .testing_utils import repo_name class HfFileSystemTests(unittest.TestCase): @classmethod def setUpClass(cls): """Register `HfFileSystem` as a `fsspec` filesystem if not already registered.""" if HfFileSystem.protocol not in fsspec.available_protocols(): fsspec.register_implementation(HfFileSystem.protocol, HfFileSystem) def setUp(self): self.hffs = HfFileSystem(endpoint=ENDPOINT_STAGING, token=TOKEN, skip_instance_cache=True) self.api = self.hffs._api # Create dummy repo repo_url = self.api.create_repo(repo_name(), repo_type="dataset") self.repo_id = repo_url.repo_id self.hf_path = f"datasets/{self.repo_id}" # Upload files self.api.upload_file( path_or_fileobj=b"dummy binary data on pr", path_in_repo="data/binary_data_for_pr.bin", repo_id=self.repo_id, repo_type="dataset", create_pr=True, ) self.api.upload_file( path_or_fileobj="dummy text data".encode("utf-8"), path_in_repo="data/text_data.txt", repo_id=self.repo_id, repo_type="dataset", ) self.api.upload_file( path_or_fileobj=b"dummy binary data", path_in_repo="data/binary_data.bin", repo_id=self.repo_id, repo_type="dataset", ) self.text_file = self.hf_path + "/data/text_data.txt" def tearDown(self): self.api.delete_repo(self.repo_id, repo_type="dataset") def test_info(self): root_dir = self.hffs.info(self.hf_path) self.assertEqual(root_dir["type"], "directory") self.assertEqual(root_dir["size"], 0) self.assertTrue(root_dir["name"].endswith(self.repo_id)) self.assertIsNotNone(root_dir["last_commit"]) data_dir = self.hffs.info(self.hf_path + "/data") self.assertEqual(data_dir["type"], "directory") self.assertEqual(data_dir["size"], 0) self.assertTrue(data_dir["name"].endswith("/data")) self.assertIsNotNone(data_dir["last_commit"]) self.assertIsNotNone(data_dir["tree_id"]) text_data_file = self.hffs.info(self.text_file) self.assertEqual(text_data_file["type"], "file") self.assertGreater(text_data_file["size"], 0) # not empty self.assertTrue(text_data_file["name"].endswith("/data/text_data.txt")) self.assertIsNone(text_data_file["lfs"]) self.assertIsNotNone(text_data_file["last_commit"]) self.assertIsNotNone(text_data_file["blob_id"]) self.assertIn("security", text_data_file) # the staging endpoint does not run security checks # cached info self.assertEqual(self.hffs.info(self.text_file), text_data_file) def test_glob(self): self.assertEqual( self.hffs.glob(self.hf_path + "/.gitattributes"), [self.hf_path + "/.gitattributes"], ) self.assertEqual( sorted(self.hffs.glob(self.hf_path + "/*")), sorted([self.hf_path + "/.gitattributes", self.hf_path + "/data"]), ) self.assertEqual( sorted(self.hffs.glob(self.hf_path + "/*", revision="main")), sorted([self.hf_path + "/.gitattributes", self.hf_path + "/data"]), ) self.assertEqual( sorted(self.hffs.glob(self.hf_path + "@main" + "/*")), sorted([self.hf_path + "@main" + "/.gitattributes", self.hf_path + "@main" + "/data"]), ) self.assertEqual( self.hffs.glob(self.hf_path + "@refs%2Fpr%2F1" + "/data/*"), [self.hf_path + "@refs%2Fpr%2F1" + "/data/binary_data_for_pr.bin"], ) self.assertEqual( self.hffs.glob(self.hf_path + "@refs/pr/1" + "/data/*"), [self.hf_path + "@refs/pr/1" + "/data/binary_data_for_pr.bin"], ) self.assertEqual( self.hffs.glob(self.hf_path + "/data/*", revision="refs/pr/1"), [self.hf_path + "@refs/pr/1" + "/data/binary_data_for_pr.bin"], ) self.assertIsNone( self.hffs.dircache[self.hf_path + "@main"][0]["last_commit"] ) # no detail -> no last_commit in cache files = self.hffs.glob(self.hf_path + "@main" + "/*", detail=True, expand_info=False) self.assertIsInstance(files, dict) self.assertEqual(len(files), 2) keys = sorted(files) self.assertTrue( files[keys[0]]["name"].endswith("/.gitattributes") and files[keys[1]]["name"].endswith("/data") ) self.assertIsNone( self.hffs.dircache[self.hf_path + "@main"][0]["last_commit"] ) # detail but no expand info -> no last_commit in cache files = self.hffs.glob(self.hf_path + "@main" + "/*", detail=True) self.assertIsInstance(files, dict) self.assertEqual(len(files), 2) keys = sorted(files) self.assertTrue( files[keys[0]]["name"].endswith("/.gitattributes") and files[keys[1]]["name"].endswith("/data") ) self.assertIsNotNone(files[keys[0]]["last_commit"]) def test_url(self): self.assertEqual( self.hffs.url(self.text_file), f"{ENDPOINT_STAGING}/datasets/{self.repo_id}/resolve/main/data/text_data.txt", ) self.assertEqual( self.hffs.url(self.hf_path + "/data"), f"{ENDPOINT_STAGING}/datasets/{self.repo_id}/tree/main/data", ) def test_file_type(self): self.assertTrue( self.hffs.isdir(self.hf_path + "/data") and not self.hffs.isdir(self.hf_path + "/.gitattributes") ) self.assertTrue(self.hffs.isfile(self.text_file) and not self.hffs.isfile(self.hf_path + "/data")) def test_remove_file(self): self.hffs.rm_file(self.text_file) self.assertEqual(self.hffs.glob(self.hf_path + "/data/*"), [self.hf_path + "/data/binary_data.bin"]) self.hffs.rm_file(self.hf_path + "@refs/pr/1" + "/data/binary_data_for_pr.bin") self.assertEqual(self.hffs.glob(self.hf_path + "@refs/pr/1" + "/data/*"), []) def test_remove_directory(self): self.hffs.rm(self.hf_path + "/data", recursive=True) self.assertNotIn(self.hf_path + "/data", self.hffs.ls(self.hf_path)) self.hffs.rm(self.hf_path + "@refs/pr/1" + "/data", recursive=True) self.assertNotIn(self.hf_path + "@refs/pr/1" + "/data", self.hffs.ls(self.hf_path)) def test_read_file(self): with self.hffs.open(self.text_file, "r") as f: self.assertIsInstance(f, io.TextIOWrapper) self.assertIsInstance(f.buffer, HfFileSystemFile) self.assertEqual(f.read(), "dummy text data") def test_stream_file(self): with self.hffs.open(self.hf_path + "/data/binary_data.bin", block_size=0) as f: self.assertIsInstance(f, HfFileSystemStreamFile) self.assertEqual(f.read(), b"dummy binary data") def test_stream_file_retry(self): with self.hffs.open(self.hf_path + "/data/binary_data.bin", block_size=0) as f: self.assertIsInstance(f, HfFileSystemStreamFile) self.assertEqual(f.read(6), b"dummy ") # Simulate that streaming fails mid-way f.response.raw.read = None self.assertEqual(f.read(6), b"binary") self.assertIsNotNone(f.response.raw.read) # a new connection has been created def test_read_file_with_revision(self): with self.hffs.open(self.hf_path + "/data/binary_data_for_pr.bin", "rb", revision="refs/pr/1") as f: self.assertEqual(f.read(), b"dummy binary data on pr") def test_write_file(self): data = "new text data" with self.hffs.open(self.hf_path + "/data/new_text_data.txt", "w") as f: f.write(data) self.assertIn(self.hf_path + "/data/new_text_data.txt", self.hffs.glob(self.hf_path + "/data/*")) with self.hffs.open(self.hf_path + "/data/new_text_data.txt", "r") as f: self.assertEqual(f.read(), data) def test_write_file_multiple_chunks(self): data = "a" * (4 << 20) # 4MB with self.hffs.open(self.hf_path + "/data/new_text_data_big.txt", "w") as f: for _ in range(8): # 32MB in total f.write(data) self.assertIn(self.hf_path + "/data/new_text_data_big.txt", self.hffs.glob(self.hf_path + "/data/*")) with self.hffs.open(self.hf_path + "/data/new_text_data_big.txt", "r") as f: for _ in range(8): self.assertEqual(f.read(len(data)), data) @unittest.skip("Not implemented yet") def test_append_file(self): with self.hffs.open(self.text_file, "a") as f: f.write(" appended text") with self.hffs.open(self.text_file, "r") as f: self.assertEqual(f.read(), "dummy text data appended text") def test_copy_file(self): # Non-LFS file self.assertIsNone(self.hffs.info(self.text_file)["lfs"]) self.hffs.cp_file(self.text_file, self.hf_path + "/data/text_data_copy.txt") with self.hffs.open(self.hf_path + "/data/text_data_copy.txt", "r") as f: self.assertEqual(f.read(), "dummy text data") self.assertIsNone(self.hffs.info(self.hf_path + "/data/text_data_copy.txt")["lfs"]) # LFS file self.assertIsNotNone(self.hffs.info(self.hf_path + "/data/binary_data.bin")["lfs"]) self.hffs.cp_file(self.hf_path + "/data/binary_data.bin", self.hf_path + "/data/binary_data_copy.bin") with self.hffs.open(self.hf_path + "/data/binary_data_copy.bin", "rb") as f: self.assertEqual(f.read(), b"dummy binary data") self.assertIsNotNone(self.hffs.info(self.hf_path + "/data/binary_data_copy.bin")["lfs"]) def test_modified_time(self): self.assertIsInstance(self.hffs.modified(self.text_file), datetime.datetime) self.assertIsInstance(self.hffs.modified(self.hf_path + "/data"), datetime.datetime) # should fail on a non-existing file with self.assertRaises(FileNotFoundError): self.hffs.modified(self.hf_path + "/data/not_existing_file.txt") def test_open_if_not_found(self): # Regression test: opening a missing file should raise a FileNotFoundError. This was not the case before when # opening a file in read mode. with self.assertRaises(FileNotFoundError): self.hffs.open("hf://missing/repo/not_existing_file.txt", mode="r") with self.assertRaises(FileNotFoundError): self.hffs.open("hf://missing/repo/not_existing_file.txt", mode="w") def test_initialize_from_fsspec(self): fs, _, paths = fsspec.get_fs_token_paths( f"hf://datasets/{self.repo_id}/data/text_data.txt", storage_options={ "endpoint": ENDPOINT_STAGING, "token": TOKEN, }, ) self.assertIsInstance(fs, HfFileSystem) self.assertEqual(fs._api.endpoint, ENDPOINT_STAGING) self.assertEqual(fs.token, TOKEN) self.assertEqual(paths, [self.text_file]) fs, _, paths = fsspec.get_fs_token_paths(f"hf://{self.repo_id}/data/text_data.txt") self.assertIsInstance(fs, HfFileSystem) self.assertEqual(paths, [f"{self.repo_id}/data/text_data.txt"]) def test_list_root_directory_no_revision(self): files = self.hffs.ls(self.hf_path) self.assertEqual(len(files), 2) self.assertEqual(files[0]["type"], "directory") self.assertEqual(files[0]["size"], 0) self.assertTrue(files[0]["name"].endswith("/data")) self.assertIsNotNone(files[0]["last_commit"]) self.assertIsNotNone(files[0]["tree_id"]) self.assertEqual(files[1]["type"], "file") self.assertGreater(files[1]["size"], 0) # not empty self.assertTrue(files[1]["name"].endswith("/.gitattributes")) self.assertIsNotNone(files[1]["last_commit"]) self.assertIsNotNone(files[1]["blob_id"]) self.assertIn("security", files[1]) # the staging endpoint does not run security checks def test_list_data_directory_no_revision(self): files = self.hffs.ls(self.hf_path + "/data") self.assertEqual(len(files), 2) self.assertEqual(files[0]["type"], "file") self.assertGreater(files[0]["size"], 0) # not empty self.assertTrue(files[0]["name"].endswith("/data/binary_data.bin")) self.assertIsNotNone(files[0]["lfs"]) self.assertIn("sha256", files[0]["lfs"]) self.assertIn("size", files[0]["lfs"]) self.assertIn("pointer_size", files[0]["lfs"]) self.assertIsNotNone(files[0]["last_commit"]) self.assertIsNotNone(files[0]["blob_id"]) self.assertIn("security", files[0]) # the staging endpoint does not run security checks self.assertEqual(files[1]["type"], "file") self.assertGreater(files[1]["size"], 0) # not empty self.assertTrue(files[1]["name"].endswith("/data/text_data.txt")) self.assertIsNone(files[1]["lfs"]) self.assertIsNotNone(files[1]["last_commit"]) self.assertIsNotNone(files[1]["blob_id"]) self.assertIn("security", files[1]) # the staging endpoint does not run security checks def test_list_data_file_no_revision(self): files = self.hffs.ls(self.text_file) self.assertEqual(len(files), 1) self.assertEqual(files[0]["type"], "file") self.assertGreater(files[0]["size"], 0) # not empty self.assertTrue(files[0]["name"].endswith("/data/text_data.txt")) self.assertIsNone(files[0]["lfs"]) self.assertIsNotNone(files[0]["last_commit"]) self.assertIsNotNone(files[0]["blob_id"]) self.assertIn("security", files[0]) # the staging endpoint does not run security checks def test_list_data_directory_with_revision(self): files = self.hffs.ls(self.hf_path + "@refs%2Fpr%2F1" + "/data") for test_name, files in { "quoted_rev_in_path": self.hffs.ls(self.hf_path + "@refs%2Fpr%2F1" + "/data"), "rev_in_path": self.hffs.ls(self.hf_path + "@refs/pr/1" + "/data"), "rev_as_arg": self.hffs.ls(self.hf_path + "/data", revision="refs/pr/1"), "quoted_rev_in_path_and_rev_as_arg": self.hffs.ls( self.hf_path + "@refs%2Fpr%2F1" + "/data", revision="refs/pr/1" ), }.items(): with self.subTest(test_name): self.assertEqual(len(files), 1) # only one file in PR self.assertEqual(files[0]["type"], "file") self.assertTrue(files[0]["name"].endswith("/data/binary_data_for_pr.bin")) # PR file if "quoted_rev_in_path" in test_name: self.assertIn("@refs%2Fpr%2F1", files[0]["name"]) elif "rev_in_path" in test_name: self.assertIn("@refs/pr/1", files[0]["name"]) def test_list_root_directory_no_revision_no_detail_then_with_detail(self): files = self.hffs.ls(self.hf_path, detail=False) self.assertEqual(len(files), 2) self.assertTrue(files[0].endswith("/data") and files[1].endswith("/.gitattributes")) self.assertIsNone(self.hffs.dircache[self.hf_path][0]["last_commit"]) # no detail -> no last_commit in cache files = self.hffs.ls(self.hf_path, detail=True) self.assertEqual(len(files), 2) self.assertTrue(files[0]["name"].endswith("/data") and files[1]["name"].endswith("/.gitattributes")) self.assertIsNotNone(self.hffs.dircache[self.hf_path][0]["last_commit"]) def test_find_root_directory_no_revision(self): files = self.hffs.find(self.hf_path, detail=False) self.assertEqual( files, self.hffs.ls(self.hf_path, detail=False)[1:] + self.hffs.ls(self.hf_path + "/data", detail=False) ) files = self.hffs.find(self.hf_path, detail=True) self.assertEqual( files, { f["name"]: f for f in self.hffs.ls(self.hf_path, detail=True)[1:] + self.hffs.ls(self.hf_path + "/data", detail=True) }, ) files_with_dirs = self.hffs.find(self.hf_path, withdirs=True, detail=False) self.assertEqual( files_with_dirs, sorted( [self.hf_path] + self.hffs.ls(self.hf_path, detail=False) + self.hffs.ls(self.hf_path + "/data", detail=False) ), ) def test_find_root_directory_no_revision_with_incomplete_cache(self): self.api.upload_file( path_or_fileobj=b"dummy text data 2", path_in_repo="data/sub_data/text_data2.txt", repo_id=self.repo_id, repo_type="dataset", ) self.api.upload_file( path_or_fileobj=b"dummy binary data 2", path_in_repo="data1/binary_data2.bin", repo_id=self.repo_id, repo_type="dataset", ) # Copy the result to make it robust to the cache modifications # See discussion in https://github.com/huggingface/huggingface_hub/pull/2103 # for info on why this is not done in `HfFileSystem.find` by default files = copy.deepcopy(self.hffs.find(self.hf_path, detail=True)) # some directories not in cache self.hffs.dircache.pop(self.hf_path + "/data/sub_data") # some files not expanded self.hffs.dircache[self.hf_path + "/data"][1]["last_commit"] = None out = self.hffs.find(self.hf_path, detail=True) self.assertEqual(out, files) def test_find_data_file_no_revision(self): files = self.hffs.find(self.text_file, detail=False) self.assertEqual(files, [self.text_file]) def test_read_bytes(self): data = self.hffs.read_bytes(self.text_file) self.assertEqual(data, b"dummy text data") def test_read_text(self): data = self.hffs.read_text(self.text_file) self.assertEqual(data, "dummy text data") def test_open_and_read(self): with self.hffs.open(self.text_file, "r") as f: self.assertEqual(f.read(), "dummy text data") def test_partial_read(self): # If partial read => should not download whole file with patch.object(self.hffs, "get_file") as mock: with self.hffs.open(self.text_file, "r") as f: self.assertEqual(f.read(5), "dummy") mock.assert_not_called() def test_get_file_with_temporary_file(self): # Test passing a file object works => happens "in-memory" for posix systems with tempfile.TemporaryFile() as temp_file: self.hffs.get_file(self.text_file, temp_file) temp_file.seek(0) assert temp_file.read() == b"dummy text data" def test_get_file_with_temporary_folder(self): # Test passing a file path works => compatible with hf_transfer with tempfile.TemporaryDirectory() as temp_dir: temp_file = os.path.join(temp_dir, "temp_file.txt") self.hffs.get_file(self.text_file, temp_file) with open(temp_file, "rb") as f: assert f.read() == b"dummy text data" def test_get_file_with_kwargs(self): # If custom kwargs are passed, the function should still work but defaults to base implementation with patch.object(hf_file_system, "http_get") as mock: with tempfile.TemporaryDirectory() as temp_dir: temp_file = os.path.join(temp_dir, "temp_file.txt") self.hffs.get_file(self.text_file, temp_file, custom_kwarg=123) mock.assert_not_called() with tempfile.TemporaryDirectory() as temp_dir: temp_file = os.path.join(temp_dir, "temp_file.txt") self.hffs.get_file(self.text_file, temp_file) mock.assert_called_once() def test_get_file_on_folder(self): # Test it works with custom kwargs with tempfile.TemporaryDirectory() as temp_dir: assert not (Path(temp_dir) / "data").exists() self.hffs.get_file(self.hf_path + "/data", temp_dir + "/data") assert (Path(temp_dir) / "data").exists() @pytest.mark.parametrize("path_in_repo", ["", "file.txt", "path/to/file"]) @pytest.mark.parametrize( "root_path,revision,repo_type,repo_id,resolved_revision", [ # Parse without namespace ("gpt2", None, "model", "gpt2", "main"), ("gpt2", "dev", "model", "gpt2", "dev"), ("gpt2@dev", None, "model", "gpt2", "dev"), ("datasets/squad", None, "dataset", "squad", "main"), ("datasets/squad", "dev", "dataset", "squad", "dev"), ("datasets/squad@dev", None, "dataset", "squad", "dev"), # Parse with namespace ("username/my_model", None, "model", "username/my_model", "main"), ("username/my_model", "dev", "model", "username/my_model", "dev"), ("username/my_model@dev", None, "model", "username/my_model", "dev"), ("datasets/username/my_dataset", None, "dataset", "username/my_dataset", "main"), ("datasets/username/my_dataset", "dev", "dataset", "username/my_dataset", "dev"), ("datasets/username/my_dataset@dev", None, "dataset", "username/my_dataset", "dev"), # Parse with hf:// protocol ("hf://gpt2", None, "model", "gpt2", "main"), ("hf://gpt2", "dev", "model", "gpt2", "dev"), ("hf://gpt2@dev", None, "model", "gpt2", "dev"), ("hf://datasets/squad", None, "dataset", "squad", "main"), ("hf://datasets/squad", "dev", "dataset", "squad", "dev"), ("hf://datasets/squad@dev", None, "dataset", "squad", "dev"), # Parse with `refs/convert/parquet` and `refs/pr/(\d)+` revisions. # Regression tests for https://github.com/huggingface/huggingface_hub/issues/1710. ("datasets/squad@refs/convert/parquet", None, "dataset", "squad", "refs/convert/parquet"), ( "hf://datasets/username/my_dataset@refs/convert/parquet", None, "dataset", "username/my_dataset", "refs/convert/parquet", ), ("gpt2@refs/pr/2", None, "model", "gpt2", "refs/pr/2"), ("gpt2@refs%2Fpr%2F2", None, "model", "gpt2", "refs/pr/2"), ("hf://username/my_model@refs/pr/10", None, "model", "username/my_model", "refs/pr/10"), ("hf://username/my_model@refs/pr/10", "refs/pr/10", "model", "username/my_model", "refs/pr/10"), ("hf://username/my_model@refs%2Fpr%2F10", "refs/pr/10", "model", "username/my_model", "refs/pr/10"), ], ) def test_resolve_path( root_path: str, revision: Optional[str], repo_type: str, repo_id: str, resolved_revision: str, path_in_repo: str, ): fs = HfFileSystem() path = root_path + "/" + path_in_repo if path_in_repo else root_path with mock_repo_info(fs): resolved_path = fs.resolve_path(path, revision=revision) assert ( resolved_path.repo_type, resolved_path.repo_id, resolved_path.revision, resolved_path.path_in_repo, ) == (repo_type, repo_id, resolved_revision, path_in_repo) if "@" in path: assert resolved_path._raw_revision in path @pytest.mark.parametrize("path_in_repo", ["", "file.txt", "path/to/file"]) @pytest.mark.parametrize( "path,revision,expected_path", [ ("hf://datasets/squad@dev", None, "datasets/squad@dev"), ("datasets/squad@refs/convert/parquet", None, "datasets/squad@refs/convert/parquet"), ("hf://username/my_model@refs/pr/10", None, "username/my_model@refs/pr/10"), ("username/my_model", "refs/weirdo", "username/my_model@refs%2Fweirdo"), # not a "special revision" -> encode ], ) def test_unresolve_path(path: str, revision: Optional[str], expected_path: str, path_in_repo: str) -> None: fs = HfFileSystem() path = path + "/" + path_in_repo if path_in_repo else path expected_path = expected_path + "/" + path_in_repo if path_in_repo else expected_path with mock_repo_info(fs): assert fs.resolve_path(path, revision=revision).unresolve() == expected_path def test_resolve_path_with_refs_revision() -> None: """ Testing a very specific edge case where a user has a repo with a revisions named "refs" and a file/directory named "pr/10". We can still process them but the user has to use the `revision` argument to disambiguate between the two. """ fs = HfFileSystem() with mock_repo_info(fs): resolved = fs.resolve_path("hf://username/my_model@refs/pr/10", revision="refs") assert resolved.revision == "refs" assert resolved.path_in_repo == "pr/10" assert resolved.unresolve() == "username/my_model@refs/pr/10" def mock_repo_info(fs: HfFileSystem): def _inner(repo_id: str, *, revision: str, repo_type: str, **kwargs): if repo_id not in ["gpt2", "squad", "username/my_dataset", "username/my_model"]: raise RepositoryNotFoundError(repo_id) if revision is not None and revision not in ["main", "dev", "refs"] and not revision.startswith("refs/"): raise RevisionNotFoundError(revision) return patch.object(fs._api, "repo_info", _inner) def test_resolve_path_with_non_matching_revisions(): fs = HfFileSystem() with pytest.raises(ValueError): fs.resolve_path("gpt2@dev", revision="main") @pytest.mark.parametrize("not_supported_path", ["", "foo", "datasets", "datasets/foo"]) def test_access_repositories_lists(not_supported_path): fs = HfFileSystem() with pytest.raises(NotImplementedError): fs.info(not_supported_path) with pytest.raises(NotImplementedError): fs.ls(not_supported_path) with pytest.raises(NotImplementedError): fs.open(not_supported_path) def test_exists_after_repo_deletion(): """Test that exists() correctly reflects repository deletion.""" # Initialize with staging endpoint and skip cache hffs = HfFileSystem(endpoint=ENDPOINT_STAGING, token=TOKEN, skip_instance_cache=True) api = hffs._api # Create a new repo temp_repo_id = repo_name() repo_url = api.create_repo(temp_repo_id) repo_id = repo_url.repo_id assert hffs.exists(repo_id, refresh=True) # Delete the repo api.delete_repo(repo_id=repo_id, repo_type="model") # Verify that the repo no longer exists. assert not hffs.exists(repo_id, refresh=True) huggingface_hub-0.31.1/tests/test_hub_mixin.py000066400000000000000000000462451500667546600214700ustar00rootroot00000000000000import inspect import json import os import unittest from dataclasses import dataclass from pathlib import Path from typing import Dict, Optional, Union, get_type_hints from unittest.mock import Mock, patch import jedi import pytest from huggingface_hub import HfApi, hf_hub_download from huggingface_hub.hub_mixin import ModelHubMixin from huggingface_hub.utils import SoftTemporaryDirectory from .testing_constants import ENDPOINT_STAGING, TOKEN, USER from .testing_utils import repo_name @dataclass class ConfigAsDataclass: foo: int = 10 bar: str = "baz" CONFIG_AS_DATACLASS = ConfigAsDataclass(foo=20, bar="qux") CONFIG_AS_DICT = {"foo": 20, "bar": "qux"} class BaseModel: def _save_pretrained(self, save_directory: Path) -> None: return @classmethod def _from_pretrained( cls, model_id: Union[str, Path], **kwargs, ) -> "BaseModel": # Little hack but in practice NO-ONE is creating 5 inherited classes for their framework :D init_parameters = inspect.signature(cls.__init__).parameters if init_parameters.get("config"): return cls(config=kwargs.get("config")) if init_parameters.get("kwargs"): return cls(**kwargs) return cls() class DummyModelNoConfig(BaseModel, ModelHubMixin): def __init__(self): pass class DummyModelConfigAsDataclass(BaseModel, ModelHubMixin): def __init__(self, config: ConfigAsDataclass): pass class DummyModelConfigAsDict(BaseModel, ModelHubMixin): def __init__(self, config: Dict): pass class DummyModelConfigAsOptionalDataclass(BaseModel, ModelHubMixin): def __init__(self, config: Optional[ConfigAsDataclass] = None): pass class DummyModelConfigAsOptionalDict(BaseModel, ModelHubMixin): def __init__(self, config: Optional[Dict] = None): pass class DummyModelWithKwargs(BaseModel, ModelHubMixin): def __init__(self, **kwargs): pass class DummyModelFromPretrainedExpectsConfig(ModelHubMixin): def _save_pretrained(self, save_directory: Path) -> None: return @classmethod def _from_pretrained( cls, model_id: Union[str, Path], config: Optional[Dict] = None, **kwargs, ) -> "BaseModel": return cls(**kwargs) class BaseModelForInheritance( ModelHubMixin, repo_url="https://hf.co/my-repo", paper_url="https://arxiv.org/abs/2304.12244", library_name="my-cool-library", ): pass class DummyModelInherited(BaseModelForInheritance): pass class DummyModelSavingConfig(ModelHubMixin): def _save_pretrained(self, save_directory: Path) -> None: """Implementation that uses `config.json` to serialize the config. This file must not be overwritten by the default config saved by `ModelHubMixin`. """ (save_directory / "config.json").write_text(json.dumps({"custom_config": "custom_config"})) @dataclass class DummyModelThatIsAlsoADataclass(ModelHubMixin): foo: int bar: str @classmethod def _from_pretrained( cls, *, model_id: str, revision: Optional[str], cache_dir: Optional[Union[str, Path]], force_download: bool, proxies: Optional[Dict], resume_download: bool, local_files_only: bool, token: Optional[Union[str, bool]], **model_kwargs, ): return cls(**model_kwargs) class CustomType: def __init__(self, value: str): self.value = value class DummyModelWithCustomTypes( ModelHubMixin, coders={ CustomType: ( lambda x: {"value": x.value}, lambda x: CustomType(x["value"]), ) }, ): def __init__( self, foo: int, bar: str, baz: Union[int, str], custom: CustomType, optional_custom_1: Optional[CustomType], optional_custom_2: Optional[CustomType], custom_default: CustomType = CustomType("default"), **kwargs, ): self.foo = foo self.bar = bar self.baz = baz self.custom = custom self.optional_custom_1 = optional_custom_1 self.optional_custom_2 = optional_custom_2 self.custom_default = custom_default @classmethod def _from_pretrained(cls, **kwargs): return cls(**kwargs) @classmethod def _save_pretrained(cls, save_directory: Path): return @dataclass class DummyDataclass: foo: int bar: str class DummyWithDataclassInputs(ModelHubMixin): def __init__(self, arg1: DummyDataclass, arg2: DummyDataclass): self.arg1 = arg1 self.arg2 = arg2 @classmethod def _from_pretrained(cls, **kwargs): return cls(arg1=kwargs["arg1"], arg2=kwargs["arg2"]) @classmethod def _save_pretrained(cls, save_directory: Path): return @pytest.mark.usefixtures("fx_cache_dir") class HubMixinTest(unittest.TestCase): cache_dir: Path @classmethod def setUpClass(cls): """ Share this valid token in all tests below. """ cls._api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) def assert_valid_config_json(self) -> None: # config.json saved correctly with open(self.cache_dir / "config.json") as f: assert json.load(f) == CONFIG_AS_DICT def assert_no_config_json(self) -> None: # config.json not saved files = os.listdir(self.cache_dir) assert "config.json" not in files def test_save_pretrained_no_config(self): model = DummyModelNoConfig() model.save_pretrained(self.cache_dir) self.assert_no_config_json() def test_save_pretrained_as_dataclass_basic(self): model = DummyModelConfigAsDataclass(CONFIG_AS_DATACLASS) model.save_pretrained(self.cache_dir) self.assert_valid_config_json() def test_save_pretrained_as_dict_basic(self): model = DummyModelConfigAsDict(CONFIG_AS_DICT) model.save_pretrained(self.cache_dir) self.assert_valid_config_json() def test_save_pretrained_optional_dataclass(self): model = DummyModelConfigAsOptionalDataclass() model.save_pretrained(self.cache_dir) self.assert_no_config_json() model = DummyModelConfigAsOptionalDataclass(CONFIG_AS_DATACLASS) model.save_pretrained(self.cache_dir) self.assert_valid_config_json() def test_save_pretrained_optional_dict(self): model = DummyModelConfigAsOptionalDict() model.save_pretrained(self.cache_dir) self.assert_no_config_json() model = DummyModelConfigAsOptionalDict(CONFIG_AS_DICT) model.save_pretrained(self.cache_dir) self.assert_valid_config_json() def test_save_pretrained_with_dataclass_config(self): model = DummyModelConfigAsOptionalDataclass() model.save_pretrained(self.cache_dir, config=CONFIG_AS_DATACLASS) self.assert_valid_config_json() def test_save_pretrained_with_dict_config(self): model = DummyModelConfigAsOptionalDict() model.save_pretrained(self.cache_dir, config=CONFIG_AS_DICT) self.assert_valid_config_json() def test_init_accepts_kwargs_no_config(self): """ Test that if `__init__` accepts **kwargs and config file doesn't exist then no 'config' kwargs is passed. Regression test. See https://github.com/huggingface/huggingface_hub/pull/2058. """ model = DummyModelWithKwargs() model.save_pretrained(self.cache_dir) with patch.object( DummyModelWithKwargs, "_from_pretrained", return_value=DummyModelWithKwargs() ) as from_pretrained_mock: model = DummyModelWithKwargs.from_pretrained(self.cache_dir) assert "config" not in from_pretrained_mock.call_args_list[0].kwargs def test_init_accepts_kwargs_with_config(self): """ Test that if `config_inject_mode="as_kwargs"` and config file exists then the 'config' kwarg is passed. Regression test. See https://github.com/huggingface/huggingface_hub/pull/2058. And https://github.com/huggingface/huggingface_hub/pull/2099. """ model = DummyModelFromPretrainedExpectsConfig() model.save_pretrained(self.cache_dir, config=CONFIG_AS_DICT) with patch.object( DummyModelFromPretrainedExpectsConfig, "_from_pretrained", return_value=DummyModelFromPretrainedExpectsConfig(), ) as from_pretrained_mock: DummyModelFromPretrainedExpectsConfig.from_pretrained(self.cache_dir) assert "config" in from_pretrained_mock.call_args_list[0].kwargs def test_init_accepts_kwargs_save_and_load(self): model = DummyModelWithKwargs(something="else") model.save_pretrained(self.cache_dir) assert model._hub_mixin_config == {"something": "else"} with patch.object(DummyModelWithKwargs, "__init__", return_value=None) as init_call_mock: DummyModelWithKwargs.from_pretrained(self.cache_dir) # 'something' is passed to __init__ both as kwarg and in config. init_kwargs = init_call_mock.call_args_list[0].kwargs assert init_kwargs["something"] == "else" def test_save_pretrained_with_push_to_hub(self): repo_id = repo_name("save") save_directory = self.cache_dir / repo_id mocked_model = DummyModelConfigAsDataclass(CONFIG_AS_DATACLASS) mocked_model.push_to_hub = Mock() mocked_model._save_pretrained = Mock() # disable _save_pretrained to speed-up # Not pushed to hub mocked_model.save_pretrained(save_directory) mocked_model.push_to_hub.assert_not_called() # Push to hub with repo_id (config is pushed) mocked_model.save_pretrained(save_directory, push_to_hub=True, repo_id="CustomID") mocked_model.push_to_hub.assert_called_with(repo_id="CustomID", config=CONFIG_AS_DICT, model_card_kwargs={}) # Push to hub with default repo_id (based on dir name) mocked_model.save_pretrained(save_directory, push_to_hub=True) mocked_model.push_to_hub.assert_called_with(repo_id=repo_id, config=CONFIG_AS_DICT, model_card_kwargs={}) @patch.object(DummyModelNoConfig, "_from_pretrained") def test_from_pretrained_model_id_only(self, from_pretrained_mock: Mock) -> None: model = DummyModelNoConfig.from_pretrained("namespace/repo_name") from_pretrained_mock.assert_called_once() assert model is from_pretrained_mock.return_value @patch.object(DummyModelNoConfig, "_from_pretrained") def test_from_pretrained_model_id_and_revision(self, from_pretrained_mock: Mock) -> None: """Regression test for #1313. See https://github.com/huggingface/huggingface_hub/issues/1313.""" model = DummyModelNoConfig.from_pretrained("namespace/repo_name", revision="123456789") from_pretrained_mock.assert_called_once_with( model_id="namespace/repo_name", revision="123456789", # Revision is passed correctly! cache_dir=None, force_download=False, proxies=None, resume_download=None, local_files_only=False, token=None, ) assert model is from_pretrained_mock.return_value def test_from_pretrained_from_relative_path(self): with SoftTemporaryDirectory(dir=Path(".")) as tmp_relative_dir: relative_save_directory = Path(tmp_relative_dir) / "model" DummyModelConfigAsDataclass(config=CONFIG_AS_DATACLASS).save_pretrained(relative_save_directory) model = DummyModelConfigAsDataclass.from_pretrained(relative_save_directory) assert model._hub_mixin_config == CONFIG_AS_DATACLASS def test_from_pretrained_from_absolute_path(self): save_directory = self.cache_dir / "subfolder" DummyModelConfigAsDataclass(config=CONFIG_AS_DATACLASS).save_pretrained(save_directory) model = DummyModelConfigAsDataclass.from_pretrained(save_directory) assert model._hub_mixin_config == CONFIG_AS_DATACLASS def test_from_pretrained_from_absolute_string_path(self): save_directory = str(self.cache_dir / "subfolder") DummyModelConfigAsDataclass(config=CONFIG_AS_DATACLASS).save_pretrained(save_directory) model = DummyModelConfigAsDataclass.from_pretrained(save_directory) assert model._hub_mixin_config == CONFIG_AS_DATACLASS def test_push_to_hub(self): repo_id = f"{USER}/{repo_name('push_to_hub')}" DummyModelConfigAsDataclass(CONFIG_AS_DATACLASS).push_to_hub(repo_id=repo_id, token=TOKEN) # Test model id exists self._api.model_info(repo_id) # Test config has been pushed to hub tmp_config_path = hf_hub_download( repo_id=repo_id, filename="config.json", use_auth_token=TOKEN, cache_dir=self.cache_dir, ) with open(tmp_config_path) as f: assert json.load(f) == CONFIG_AS_DICT # from_pretrained with correct serialization from_pretrained_kwargs = { "pretrained_model_name_or_path": repo_id, "cache_dir": self.cache_dir, "api_endpoint": ENDPOINT_STAGING, "token": TOKEN, } for cls in (DummyModelConfigAsDataclass, DummyModelConfigAsOptionalDataclass): assert cls.from_pretrained(**from_pretrained_kwargs)._hub_mixin_config == CONFIG_AS_DATACLASS for cls in (DummyModelConfigAsDict, DummyModelConfigAsOptionalDict): assert cls.from_pretrained(**from_pretrained_kwargs)._hub_mixin_config == CONFIG_AS_DICT # Delete repo self._api.delete_repo(repo_id=repo_id) def test_save_pretrained_do_not_overwrite_new_config(self): """Regression test for https://github.com/huggingface/huggingface_hub/issues/2102. If `_from_pretrained` does save a config file, we should not overwrite it. """ model = DummyModelSavingConfig() model.save_pretrained(self.cache_dir) # config.json is not overwritten with open(self.cache_dir / "config.json") as f: assert json.load(f) == {"custom_config": "custom_config"} def test_save_pretrained_does_overwrite_legacy_config(self): """Regression test for https://github.com/huggingface/huggingface_hub/issues/2142. If a previously existing config file exists, it should be overwritten. """ # Something existing in the cache dir (self.cache_dir / "config.json").write_text(json.dumps({"something_legacy": 123})) # Save model model = DummyModelWithKwargs(a=1, b=2) model.save_pretrained(self.cache_dir) # config.json IS overwritten with open(self.cache_dir / "config.json") as f: assert json.load(f) == {"a": 1, "b": 2} def test_from_pretrained_when_cls_is_a_dataclass(self): """Regression test for #2157. When the ModelHubMixin class happens to be a dataclass, `__init__` method will accept `**kwargs` when inspecting it. However, due to how dataclasses work, we cannot forward arbitrary kwargs to the `__init__`. This test ensures that the `from_pretrained` method does not raise an error when the class is a dataclass. See https://github.com/huggingface/huggingface_hub/issues/2157. """ (self.cache_dir / "config.json").write_text('{"foo": 42, "bar": "baz", "other": "value"}') model = DummyModelThatIsAlsoADataclass.from_pretrained(self.cache_dir) assert model.foo == 42 assert model.bar == "baz" assert not hasattr(model, "other") def test_from_cls_with_custom_type(self): model = DummyModelWithCustomTypes( 1, bar="bar", baz=1.0, custom=CustomType("custom"), optional_custom_1=CustomType("optional"), optional_custom_2=None, ) model.save_pretrained(self.cache_dir) config = json.loads((self.cache_dir / "config.json").read_text()) assert config == { "foo": 1, "bar": "bar", "baz": 1.0, "custom": {"value": "custom"}, "optional_custom_1": {"value": "optional"}, "optional_custom_2": None, "custom_default": {"value": "default"}, } model_reloaded = DummyModelWithCustomTypes.from_pretrained(self.cache_dir) assert model_reloaded.foo == 1 assert model_reloaded.bar == "bar" assert model_reloaded.baz == 1.0 assert model_reloaded.custom.value == "custom" assert model_reloaded.optional_custom_1 is not None and model_reloaded.optional_custom_1.value == "optional" assert model_reloaded.optional_custom_2 is None assert model_reloaded.custom_default.value == "default" def test_inherited_class(self): """Test MixinInfo attributes are inherited from the parent class.""" model = DummyModelInherited() assert model._hub_mixin_info.repo_url == "https://hf.co/my-repo" assert model._hub_mixin_info.paper_url == "https://arxiv.org/abs/2304.12244" assert model._hub_mixin_info.model_card_data.library_name == "my-cool-library" def test_autocomplete_works_as_expected(self): """Regression test for #2694. Ensure that autocomplete works as expected when inheriting from `ModelHubMixin`. See https://github.com/huggingface/huggingface_hub/issues/2694. """ source = """ from huggingface_hub import ModelHubMixin class Dummy(ModelHubMixin): def dummy_example_for_test(self, x: str) -> str: return x a = Dummy() a.dum""".strip() script = jedi.Script(source, path="example.py") source_lines = source.split("\n") completions = script.complete(len(source_lines), len(source_lines[-1])) assert any(completion.name == "dummy_example_for_test" for completion in completions) def test_get_type_hints_works_as_expected(self): """ Ensure that `typing.get_type_hints` works as expected when inheriting from `ModelHubMixin`. See https://github.com/huggingface/huggingface_hub/issues/2727. """ class ModelWithHints(ModelHubMixin): def method_with_hints(self, x: int) -> str: return str(x) assert get_type_hints(ModelWithHints) != {} # Test method type hints on class hints = get_type_hints(ModelWithHints.method_with_hints) assert hints == {"x": int, "return": str} # Test method type hints on instance model = ModelWithHints() assert get_type_hints(model.method_with_hints) == {"x": int, "return": str} def test_with_dataclass_inputs(self): model = DummyWithDataclassInputs( arg1=DummyDataclass(foo=1, bar="1"), arg2=DummyDataclass(foo=2, bar="2"), ) model.save_pretrained(self.cache_dir) config = json.loads((self.cache_dir / "config.json").read_text()) assert config == { "arg1": {"foo": 1, "bar": "1"}, "arg2": {"foo": 2, "bar": "2"}, } model_reloaded = DummyWithDataclassInputs.from_pretrained(self.cache_dir) assert model_reloaded.arg1.foo == 1 assert model_reloaded.arg1.bar == "1" assert model_reloaded.arg2.foo == 2 assert model_reloaded.arg2.bar == "2" huggingface_hub-0.31.1/tests/test_hub_mixin_pytorch.py000066400000000000000000000457241500667546600232410ustar00rootroot00000000000000import json import os import struct import unittest from argparse import Namespace from pathlib import Path from typing import Any, Dict, Optional, TypeVar from unittest.mock import Mock, patch import pytest from huggingface_hub import HfApi, ModelCard, constants, hf_hub_download from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError from huggingface_hub.hub_mixin import ModelHubMixin, PyTorchModelHubMixin from huggingface_hub.serialization._torch import storage_ptr from huggingface_hub.utils import SoftTemporaryDirectory, is_torch_available from .testing_constants import ENDPOINT_STAGING, TOKEN, USER from .testing_utils import repo_name, requires DUMMY_OBJECT = object() DUMMY_MODEL_CARD_TEMPLATE = """ --- {{ card_data }} --- This is a dummy model card. Arxiv ID: 1234.56789 """ DUMMY_MODEL_CARD_TEMPLATE_WITH_CUSTOM_KWARGS = """ --- {{ card_data }} --- This is a dummy model card with kwargs. {{ custom_data }} - Code: {{ repo_url }} - Paper: {{ paper_url }} - Docs: {{ docs_url }} """ if is_torch_available(): import torch import torch.nn as nn CONFIG = {"num": 10, "act": "gelu_fast"} class DummyModel(nn.Module, PyTorchModelHubMixin): def __init__(self, **kwargs): super().__init__() self.config = kwargs.pop("config", None) self.l1 = nn.Linear(2, 2) def forward(self, x): return self.l1(x) class DummyModelWithModelCard( nn.Module, PyTorchModelHubMixin, model_card_template=DUMMY_MODEL_CARD_TEMPLATE, language=["en", "zh"], library_name="my-dummy-lib", license="apache-2.0", tags=["tag1", "tag2"], pipeline_tag="text-classification", ): def __init__(self, linear_layer: int = 4): super().__init__() self.l1 = nn.Linear(linear_layer, linear_layer) def forward(self, x): return self.l1(x) class DummyModelNoConfig(nn.Module, PyTorchModelHubMixin): def __init__( self, num_classes: int = 42, state: str = "layernorm", not_jsonable: Any = DUMMY_OBJECT, ): super().__init__() self.num_classes = num_classes self.state = state self.not_jsonable = not_jsonable class DummyModelWithConfigAndKwargs(nn.Module, PyTorchModelHubMixin): def __init__(self, num_classes: int = 42, state: str = "layernorm", config: Optional[Dict] = None, **kwargs): super().__init__() class DummyModelWithModelCardAndCustomKwargs( nn.Module, PyTorchModelHubMixin, model_card_template=DUMMY_MODEL_CARD_TEMPLATE_WITH_CUSTOM_KWARGS, docs_url="https://hf.co/docs/my-repo", paper_url="https://arxiv.org/abs/2304.12244", repo_url="https://hf.co/my-repo", ): def __init__(self, linear_layer: int = 4): super().__init__() class DummyModelWithEncodedConfig( nn.Module, PyTorchModelHubMixin, coders={ Namespace: ( lambda x: vars(x), lambda data: Namespace(**data), ) }, ): # Regression test for https://github.com/huggingface/huggingface_hub/issues/2334 def __init__(self, config: Namespace): super().__init__() self.config = config class DummyModelWithTag1(nn.Module, PyTorchModelHubMixin, tags=["tag1"]): """Used to test tags not shared between sibling classes (only inheritance).""" class DummyModelWithTag2(nn.Module, PyTorchModelHubMixin, tags=["tag2"]): """Used to test tags not shared between sibling classes (only inheritance).""" else: DummyModel = None DummyModelWithModelCard = None DummyModelNoConfig = None DummyModelWithConfigAndKwargs = None DummyModelWithModelCardAndCustomKwargs = None DummyModelWithTag1 = None DummyModelWithTag2 = None @requires("torch") @pytest.mark.usefixtures("fx_cache_dir") class PytorchHubMixinTest(unittest.TestCase): cache_dir: Path @classmethod def setUpClass(cls): """ Share this valid token in all tests below. """ cls._api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) def test_save_pretrained_basic(self): DummyModel().save_pretrained(self.cache_dir) files = os.listdir(self.cache_dir) assert set(files) == {"README.md", "model.safetensors"} def test_save_pretrained_with_config(self): DummyModel().save_pretrained(self.cache_dir, config=CONFIG) files = os.listdir(self.cache_dir) assert set(files) == {"README.md", "config.json", "model.safetensors"} def test_save_as_safetensors(self): DummyModel().save_pretrained(self.cache_dir, config=TOKEN) modelFile = self.cache_dir / "model.safetensors" # check for safetensors header to ensure we are saving the model in safetensors format # while an implementation detail, assert as this has safety implications # https://github.com/huggingface/safetensors?tab=readme-ov-file#format with open(modelFile, "rb") as f: header_size = struct.unpack(" None: model = DummyModel.from_pretrained("namespace/repo_name") from_pretrained_mock.assert_called_once() self.assertIs(model, from_pretrained_mock.return_value) def pretend_file_download(self, **kwargs): if kwargs.get("filename") == "config.json": raise HfHubHTTPError("no config") DummyModel().save_pretrained(self.cache_dir) return self.cache_dir / "model.safetensors" @patch("huggingface_hub.hub_mixin.hf_hub_download") def test_from_pretrained_model_from_hub_prefer_safetensor(self, hf_hub_download_mock: Mock) -> None: hf_hub_download_mock.side_effect = self.pretend_file_download model = DummyModel.from_pretrained("namespace/repo_name") hf_hub_download_mock.assert_any_call( repo_id="namespace/repo_name", filename="model.safetensors", revision=None, cache_dir=None, force_download=False, proxies=None, resume_download=None, token=None, local_files_only=False, ) self.assertIsNotNone(model) def pretend_file_download_fallback(self, **kwargs): filename = kwargs.get("filename") if filename == "model.safetensors" or filename == "config.json": raise EntryNotFoundError("not found") class TestMixin(ModelHubMixin): def _save_pretrained(self, save_directory: Path) -> None: torch.save(DummyModel().state_dict(), save_directory / constants.PYTORCH_WEIGHTS_NAME) TestMixin().save_pretrained(self.cache_dir) return self.cache_dir / constants.PYTORCH_WEIGHTS_NAME @patch("huggingface_hub.hub_mixin.hf_hub_download") def test_from_pretrained_model_from_hub_fallback_pickle(self, hf_hub_download_mock: Mock) -> None: hf_hub_download_mock.side_effect = self.pretend_file_download_fallback model = DummyModel.from_pretrained("namespace/repo_name") hf_hub_download_mock.assert_any_call( repo_id="namespace/repo_name", filename="model.safetensors", revision=None, cache_dir=None, force_download=False, proxies=None, resume_download=None, token=None, local_files_only=False, ) hf_hub_download_mock.assert_any_call( repo_id="namespace/repo_name", filename="pytorch_model.bin", revision=None, cache_dir=None, force_download=False, proxies=None, resume_download=None, token=None, local_files_only=False, ) self.assertIsNotNone(model) @patch.object(DummyModel, "_from_pretrained") def test_from_pretrained_model_id_and_revision(self, from_pretrained_mock: Mock) -> None: """Regression test for #1313. See https://github.com/huggingface/huggingface_hub/issues/1313.""" model = DummyModel.from_pretrained("namespace/repo_name", revision="123456789") from_pretrained_mock.assert_called_once_with( model_id="namespace/repo_name", revision="123456789", # Revision is passed correctly! cache_dir=None, force_download=False, proxies=None, resume_download=None, local_files_only=False, token=None, ) self.assertIs(model, from_pretrained_mock.return_value) def test_from_pretrained_to_relative_path(self): with SoftTemporaryDirectory(dir=Path(".")) as tmp_relative_dir: relative_save_directory = Path(tmp_relative_dir) / "model" DummyModel().save_pretrained(relative_save_directory, config=CONFIG) model = DummyModel.from_pretrained(relative_save_directory) self.assertDictEqual(model._hub_mixin_config, CONFIG) def test_from_pretrained_to_absolute_path(self): save_directory = self.cache_dir / "subfolder" DummyModel().save_pretrained(save_directory, config=CONFIG) model = DummyModel.from_pretrained(save_directory) self.assertDictEqual(model._hub_mixin_config, CONFIG) def test_from_pretrained_to_absolute_string_path(self): save_directory = str(self.cache_dir / "subfolder") DummyModel().save_pretrained(save_directory, config=CONFIG) model = DummyModel.from_pretrained(save_directory) self.assertDictEqual(model._hub_mixin_config, CONFIG) def test_return_type_hint_from_pretrained(self): self.assertIn( "return", DummyModel.from_pretrained.__annotations__, "`PyTorchModelHubMixin.from_pretrained` does not set a return type annotation.", ) self.assertIsInstance( DummyModel.from_pretrained.__annotations__["return"], TypeVar, "`PyTorchModelHubMixin.from_pretrained` return type annotation is not a TypeVar.", ) self.assertEqual( DummyModel.from_pretrained.__annotations__["return"].__bound__.__forward_arg__, "ModelHubMixin", "`PyTorchModelHubMixin.from_pretrained` return type annotation is not a TypeVar bound by `ModelHubMixin`.", ) def test_push_to_hub(self): repo_id = f"{USER}/{repo_name('push_to_hub')}" DummyModel().push_to_hub(repo_id=repo_id, token=TOKEN, config=CONFIG) # Test model id exists assert self._api.model_info(repo_id).id == repo_id # Test config has been pushed to hub tmp_config_path = hf_hub_download( repo_id=repo_id, filename="config.json", use_auth_token=TOKEN, cache_dir=self.cache_dir, ) with open(tmp_config_path) as f: self.assertDictEqual(json.load(f), CONFIG) # Delete repo self._api.delete_repo(repo_id=repo_id) def test_generate_model_card(self): model = DummyModelWithModelCard() card = model.generate_model_card() assert card.data.language == ["en", "zh"] assert card.data.library_name == "my-dummy-lib" assert card.data.license == "apache-2.0" assert card.data.pipeline_tag == "text-classification" assert card.data.tags == ["model_hub_mixin", "pytorch_model_hub_mixin", "tag1", "tag2"] # Model card template has been used assert "This is a dummy model card" in str(card) model.save_pretrained(self.cache_dir) card_reloaded = ModelCard.load(self.cache_dir / "README.md") assert str(card) == str(card_reloaded) assert card.data == card_reloaded.data def test_load_no_config(self): config_file = self.cache_dir / "config.json" # Test creating model => auto-generated config model = DummyModelNoConfig(num_classes=50) assert model._hub_mixin_config == {"num_classes": 50, "state": "layernorm"} # Test saving model => auto-generated config is saved model.save_pretrained(self.cache_dir) assert config_file.exists() assert json.loads(config_file.read_text()) == {"num_classes": 50, "state": "layernorm"} # Reload model => config is reloaded reloaded = DummyModelNoConfig.from_pretrained(self.cache_dir) assert reloaded.num_classes == 50 assert reloaded.state == "layernorm" assert reloaded._hub_mixin_config == {"num_classes": 50, "state": "layernorm"} # Reload model with custom config => custom config is used reloaded_with_default = DummyModelNoConfig.from_pretrained(self.cache_dir, state="other") assert reloaded_with_default.num_classes == 50 assert reloaded_with_default.state == "other" assert reloaded_with_default._hub_mixin_config == {"num_classes": 50, "state": "other"} config_file.unlink() # Remove config file reloaded_with_default.save_pretrained(self.cache_dir) assert json.loads(config_file.read_text()) == {"num_classes": 50, "state": "other"} def test_save_with_non_jsonable_config(self): # Save with a non-jsonable value my_object = object() model = DummyModelNoConfig(not_jsonable=my_object) assert model.not_jsonable is my_object assert "not_jsonable" not in model._hub_mixin_config # Reload with default value model.save_pretrained(self.cache_dir) reloaded_model = DummyModelNoConfig.from_pretrained(self.cache_dir) assert reloaded_model.not_jsonable is DUMMY_OBJECT assert "not_jsonable" not in model._hub_mixin_config # If jsonable value passed by user, it's saved in the config (self.cache_dir / "config.json").unlink() new_model = DummyModelNoConfig(not_jsonable=123) new_model.save_pretrained(self.cache_dir) assert new_model._hub_mixin_config["not_jsonable"] == 123 reloaded_new_model = DummyModelNoConfig.from_pretrained(self.cache_dir) assert reloaded_new_model.not_jsonable == 123 assert reloaded_new_model._hub_mixin_config["not_jsonable"] == 123 def test_save_model_with_shared_tensors(self): """ Regression test for #2086. Shared tensors should be saved correctly. See https://github.com/huggingface/huggingface_hub/pull/2086 for more details. """ class ModelWithSharedTensors(nn.Module, PyTorchModelHubMixin): def __init__(self): super().__init__() self.a = nn.Linear(100, 100) self.b = self.a def forward(self, x): return self.b(self.a(x)) # Save and reload model model = ModelWithSharedTensors() model.save_pretrained(self.cache_dir) reloaded = ModelWithSharedTensors.from_pretrained(self.cache_dir) # Linear layers should share weights and biases in memory state_dict = reloaded.state_dict() a_weight_ptr = storage_ptr(state_dict["a.weight"]) b_weight_ptr = storage_ptr(state_dict["b.weight"]) a_bias_ptr = storage_ptr(state_dict["a.bias"]) b_bias_ptr = storage_ptr(state_dict["b.bias"]) assert a_weight_ptr == b_weight_ptr assert a_bias_ptr == b_bias_ptr def test_save_pretrained_when_config_and_kwargs_are_passed(self): # Test creating model with config and kwargs => all values are saved together in config.json model = DummyModelWithConfigAndKwargs(num_classes=50, state="layernorm", config={"a": 1}, b=2, c=3) model.save_pretrained(self.cache_dir) assert model._hub_mixin_config == {"num_classes": 50, "state": "layernorm", "a": 1, "b": 2, "c": 3} reloaded = DummyModelWithConfigAndKwargs.from_pretrained(self.cache_dir) assert reloaded._hub_mixin_config == model._hub_mixin_config def test_model_card_with_custom_kwargs(self): model_card_kwargs = {"custom_data": "This is a model custom data: 42."} # Test creating model with custom kwargs => custom data is saved in model card model = DummyModelWithModelCardAndCustomKwargs() card = model.generate_model_card(**model_card_kwargs) assert model_card_kwargs["custom_data"] in str(card) assert "Code: https://hf.co/my-repo" in str(card) assert "Paper: https://arxiv.org/abs/2304.12244" in str(card) assert "Docs: https://hf.co/docs/my-repo" in str(card) # Test saving card => model card is saved and restored with custom data model.save_pretrained(self.cache_dir, model_card_kwargs=model_card_kwargs) card_reloaded = ModelCard.load(self.cache_dir / "README.md") assert str(card) == str(card_reloaded) def test_config_with_custom_coders(self): """ Regression test for #2334. When `config` is encoded with custom coders, it should be decoded correctly. See https://github.com/huggingface/huggingface_hub/issues/2334. """ model = DummyModelWithEncodedConfig(Namespace(a=1, b=2)) model.save_pretrained(self.cache_dir) assert model._hub_mixin_config["a"] == 1 assert model._hub_mixin_config["b"] == 2 reloaded = DummyModelWithEncodedConfig.from_pretrained(self.cache_dir) assert isinstance(reloaded.config, Namespace) assert reloaded.config.a == 1 assert reloaded.config.b == 2 def test_inheritance_and_sibling_classes(self): """ Test tags are not shared between sibling classes. Regression test for #2394. See https://github.com/huggingface/huggingface_hub/pull/2394. """ assert DummyModelWithTag1._hub_mixin_info.model_card_data.tags == [ "model_hub_mixin", "pytorch_model_hub_mixin", "tag1", ] assert DummyModelWithTag2._hub_mixin_info.model_card_data.tags == [ "model_hub_mixin", "pytorch_model_hub_mixin", "tag2", ] huggingface_hub-0.31.1/tests/test_inference_api.py000066400000000000000000000132641500667546600222700ustar00rootroot00000000000000# Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest from pathlib import Path from unittest.mock import patch import pytest from PIL import Image from huggingface_hub import hf_hub_download from huggingface_hub.inference_api import InferenceApi from .testing_utils import expect_deprecation, with_production_testing @pytest.mark.vcr @with_production_testing class InferenceApiTest(unittest.TestCase): def read(self, filename: str) -> bytes: return Path(filename).read_bytes() @classmethod @with_production_testing def setUpClass(cls) -> None: cls.image_file = hf_hub_download(repo_id="Narsil/image_dummy", repo_type="dataset", filename="lena.png") return super().setUpClass() @expect_deprecation("huggingface_hub.inference_api") def test_simple_inference(self): api = InferenceApi("bert-base-uncased") inputs = "Hi, I think [MASK] is cool" results = api(inputs) self.assertIsInstance(results, list) result = results[0] self.assertIsInstance(result, dict) self.assertTrue("sequence" in result) self.assertTrue("score" in result) @unittest.skip("Model often not loaded") @expect_deprecation("huggingface_hub.inference_api") def test_inference_with_params(self): api = InferenceApi("typeform/distilbert-base-uncased-mnli") inputs = "I bought a device but it is not working and I would like to get reimbursed!" params = {"candidate_labels": ["refund", "legal", "faq"]} result = api(inputs, params) self.assertIsInstance(result, dict) self.assertTrue("sequence" in result) self.assertTrue("scores" in result) @unittest.skip("Model often not loaded") @expect_deprecation("huggingface_hub.inference_api") def test_inference_with_dict_inputs(self): api = InferenceApi("distilbert-base-cased-distilled-squad") inputs = { "question": "What's my name?", "context": "My name is Clara and I live in Berkeley.", } result = api(inputs) self.assertIsInstance(result, dict) self.assertTrue("score" in result) self.assertTrue("answer" in result) @unittest.skip("Model often not loaded") @expect_deprecation("huggingface_hub.inference_api") def test_inference_with_audio(self): api = InferenceApi("facebook/wav2vec2-base-960h") file = hf_hub_download( repo_id="hf-internal-testing/dummy-flac-single-example", repo_type="dataset", filename="example.flac", ) data = self.read(file) result = api(data=data) self.assertIsInstance(result, dict) self.assertTrue("text" in result, f"We received {result} instead") @unittest.skip("Model often not loaded") @expect_deprecation("huggingface_hub.inference_api") def test_inference_with_image(self): api = InferenceApi("google/vit-base-patch16-224") data = self.read(self.image_file) result = api(data=data) self.assertIsInstance(result, list) for classification in result: self.assertIsInstance(classification, dict) self.assertTrue("score" in classification) self.assertTrue("label" in classification) @expect_deprecation("huggingface_hub.inference_api") def test_text_to_image(self): api = InferenceApi("stabilityai/stable-diffusion-2-1") with patch("huggingface_hub.inference_api.get_session") as mock: mock().post.return_value.headers = {"Content-Type": "image/jpeg"} mock().post.return_value.content = self.read(self.image_file) output = api("cat") self.assertIsInstance(output, Image.Image) @expect_deprecation("huggingface_hub.inference_api") def test_text_to_image_raw_response(self): api = InferenceApi("stabilityai/stable-diffusion-2-1") with patch("huggingface_hub.inference_api.get_session") as mock: mock().post.return_value.headers = {"Content-Type": "image/jpeg"} mock().post.return_value.content = self.read(self.image_file) output = api("cat", raw_response=True) # Raw response is returned self.assertEqual(output, mock().post.return_value) @expect_deprecation("huggingface_hub.inference_api") def test_inference_overriding_task(self): api = InferenceApi( "sentence-transformers/paraphrase-albert-small-v2", task="feature-extraction", ) inputs = "This is an example again" result = api(inputs) self.assertIsInstance(result, list) @expect_deprecation("huggingface_hub.inference_api") def test_inference_overriding_invalid_task(self): with self.assertRaises(ValueError, msg="Invalid task invalid-task. Make sure it's valid."): InferenceApi("bert-base-uncased", task="invalid-task") @expect_deprecation("huggingface_hub.inference_api") def test_inference_missing_input(self): api = InferenceApi("deepset/roberta-base-squad2") result = api({"question": "What's my name?"}) self.assertIsInstance(result, dict) self.assertTrue("error" in result) huggingface_hub-0.31.1/tests/test_inference_async_client.py000066400000000000000000000457121500667546600241750ustar00rootroot00000000000000# coding=utf-8 # Copyright 2023-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains tests for AsyncInferenceClient. Tests are run directly with pytest instead of unittest.TestCase as it's much easier to run with asyncio. Not all tasks are tested. We extensively test `text_generation` method since it's the most complex one (has different return types + uses streaming requests on demand). Tests are mostly duplicates from test_inference_text_generation.py`. For completeness we also run a test on a simple task (`test_async_sentence_similarity`) and assume all other tasks work as well. """ import asyncio import inspect from unittest.mock import Mock, patch import pytest from aiohttp import ClientResponseError import huggingface_hub.inference._common from huggingface_hub import ( AsyncInferenceClient, ChatCompletionOutput, ChatCompletionOutputComplete, ChatCompletionOutputMessage, ChatCompletionOutputUsage, ChatCompletionStreamOutput, InferenceClient, InferenceTimeoutError, TextGenerationOutputPrefillToken, ) from huggingface_hub.inference._common import ValidationError as TextGenerationValidationError from huggingface_hub.inference._common import _get_unsupported_text_generation_kwargs from .test_inference_client import CHAT_COMPLETE_NON_TGI_MODEL, CHAT_COMPLETION_MESSAGES, CHAT_COMPLETION_MODEL from .testing_utils import with_production_testing @pytest.fixture(autouse=True) def patch_non_tgi_server(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(huggingface_hub.inference._common, "_UNSUPPORTED_TEXT_GENERATION_KWARGS", {}) @pytest.fixture def tgi_client() -> AsyncInferenceClient: return AsyncInferenceClient(model="openai-community/gpt2") @pytest.mark.asyncio @with_production_testing @pytest.mark.skip("Temporary skipping this test") async def test_async_generate_no_details(tgi_client: AsyncInferenceClient) -> None: response = await tgi_client.text_generation("test", details=False, max_new_tokens=1) assert isinstance(response, str) assert response == "." @pytest.mark.asyncio @with_production_testing @pytest.mark.skip("Temporary skipping this test") async def test_async_generate_with_details(tgi_client: AsyncInferenceClient) -> None: response = await tgi_client.text_generation("test", details=True, max_new_tokens=1, decoder_input_details=True) assert response.generated_text == "." assert response.details.finish_reason == "length" assert response.details.generated_tokens == 1 assert response.details.seed is None assert len(response.details.prefill) == 1 assert response.details.prefill[0] == TextGenerationOutputPrefillToken(id=9288, logprob=None, text="test") assert len(response.details.tokens) == 1 assert response.details.tokens[0].id == 13 assert response.details.tokens[0].text == "." assert not response.details.tokens[0].special @pytest.mark.asyncio @with_production_testing @pytest.mark.skip("Temporary skipping this test") async def test_async_generate_best_of(tgi_client: AsyncInferenceClient) -> None: response = await tgi_client.text_generation( "test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True, details=True ) assert response.details.seed is not None assert response.details.best_of_sequences is not None assert len(response.details.best_of_sequences) == 1 assert response.details.best_of_sequences[0].seed is not None @pytest.mark.asyncio @with_production_testing @pytest.mark.skip("Temporary skipping this test") async def test_async_generate_validation_error(tgi_client: AsyncInferenceClient) -> None: with pytest.raises(TextGenerationValidationError): await tgi_client.text_generation("test", max_new_tokens=10_000) @pytest.mark.asyncio @pytest.mark.skip("skipping this test, as InferenceAPI seems to not throw an error when sending unsupported params") async def test_async_generate_non_tgi_endpoint(tgi_client: AsyncInferenceClient) -> None: text = await tgi_client.text_generation("0 1 2", model="gpt2", max_new_tokens=10) assert text == " 3 4 5 6 7 8 9 10 11 12" assert _get_unsupported_text_generation_kwargs("gpt2") == ["details", "stop", "watermark", "decoder_input_details"] # Watermark is ignored (+ warning) with pytest.warns(UserWarning): await tgi_client.text_generation("4 5 6", model="gpt2", max_new_tokens=10, watermark=True) # Return as detail even if details=True (+ warning) with pytest.warns(UserWarning): text = await tgi_client.text_generation("0 1 2", model="gpt2", max_new_tokens=10, details=True) assert isinstance(text, str) # Return as stream raises error with pytest.raises(ValueError): await tgi_client.text_generation("0 1 2", model="gpt2", max_new_tokens=10, stream=True) @pytest.mark.skip("Temporary skipping this test") @pytest.mark.asyncio @with_production_testing async def test_async_generate_stream_no_details(tgi_client: AsyncInferenceClient) -> None: responses = [ response async for response in await tgi_client.text_generation("test", max_new_tokens=1, stream=True) ] assert len(responses) == 1 response = responses[0] assert isinstance(response, str) assert response == "." @pytest.mark.skip("Temporary skipping this test") @pytest.mark.asyncio @with_production_testing async def test_async_generate_stream_with_details(tgi_client: AsyncInferenceClient) -> None: responses = [ response async for response in await tgi_client.text_generation("test", max_new_tokens=1, stream=True, details=True) ] assert len(responses) == 1 response = responses[0] assert response.generated_text == "." assert response.details.finish_reason == "length" assert response.details.generated_tokens == 1 assert response.details.seed is None @pytest.mark.skip("Temporary skipping this test") @pytest.mark.asyncio @with_production_testing async def test_async_chat_completion_no_stream() -> None: async_client = AsyncInferenceClient(model=CHAT_COMPLETION_MODEL) output = await async_client.chat_completion(CHAT_COMPLETION_MESSAGES, max_tokens=10) assert isinstance(output.created, int) assert output == ChatCompletionOutput( id="", model="HuggingFaceH4/zephyr-7b-beta", system_fingerprint="3.0.1-sha-bb9095a", usage=ChatCompletionOutputUsage(completion_tokens=10, prompt_tokens=46, total_tokens=56), choices=[ ChatCompletionOutputComplete( finish_reason="length", index=0, message=ChatCompletionOutputMessage( content="Deep learning is a subfield of machine learning that", role="assistant", ), ) ], created=output.created, ) @pytest.mark.skip("Temporary skipping this test") @pytest.mark.asyncio @with_production_testing async def test_async_chat_completion_not_tgi_no_stream() -> None: async_client = AsyncInferenceClient(model=CHAT_COMPLETE_NON_TGI_MODEL) output = await async_client.chat_completion(CHAT_COMPLETION_MESSAGES, max_tokens=10) assert isinstance(output.created, int) assert output == ChatCompletionOutput( choices=[ ChatCompletionOutputComplete( finish_reason="length", index=0, message=ChatCompletionOutputMessage( role="assistant", content="Deep learning isn't even an algorithm though.", tool_calls=None ), logprobs=None, ) ], created=1737562613, id="", model="microsoft/DialoGPT-small", system_fingerprint="3.0.1-sha-bb9095a", usage=ChatCompletionOutputUsage(completion_tokens=10, prompt_tokens=13, total_tokens=23), ) @pytest.mark.skip("Temporary skipping this test") @pytest.mark.asyncio @with_production_testing async def test_async_chat_completion_with_stream() -> None: async_client = AsyncInferenceClient(model=CHAT_COMPLETION_MODEL) output = await async_client.chat_completion(CHAT_COMPLETION_MESSAGES, max_tokens=10, stream=True) all_items = [] generated_text = "" async for item in output: all_items.append(item) assert isinstance(item, ChatCompletionStreamOutput) assert len(item.choices) == 1 if item.choices[0].delta.content is not None: generated_text += item.choices[0].delta.content assert len(all_items) > 0 last_item = all_items[-1] assert last_item.choices[0].finish_reason == "length" @pytest.mark.skip("Temporary skipping this test") @pytest.mark.asyncio @with_production_testing async def test_async_sentence_similarity() -> None: async_client = AsyncInferenceClient(model="sentence-transformers/all-MiniLM-L6-v2") scores = await async_client.sentence_similarity( "Machine learning is so easy.", other_sentences=[ "Deep learning is so straightforward.", "This is so difficult, like rocket science.", "I can't believe how much I struggled with this.", ], ) assert scores == [0.7785724997520447, 0.45876249670982362, 0.29062220454216003] def test_sync_vs_async_signatures() -> None: client = InferenceClient() async_client = AsyncInferenceClient() # Some methods have to be tested separately. special_methods = ["post", "text_generation", "chat_completion"] # Post: this is not automatically tested. No need to test its signature separately. # text-generation/chat-completion: return type changes from Iterable[...] to AsyncIterable[...] but input parameters are the same for name in ["text_generation", "chat_completion"]: sync_method = getattr(client, name) assert not inspect.iscoroutinefunction(sync_method) async_method = getattr(async_client, name) assert inspect.iscoroutinefunction(async_method) sync_sig = inspect.signature(sync_method) async_sig = inspect.signature(async_method) assert sync_sig.parameters == async_sig.parameters assert sync_sig.return_annotation != async_sig.return_annotation # Check that all methods are consistent between InferenceClient and AsyncInferenceClient for name in dir(client): if not inspect.ismethod(getattr(client, name)): # not a method continue if name.startswith("_"): # not public method continue if name in special_methods: # tested separately continue # Check that the sync method is not async sync_method = getattr(client, name) assert not inspect.iscoroutinefunction(sync_method) # Check that the async method is async async_method = getattr(async_client, name) # Since some methods are decorated with @_deprecate_arguments, we need to unwrap the async method to get the actual coroutine function # TODO: Remove this once the @_deprecate_arguments decorator is removed from the AsyncInferenceClient methods. assert inspect.iscoroutinefunction(inspect.unwrap(async_method)) # Check that expected inputs and outputs are the same sync_sig = inspect.signature(sync_method) async_sig = inspect.signature(async_method) assert sync_sig.parameters == async_sig.parameters assert sync_sig.return_annotation == async_sig.return_annotation @pytest.mark.asyncio @pytest.mark.skip("Deprecated (get_model_status)") async def test_get_status_too_big_model() -> None: model_status = await AsyncInferenceClient(token=False).get_model_status("facebook/nllb-moe-54b") assert model_status.loaded is False assert model_status.state == "TooBig" assert model_status.compute_type == "cpu" assert model_status.framework == "transformers" @pytest.mark.asyncio @pytest.mark.skip("Deprecated (get_model_status)") async def test_get_status_loaded_model() -> None: model_status = await AsyncInferenceClient(token=False).get_model_status("bigscience/bloom") assert model_status.loaded is True assert model_status.state == "Loaded" assert isinstance(model_status.compute_type, dict) # e.g. {'gpu': {'gpu': 'a100', 'count': 8}} assert model_status.framework == "text-generation-inference" @pytest.mark.asyncio @pytest.mark.skip("Deprecated (get_model_status)") async def test_get_status_unknown_model() -> None: with pytest.raises(ClientResponseError): await AsyncInferenceClient(token=False).get_model_status("unknown/model") @pytest.mark.asyncio @pytest.mark.skip("Deprecated (get_model_status)") async def test_get_status_model_as_url() -> None: with pytest.raises(NotImplementedError): await AsyncInferenceClient(token=False).get_model_status("https://unkown/model") @pytest.mark.asyncio @pytest.mark.skip("Deprecated (list_deployed_models)") async def test_list_deployed_models_single_frameworks() -> None: models_by_task = await AsyncInferenceClient().list_deployed_models("text-generation-inference") assert isinstance(models_by_task, dict) for task, models in models_by_task.items(): assert isinstance(task, str) assert isinstance(models, list) for model in models: assert isinstance(model, str) assert "text-generation" in models_by_task assert "HuggingFaceH4/zephyr-7b-beta" in models_by_task["text-generation"] @pytest.mark.asyncio async def test_async_generate_timeout_error(monkeypatch: pytest.MonkeyPatch) -> None: def _mock_aiohttp_client_timeout(*args, **kwargs): raise asyncio.TimeoutError def mock_check_supported_task(*args, **kwargs): return None monkeypatch.setattr( "huggingface_hub.inference._providers.hf_inference._check_supported_task", mock_check_supported_task ) monkeypatch.setattr("aiohttp.ClientSession.post", _mock_aiohttp_client_timeout) with pytest.raises(InferenceTimeoutError): await AsyncInferenceClient(timeout=1).text_generation("test") class CustomException(Exception): """Mock any exception that could happen while making a POST request.""" @pytest.mark.skip("Temporary skipping this test") @pytest.mark.asyncio @with_production_testing async def test_openai_compatibility_base_url_and_api_key(): client = AsyncInferenceClient( base_url="https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct", api_key="my-api-key", ) output = await client.chat.completions.create( model="meta-llama/Meta-Llama-3.1-8B-Instruct", messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Count to 10"}, ], stream=False, max_tokens=1024, ) assert "1, 2, 3, 4, 5, 6, 7, 8, 9, 10" in output.choices[0].message.content @pytest.mark.skip("Temporary skipping this test") @pytest.mark.asyncio @with_production_testing async def test_openai_compatibility_without_base_url(): client = AsyncInferenceClient() output = await client.chat.completions.create( model="meta-llama/Meta-Llama-3.1-8B-Instruct", messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Count to 10"}, ], stream=False, max_tokens=1024, ) assert "1, 2, 3, 4, 5, 6, 7, 8, 9, 10" in output.choices[0].message.content @pytest.mark.skip("Temporary skipping this test") @pytest.mark.asyncio @with_production_testing async def test_openai_compatibility_with_stream_true(): client = AsyncInferenceClient() output = await client.chat.completions.create( model="meta-llama/Meta-Llama-3.1-8B-Instruct", messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Count to 10"}, ], stream=True, max_tokens=1024, ) chunked_text = [ chunk.choices[0].delta.content async for chunk in output if chunk.choices[0].delta.content is not None ] assert len(chunked_text) == 35 output_text = "".join(chunked_text) assert "1, 2, 3, 4, 5, 6, 7, 8, 9, 10" in output_text @pytest.mark.skip("Temporary skipping this test") @pytest.mark.asyncio @with_production_testing async def test_http_session_correctly_closed() -> None: """ Regression test for #2493. Async client should close the HTTP session after the request is done. This is always done except for streamed responses if the stream is not fully consumed. Fixed by keeping a list of sessions and closing them all when deleting the client. See https://github.com/huggingface/huggingface_hub/issues/2493. """ client = AsyncInferenceClient("meta-llama/Meta-Llama-3.1-8B-Instruct") kwargs = {"prompt": "Hi", "stream": True, "max_new_tokens": 1} # Test create session + close it + check correctly unregistered await client.text_generation(**kwargs) assert len(client._sessions) == 1 await list(client._sessions)[0].close() assert len(client._sessions) == 0 # Test create multiple sessions + close AsyncInferenceClient + check correctly unregistered await client.text_generation(**kwargs) await client.text_generation(**kwargs) await client.text_generation(**kwargs) assert len(client._sessions) == 3 await client.close() assert len(client._sessions) == 0 @pytest.mark.asyncio async def test_use_async_with_inference_client(): with patch("huggingface_hub.AsyncInferenceClient.close") as mock_close: async with AsyncInferenceClient(): pass mock_close.assert_called_once() @pytest.mark.asyncio @patch("aiohttp.ClientSession._request") async def test_client_responses_correctly_closed(request_mock: Mock) -> None: """ Regression test for #2521. Async client must close the ClientResponse objects when exiting the async context manager. Fixed by closing the response objects when the session is closed. See https://github.com/huggingface/huggingface_hub/issues/2521. """ async with AsyncInferenceClient() as client: session = client._get_client_session() response1 = await session.get("http://this-is-a-fake-url.com") response2 = await session.post("http://this-is-a-fake-url.com", json={}) # Response objects are closed when the AsyncInferenceClient is closed response1.close.assert_called_once() response2.close.assert_called_once() @pytest.mark.asyncio async def test_warns_if_client_deleted_with_opened_sessions(): client = AsyncInferenceClient() session = client._get_client_session() with pytest.warns(UserWarning): client.__del__() await session.close() huggingface_hub-0.31.1/tests/test_inference_client.py000066400000000000000000001431611500667546600227750ustar00rootroot00000000000000# Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import io import json import os import string import time from pathlib import Path from typing import List from unittest.mock import MagicMock, patch import numpy as np import pytest from PIL import Image from huggingface_hub import ( ChatCompletionOutput, ChatCompletionOutputComplete, ChatCompletionStreamOutput, DocumentQuestionAnsweringOutputElement, FillMaskOutputElement, ImageClassificationOutputElement, ImageToTextOutput, InferenceClient, ObjectDetectionBoundingBox, ObjectDetectionOutputElement, QuestionAnsweringOutputElement, TableQuestionAnsweringOutputElement, TextClassificationOutputElement, TokenClassificationOutputElement, TranslationOutput, VisualQuestionAnsweringOutputElement, ZeroShotClassificationOutputElement, constants, hf_hub_download, ) from huggingface_hub.errors import HfHubHTTPError, ValidationError from huggingface_hub.inference._client import _open_as_binary from huggingface_hub.inference._common import _stream_chat_completion_response, _stream_text_generation_response from huggingface_hub.inference._providers import get_provider_helper from huggingface_hub.inference._providers.hf_inference import _build_chat_completion_url from .testing_utils import expect_deprecation, with_production_testing # Avoid calling APIs in VCRed tests _RECOMMENDED_MODELS_FOR_VCR = { "black-forest-labs": { "text-to-image": "black-forest-labs/FLUX.1-dev", }, "cerebras": { "conversational": "meta-llama/Llama-3.3-70B-Instruct", }, "together": { "conversational": "meta-llama/Meta-Llama-3-8B-Instruct", "text-generation": "meta-llama/Llama-2-70b-hf", "text-to-image": "stabilityai/stable-diffusion-xl-base-1.0", }, "fal-ai": { "text-to-image": "black-forest-labs/FLUX.1-dev", "automatic-speech-recognition": "openai/whisper-large-v3", }, "fireworks-ai": { "conversational": "meta-llama/Llama-3.3-70B-Instruct", }, "hf-inference": { "audio-classification": "alefiury/wav2vec2-large-xlsr-53-gender-recognition-librispeech", "audio-to-audio": "speechbrain/sepformer-wham", "automatic-speech-recognition": "jonatasgrosman/wav2vec2-large-xlsr-53-english", "conversational": "meta-llama/Llama-3.1-8B-Instruct", "document-question-answering": "naver-clova-ix/donut-base-finetuned-docvqa", "feature-extraction": "facebook/bart-base", "image-classification": "google/vit-base-patch16-224", "image-to-text": "Salesforce/blip-image-captioning-base", "image-segmentation": "facebook/detr-resnet-50-panoptic", "object-detection": "facebook/detr-resnet-50", "sentence-similarity": "sentence-transformers/all-MiniLM-L6-v2", "summarization": "sshleifer/distilbart-cnn-12-6", "table-question-answering": "google/tapas-base-finetuned-wtq", "tabular-classification": "julien-c/wine-quality", "tabular-regression": "scikit-learn/Fish-Weight", "text-classification": "distilbert/distilbert-base-uncased-finetuned-sst-2-english", "text-to-image": "CompVis/stable-diffusion-v1-4", "text-to-speech": "espnet/kan-bayashi_ljspeech_vits", "token-classification": "dbmdz/bert-large-cased-finetuned-conll03-english", "translation": "t5-small", "visual-question-answering": "dandelin/vilt-b32-finetuned-vqa", "zero-shot-classification": "facebook/bart-large-mnli", "zero-shot-image-classification": "openai/clip-vit-base-patch32", }, "hyperbolic": { "text-generation": "meta-llama/Llama-3.1-405B", "conversational": "meta-llama/Llama-3.2-3B-Instruct", "text-to-image": "stabilityai/stable-diffusion-2", }, "nebius": { "conversational": "meta-llama/Llama-3.1-8B-Instruct", "text-generation": "Qwen/Qwen2.5-32B-Instruct", "text-to-image": "stabilityai/stable-diffusion-xl-base-1.0", }, "novita": { "text-generation": "NousResearch/Nous-Hermes-Llama2-13b", "conversational": "meta-llama/Llama-3.1-8B-Instruct", }, "replicate": { "text-to-image": "ByteDance/SDXL-Lightning", }, "sambanova": { "conversational": "meta-llama/Llama-3.1-8B-Instruct", }, } CHAT_COMPLETION_MODEL = "HuggingFaceH4/zephyr-7b-beta" CHAT_COMPLETION_MESSAGES = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is deep learning?"}, ] CHAT_COMPLETE_NON_TGI_MODEL = "microsoft/DialoGPT-small" CHAT_COMPLETION_TOOL_INSTRUCTIONS = [ { "role": "system", "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.", }, { "role": "user", "content": "What's the weather like the next 3 days in San Francisco, CA?", }, ] CHAT_COMPLETION_TOOLS = [ # 1 tool to get current weather, 1 to get N-day weather forecast { "type": "function", "function": { "name": "get_current_weather", "description": "Get the current weather", "parameters": { "type": "object", "properties": { "location": { "type": "string", "description": "The city and state, e.g. San Francisco, CA", }, "format": { "type": "string", "enum": ["celsius", "fahrenheit"], "description": "The temperature unit to use. Infer this from the users location.", }, }, "required": ["location", "format"], }, }, }, { "type": "function", "function": { "name": "get_n_day_weather_forecast", "description": "Get an N-day weather forecast", "parameters": { "type": "object", "properties": { "location": { "type": "string", "description": "The city and state, e.g. San Francisco, CA", }, "format": { "type": "string", "enum": ["celsius", "fahrenheit"], "description": "The temperature unit to use. Infer this from the users location.", }, "num_days": { "type": "integer", "description": "The number of days to forecast", }, }, "required": ["location", "format", "num_days"], }, }, }, ] CHAT_COMPLETION_RESPONSE_FORMAT_MESSAGE = [ { "role": "user", "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I saw and when?", }, ] CHAT_COMPLETION_RESPONSE_FORMAT = { "type": "json_object", "value": { "properties": { "location": {"type": "string"}, "activity": {"type": "string"}, "animals_seen": {"type": "integer", "minimum": 1, "maximum": 5}, "animals": {"type": "array", "items": {"type": "string"}}, }, "required": ["location", "activity", "animals_seen", "animals"], }, } def list_clients(task: str) -> List[pytest.param]: """Get list of clients for a specific task, with proper skip handling.""" clients = [] for provider, tasks in _RECOMMENDED_MODELS_FOR_VCR.items(): if task in tasks: api_key = os.getenv("HF_INFERENCE_TEST_TOKEN") clients.append( pytest.param( (provider, tasks[task], api_key), id=f"{provider},{task}", ) ) return clients @pytest.fixture() @with_production_testing def client(request): """ Fixture to create client with proper skip handling. Note: VCR mode is only accessible through a fixture. """ provider, model, api_key = request.param vcr_record_mode = request.config.getoption("--vcr-record") # If we are recording and the api key is not set, skip the test # replaying modes are "all", "new_episodes" and "once" # non replaying modes are "none" and None if vcr_record_mode not in ["none", None] and not api_key: pytest.skip(f"API KEY not set for provider {provider}, skipping test") # If api_key is provided, use it if api_key: return InferenceClient(model=model, provider=provider, token=api_key) # Otherwise use dummy token for VCR playback return InferenceClient(model=model, provider=provider, token="hf_dummy_token") # Define fixtures for the files @pytest.fixture(scope="module") @with_production_testing def audio_file(): return hf_hub_download(repo_id="Narsil/image_dummy", repo_type="dataset", filename="sample1.flac") @pytest.fixture(scope="module") @with_production_testing def image_file(): return hf_hub_download(repo_id="Narsil/image_dummy", repo_type="dataset", filename="lena.png") @pytest.fixture(scope="module") @with_production_testing def document_file(): return hf_hub_download(repo_id="impira/docquery", repo_type="space", filename="contract.jpeg") class TestBase: @pytest.fixture(autouse=True) def setup(self, audio_file, image_file, document_file): self.audio_file = audio_file self.image_file = image_file self.document_file = document_file @pytest.fixture(autouse=True) def mock_recommended_models(self, monkeypatch): def mock_fetch(): return _RECOMMENDED_MODELS_FOR_VCR["hf-inference"] monkeypatch.setattr("huggingface_hub.inference._providers.hf_inference._fetch_recommended_models", mock_fetch) @with_production_testing @pytest.mark.skip("Temporary skipping tests for InferenceClient") class TestInferenceClient(TestBase): @pytest.mark.parametrize("client", list_clients("audio-classification"), indirect=True) def test_audio_classification(self, client: InferenceClient): output = client.audio_classification(self.audio_file) assert isinstance(output, list) assert len(output) > 0 for item in output: assert isinstance(item.score, float) assert isinstance(item.label, str) @pytest.mark.parametrize("client", list_clients("audio-to-audio"), indirect=True) def test_audio_to_audio(self, client: InferenceClient): output = client.audio_to_audio(self.audio_file) assert isinstance(output, list) assert len(output) > 0 for item in output: assert isinstance(item.label, str) assert isinstance(item.blob, bytes) assert item.content_type == "audio/flac" @pytest.mark.parametrize("client", list_clients("automatic-speech-recognition"), indirect=True) def test_automatic_speech_recognition(self, client: InferenceClient): output = client.automatic_speech_recognition(self.audio_file) # Remove punctuation from the output normalized_output = output.text.translate(str.maketrans("", "", string.punctuation)) assert normalized_output.lower().strip() == "a man said to the universe sir i exist" @pytest.mark.parametrize("client", list_clients("conversational"), indirect=True) def test_chat_completion_no_stream(self, client: InferenceClient): output = client.chat_completion(messages=CHAT_COMPLETION_MESSAGES, stream=False) assert isinstance(output, ChatCompletionOutput) assert output.created < time.time() assert isinstance(output.choices, list) assert len(output.choices) == 1 assert isinstance(output.choices[0], ChatCompletionOutputComplete) @pytest.mark.parametrize("client", list_clients("conversational"), indirect=True) def test_chat_completion_with_stream(self, client: InferenceClient): output = list( client.chat_completion( messages=CHAT_COMPLETION_MESSAGES, stream=True, max_tokens=20, ) ) assert isinstance(output, list) assert all(isinstance(item, ChatCompletionStreamOutput) for item in output) # All items except the last one have a single choice with role/content delta for item in output[:-1]: assert len(item.choices) == 1 assert item.choices[0].finish_reason is None assert item.choices[0].index == 0 # Last item has a finish reason assert output[-1].choices[0].finish_reason == "length" def test_chat_completion_with_non_tgi(self) -> None: client = InferenceClient(provider="hf-inference") output = client.chat_completion( messages=CHAT_COMPLETION_MESSAGES, model=CHAT_COMPLETE_NON_TGI_MODEL, stream=False, max_tokens=20, ) assert isinstance(output, ChatCompletionOutput) assert output.model == "microsoft/DialoGPT-small" assert len(output.choices) == 1 @pytest.mark.skip(reason="Schema not aligned between providers") @pytest.mark.parametrize("client", list_clients("conversational"), indirect=True) def test_chat_completion_with_tool(self, client: InferenceClient): response = client.chat_completion( messages=CHAT_COMPLETION_TOOL_INSTRUCTIONS, tools=CHAT_COMPLETION_TOOLS, tool_choice="auto", max_tokens=500, ) output = response.choices[0] # Single message before EOS assert output.finish_reason in ["tool_calls", "eos_token", "stop"] assert output.index == 0 assert output.message.role == "assistant" # No content but a tool call assert output.message.content is None assert len(output.message.tool_calls) == 1 # Tool tool_call = output.message.tool_calls[0] assert tool_call.type == "function" # Since tool_choice="auto", we can't know which tool will be called assert tool_call.function.name in ["get_n_day_weather_forecast", "get_current_weather"] args = tool_call.function.arguments if isinstance(args, str): args = json.loads(args) assert args["format"] in ["fahrenheit", "celsius"] assert args["location"] == "San Francisco, CA" assert args["num_days"] == 3 # Now, test with tool_choice="get_current_weather" response = client.chat_completion( messages=CHAT_COMPLETION_TOOL_INSTRUCTIONS, tools=CHAT_COMPLETION_TOOLS, tool_choice={ "type": "function", "function": { "name": "get_current_weather", }, }, max_tokens=500, ) output = response.choices[0] tool_call = output.message.tool_calls[0] assert tool_call.function.name == "get_current_weather" # No need for 'num_days' with this tool assert tool_call.function.arguments == { "format": "fahrenheit", "location": "San Francisco, CA", } @pytest.mark.skip(reason="Schema not aligned between providers") @pytest.mark.parametrize("client", list_clients("conversational"), indirect=True) def test_chat_completion_with_response_format(self, client: InferenceClient): response = client.chat_completion( messages=CHAT_COMPLETION_RESPONSE_FORMAT_MESSAGE, response_format=CHAT_COMPLETION_RESPONSE_FORMAT, max_tokens=500, ) output = response.choices[0].message.content assert json.loads(output) == { "activity": "biking", "animals": ["puppy", "cat", "raccoon"], "animals_seen": 3, "location": "park", } def test_chat_completion_unprocessable_entity(self) -> None: """Regression test for #2225. See https://github.com/huggingface/huggingface_hub/issues/2225. """ client = InferenceClient(provider="hf-inference") with pytest.raises(HfHubHTTPError): client.chat_completion( "please output 'Observation'", # Not a list of messages stop=["Observation", "Final Answer"], max_tokens=200, model="meta-llama/Meta-Llama-3-70B-Instruct", ) @pytest.mark.parametrize("client", list_clients("document-question-answering"), indirect=True) def test_document_question_answering(self, client: InferenceClient): output = client.document_question_answering(self.document_file, "What is the purchase amount?") assert output == [ DocumentQuestionAnsweringOutputElement( answer="$1,0000,000,00", end=None, score=None, start=None, ) ] @pytest.mark.parametrize("client", list_clients("feature-extraction"), indirect=True) def test_feature_extraction_with_transformers(self, client: InferenceClient): embedding = client.feature_extraction("Hi, who are you?") assert isinstance(embedding, np.ndarray) assert embedding.shape == (1, 8, 768) @pytest.mark.parametrize("client", list_clients("feature-extraction"), indirect=True) def test_feature_extraction_with_sentence_transformers(self, client: InferenceClient): embedding = client.feature_extraction("Hi, who are you?") assert isinstance(embedding, np.ndarray) assert embedding.shape == (1, 8, 768) @pytest.mark.parametrize("client", list_clients("fill-mask"), indirect=True) def test_fill_mask(self, client: InferenceClient): output = client.fill_mask("The goal of life is .") assert output == [ FillMaskOutputElement( score=0.06897063553333282, sequence="The goal of life is happiness.", token=11098, token_str=" happiness", fill_mask_output_token_str=None, ), FillMaskOutputElement( score=0.06554922461509705, sequence="The goal of life is immortality.", token=45075, token_str=" immortality", fill_mask_output_token_str=None, ), FillMaskOutputElement( score=0.0323575921356678, sequence="The goal of life is yours.", token=14314, token_str=" yours", fill_mask_output_token_str=None, ), FillMaskOutputElement( score=0.02431388944387436, sequence="The goal of life is liberation.", token=22211, token_str=" liberation", fill_mask_output_token_str=None, ), FillMaskOutputElement( score=0.023767812177538872, sequence="The goal of life is simplicity.", token=25342, token_str=" simplicity", fill_mask_output_token_str=None, ), ] def test_hf_inference_get_recommended_model_has_recommendation(self) -> None: from huggingface_hub.inference._providers.hf_inference import HFInferenceTask HFInferenceTask("feature-extraction")._prepare_mapping_info(None).provider_id == "facebook/bart-base" HFInferenceTask("translation")._prepare_mapping_info(None).provider_id == "t5-small" def test_hf_inference_get_recommended_model_no_recommendation(self) -> None: from huggingface_hub.inference._providers.hf_inference import HFInferenceTask with pytest.raises(ValueError): HFInferenceTask("text-generation")._prepare_mapping_info(None) @pytest.mark.parametrize("client", list_clients("image-classification"), indirect=True) def test_image_classification(self, client: InferenceClient): output = client.image_classification(self.image_file) assert output == [ ImageClassificationOutputElement(label="brassiere, bra, bandeau", score=0.11767438799142838), ImageClassificationOutputElement(label="sombrero", score=0.09572819620370865), ImageClassificationOutputElement(label="cowboy hat, ten-gallon hat", score=0.0900089293718338), ImageClassificationOutputElement(label="bonnet, poke bonnet", score=0.06615174561738968), ImageClassificationOutputElement(label="fur coat", score=0.061511047184467316), ] @pytest.mark.parametrize("client", list_clients("image-segmentation"), indirect=True) def test_image_segmentation(self, client: InferenceClient): output = client.image_segmentation(self.image_file) assert isinstance(output, list) assert len(output) > 0 for item in output: assert isinstance(item.score, float) assert isinstance(item.label, str) assert isinstance(item.mask, Image.Image) assert item.mask.height == 512 assert item.mask.width == 512 # ERROR 500 from server # Only during tests, not when running locally. Has to be investigated. # def test_image_to_image(self) -> None: # image = self.client.image_to_image(self.image_file, prompt="turn the woman into a man") @pytest.mark.parametrize("client", list_clients("image-to-text"), indirect=True) def test_image_to_text(self, client: InferenceClient): caption = client.image_to_text(self.image_file) assert isinstance(caption, ImageToTextOutput) assert caption.generated_text == "a woman in a hat and dress posing for a photo" @pytest.mark.parametrize("client", list_clients("object-detection"), indirect=True) def test_object_detection(self, client: InferenceClient): output = client.object_detection(self.image_file) assert output == [ ObjectDetectionOutputElement( box=ObjectDetectionBoundingBox( xmin=59, ymin=39, xmax=420, ymax=510, ), label="person", score=0.9486680030822754, ), ObjectDetectionOutputElement( box=ObjectDetectionBoundingBox( xmin=143, ymin=4, xmax=510, ymax=387, ), label="umbrella", score=0.5733323693275452, ), ObjectDetectionOutputElement( box=ObjectDetectionBoundingBox( xmin=60, ymin=162, xmax=413, ymax=510, ), label="person", score=0.5082724094390869, ), ] @pytest.mark.parametrize("client", list_clients("question-answering"), indirect=True) def test_question_answering(self, client: InferenceClient): output = client.question_answering(question="What is the meaning of life?", context="42") assert output == QuestionAnsweringOutputElement(answer="42", end=2, score=1.4291124728060822e-08, start=0) @pytest.mark.parametrize("client", list_clients("sentence-similarity"), indirect=True) def test_sentence_similarity(self, client: InferenceClient): scores = client.sentence_similarity( "Machine learning is so easy.", other_sentences=[ "Deep learning is so straightforward.", "This is so difficult, like rocket science.", "I can't believe how much I struggled with this.", ], ) assert scores == [0.7785724997520447, 0.4587624967098236, 0.29062220454216003] @pytest.mark.parametrize("client", list_clients("summarization"), indirect=True) def test_summarization(self, client: InferenceClient): summary = client.summarization("The sky is blue, the tree is green.") assert isinstance(summary.summary_text, str) @pytest.mark.skip(reason="This model is not available on InferenceAPI") @pytest.mark.parametrize("client", list_clients("tabular-classification"), indirect=True) def test_tabular_classification(self, client: InferenceClient): table = { "fixed_acidity": ["7.4", "7.8", "10.3"], "volatile_acidity": ["0.7", "0.88", "0.32"], "citric_acid": ["0", "0", "0.45"], "residual_sugar": ["1.9", "2.6", "6.4"], "chlorides": ["0.076", "0.098", "0.073"], "free_sulfur_dioxide": ["11", "25", "5"], "total_sulfur_dioxide": ["34", "67", "13"], "density": ["0.9978", "0.9968", "0.9976"], "pH": ["3.51", "3.2", "3.23"], "sulphates": ["0.56", "0.68", "0.82"], "alcohol": ["9.4", "9.8", "12.6"], } output = client.tabular_classification(table=table) assert output == ["5", "5", "5"] @pytest.mark.skip(reason="This model is not available on InferenceAPI") @pytest.mark.parametrize("client", list_clients("tabular-regression"), indirect=True) def test_tabular_regression(self, client: InferenceClient): table = { "Height": ["11.52", "12.48", "12.3778"], "Length1": ["23.2", "24", "23.9"], "Length2": ["25.4", "26.3", "26.5"], "Length3": ["30", "31.2", "31.1"], "Species": ["Bream", "Bream", "Bream"], "Width": ["4.02", "4.3056", "4.6961"], } output = client.tabular_regression(table=table) assert output == [110, 120, 130] @pytest.mark.parametrize("client", list_clients("table-question-answering"), indirect=True) def test_table_question_answering(self, client: InferenceClient): table = { "Repository": ["Transformers", "Datasets", "Tokenizers"], "Stars": ["36542", "4512", "3934"], } query = "How many stars does the transformers repository have?" output = client.table_question_answering(query=query, table=table) assert output == TableQuestionAnsweringOutputElement( answer="AVERAGE > 36542", cells=["36542"], coordinates=[[0, 1]], aggregator="AVERAGE" ) @pytest.mark.parametrize("client", list_clients("text-classification"), indirect=True) def test_text_classification(self, client: InferenceClient): output = client.text_classification("I like you") assert output == [ TextClassificationOutputElement(label="POSITIVE", score=0.9998695850372314), TextClassificationOutputElement(label="NEGATIVE", score=0.00013043530634604394), ] def test_text_generation(self) -> None: """Tested separately in `test_inference_text_generation.py`.""" @pytest.mark.parametrize("client", list_clients("text-to-image"), indirect=True) def test_text_to_image_default(self, client: InferenceClient): image = client.text_to_image("An astronaut riding a horse on the moon.") assert isinstance(image, Image.Image) @pytest.mark.skip(reason="Need to check why fal.ai doesn't take image_size into account") @pytest.mark.parametrize("client", list_clients("text-to-image"), indirect=True) def test_text_to_image_with_parameters(self, client: InferenceClient): image = client.text_to_image("An astronaut riding a horse on the moon.", height=256, width=256) assert isinstance(image, Image.Image) assert image.height == 256 assert image.width == 256 @pytest.mark.parametrize("client", list_clients("text-to-speech"), indirect=True) def test_text_to_speech(self, client: InferenceClient): audio = client.text_to_speech("Hello world") assert isinstance(audio, bytes) @pytest.mark.parametrize("client", list_clients("translation"), indirect=True) def test_translation(self, client: InferenceClient): output = client.translation("Hello world") assert output == TranslationOutput(translation_text="Hallo Welt") @pytest.mark.parametrize("client", list_clients("translation"), indirect=True) def test_translation_with_source_and_target_language(self, client: InferenceClient): output_with_langs = client.translation( "Hello world", model="facebook/mbart-large-50-many-to-many-mmt", src_lang="en_XX", tgt_lang="fr_XX" ) assert isinstance(output_with_langs, TranslationOutput) with pytest.raises(ValueError): client.translation("Hello world", model="facebook/mbart-large-50-many-to-many-mmt", src_lang="en_XX") with pytest.raises(ValueError): client.translation("Hello world", model="facebook/mbart-large-50-many-to-many-mmt", tgt_lang="en_XX") @pytest.mark.parametrize("client", list_clients("token-classification"), indirect=True) def test_token_classification(self, client: InferenceClient): output = client.token_classification(text="My name is Sarah Jessica Parker but you can call me Jessica") assert output == [ TokenClassificationOutputElement( score=0.9991335868835449, end=31, entity_group="PER", start=11, word="Sarah Jessica Parker" ), TokenClassificationOutputElement( score=0.9979913234710693, end=59, entity_group="PER", start=52, word="Jessica" ), ] @pytest.mark.parametrize("client", list_clients("visual-question-answering"), indirect=True) def test_visual_question_answering(self, client: InferenceClient): output = client.visual_question_answering(image=self.image_file, question="Who's in the picture?") assert output == [ VisualQuestionAnsweringOutputElement(score=0.9386942982673645, answer="woman"), VisualQuestionAnsweringOutputElement(score=0.3431190550327301, answer="girl"), VisualQuestionAnsweringOutputElement(score=0.08407800644636154, answer="lady"), VisualQuestionAnsweringOutputElement(score=0.05075192078948021, answer="female"), VisualQuestionAnsweringOutputElement(score=0.017771074548363686, answer="man"), ] @pytest.mark.parametrize("client", list_clients("zero-shot-classification"), indirect=True) def test_zero_shot_classification_single_label(self, client: InferenceClient): output = client.zero_shot_classification( "A new model offers an explanation for how the Galilean satellites formed around the solar system's" "largest world. Konstantin Batygin did not set out to solve one of the solar system's most puzzling" " mysteries when he went for a run up a hill in Nice, France.", candidate_labels=["space & cosmos", "scientific discovery", "microbiology", "robots", "archeology"], ) assert output == [ ZeroShotClassificationOutputElement(label="scientific discovery", score=0.796166181564331), ZeroShotClassificationOutputElement(label="space & cosmos", score=0.18570725619792938), ZeroShotClassificationOutputElement(label="microbiology", score=0.007308819331228733), ZeroShotClassificationOutputElement(label="archeology", score=0.0062583745457232), ZeroShotClassificationOutputElement(label="robots", score=0.004559362772852182), ] @pytest.mark.parametrize("client", list_clients("zero-shot-classification"), indirect=True) def test_zero_shot_classification_multi_label(self, client: InferenceClient): output = client.zero_shot_classification( text="A new model offers an explanation for how the Galilean satellites formed around the solar system's" "largest world. Konstantin Batygin did not set out to solve one of the solar system's most puzzling" " mysteries when he went for a run up a hill in Nice, France.", candidate_labels=["space & cosmos", "scientific discovery", "microbiology", "robots", "archeology"], multi_label=True, ) assert output == [ ZeroShotClassificationOutputElement(label="scientific discovery", score=0.9829296469688416), ZeroShotClassificationOutputElement(label="space & cosmos", score=0.7551906108856201), ZeroShotClassificationOutputElement(label="microbiology", score=0.0005462627159431577), ZeroShotClassificationOutputElement(label="archeology", score=0.0004713202069979161), ZeroShotClassificationOutputElement(label="robots", score=0.000304485292872414), ] @pytest.mark.parametrize("client", list_clients("zero-shot-image-classification"), indirect=True) def test_zero_shot_image_classification(self, client: InferenceClient): output = client.zero_shot_image_classification( image=self.image_file, candidate_labels=["tree", "woman", "cat"] ) assert isinstance(output, list) assert len(output) > 0 for item in output: assert isinstance(item.label, str) assert isinstance(item.score, float) class TestOpenAsBinary: @pytest.fixture(autouse=True) def setup(self, audio_file, image_file, document_file): self.audio_file = audio_file self.image_file = image_file self.document_file = document_file def test_open_as_binary_with_none(self) -> None: with _open_as_binary(None) as content: assert content is None def test_open_as_binary_from_str_path(self) -> None: with _open_as_binary(self.image_file) as content: assert isinstance(content, io.BufferedReader) def test_open_as_binary_from_pathlib_path(self) -> None: with _open_as_binary(Path(self.image_file)) as content: assert isinstance(content, io.BufferedReader) def test_open_as_binary_from_url(self) -> None: with _open_as_binary("https://huggingface.co/datasets/Narsil/image_dummy/resolve/main/tree.png") as content: assert isinstance(content, bytes) def test_open_as_binary_opened_file(self) -> None: with Path(self.image_file).open("rb") as f: with _open_as_binary(f) as content: assert content == f assert isinstance(content, io.BufferedReader) def test_open_as_binary_from_bytes(self) -> None: content_bytes = Path(self.image_file).read_bytes() with _open_as_binary(content_bytes) as content: assert content == content_bytes class TestHeadersAndCookies(TestBase): def test_headers_and_cookies(self) -> None: client = InferenceClient(headers={"X-My-Header": "foo"}, cookies={"my-cookie": "bar"}) assert client.headers["X-My-Header"] == "foo" assert client.cookies["my-cookie"] == "bar" @patch("huggingface_hub.inference._client._bytes_to_image") @patch("huggingface_hub.inference._client.get_session") @patch("huggingface_hub.inference._providers.hf_inference._check_supported_task") def test_accept_header_image( self, check_supported_task_mock: MagicMock, get_session_mock: MagicMock, bytes_to_image_mock: MagicMock, ) -> None: """Test that Accept: image/png header is set for image tasks.""" client = InferenceClient(provider="hf-inference") response = client.text_to_image("An astronaut riding a horse") assert response == bytes_to_image_mock.return_value headers = get_session_mock().post.call_args_list[0].kwargs["headers"] assert headers["Accept"] == "image/png" class TestListDeployedModels(TestBase): @expect_deprecation("list_deployed_models") @patch("huggingface_hub.inference._client.get_session") def test_list_deployed_models_main_frameworks_mock(self, get_session_mock: MagicMock) -> None: InferenceClient(provider="hf-inference").list_deployed_models() assert len(get_session_mock.return_value.get.call_args_list) == len(constants.MAIN_INFERENCE_API_FRAMEWORKS) @expect_deprecation("list_deployed_models") @patch("huggingface_hub.inference._client.get_session") def test_list_deployed_models_all_frameworks_mock(self, get_session_mock: MagicMock) -> None: InferenceClient(provider="hf-inference").list_deployed_models("all") assert len(get_session_mock.return_value.get.call_args_list) == len(constants.ALL_INFERENCE_API_FRAMEWORKS) @with_production_testing @pytest.mark.skip("Temporary skipping tests for TestOpenAICompatibility") class TestOpenAICompatibility(TestBase): def test_base_url_and_api_key(self): client = InferenceClient( provider="hf-inference", base_url="https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct", api_key=os.getenv("HF_INFERENCE_TEST_TOKEN"), ) output = client.chat.completions.create( model="meta-llama/Meta-Llama-3-8B-Instruct", messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Count to 10"}, ], stream=False, max_tokens=1024, ) assert "1, 2, 3, 4, 5, 6, 7, 8, 9, 10" in output.choices[0].message.content def test_without_base_url(self): client = InferenceClient( provider="hf-inference", token=os.getenv("HF_INFERENCE_TEST_TOKEN"), ) output = client.chat.completions.create( model="meta-llama/Meta-Llama-3-8B-Instruct", messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Count to 10"}, ], stream=False, max_tokens=1024, ) assert "1, 2, 3, 4, 5, 6, 7, 8, 9, 10" in output.choices[0].message.content def test_with_stream_true(self): client = InferenceClient( provider="hf-inference", token=os.getenv("HF_INFERENCE_TEST_TOKEN"), ) output = client.chat.completions.create( model="meta-llama/Meta-Llama-3-8B-Instruct", messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Count to 10"}, ], stream=True, max_tokens=1024, ) chunked_text = [chunk.choices[0].delta.content for chunk in output] assert len(chunked_text) == 30 output_text = "".join(chunked_text) assert "1, 2, 3, 4, 5, 6, 7, 8, 9, 10" in output_text def test_model_and_base_url_mutually_exclusive(self): with pytest.raises(ValueError): InferenceClient(model="meta-llama/Meta-Llama-3-8B-Instruct", base_url="http://127.0.0.1:8000") def test_token_initialization_from_token(self): # Test with explicit token client = InferenceClient(token="my-token") assert client.token == "my-token" def test_token_initialization_from_api_key(self): # Test with api_key client = InferenceClient(api_key="my-api-key") assert client.token == "my-api-key" def test_token_initialization_cannot_be_both(self): # Test with both token and api_key raises error with pytest.raises(ValueError, match="Received both `token` and `api_key` arguments"): InferenceClient(token="my-token", api_key="my-api-key") def test_token_initialization_default_to_none(self): # Test with token=None (default behavior) client = InferenceClient() assert client.token is None def test_token_initialization_with_token_true(self, mocker): # Test with token=True and token is set with get_token() mocker.patch("huggingface_hub.inference._client.get_token", return_value="my-token") client = InferenceClient(token=True) assert client.token == "my-token" def test_token_initialization_cannot_be_token_false(self): # Test with token=False raises error with pytest.raises(ValueError, match="Cannot use `token=False` to disable authentication"): InferenceClient(token=False) @pytest.mark.parametrize( "stop_signal", [ b"data: [DONE]", b"data: [DONE]\n", b"data: [DONE] ", ], ) def test_stream_text_generation_response(stop_signal: bytes): data = [ b'data: {"index":1,"token":{"id":4560,"text":" trying","logprob":-2.078125,"special":false},"generated_text":null,"details":null}', b"", # Empty line is skipped b"\n", # Newline is skipped b'data: {"index":2,"token":{"id":311,"text":" to","logprob":-0.026245117,"special":false},"generated_text":" trying to","details":null}', stop_signal, # Stop signal # Won't parse after b'data: {"index":2,"token":{"id":311,"text":" to","logprob":-0.026245117,"special":false},"generated_text":" trying to","details":null}', ] output = list(_stream_text_generation_response(data, details=False)) assert len(output) == 2 assert output == [" trying", " to"] @pytest.mark.parametrize( "stop_signal", [ b"data: [DONE]", b"data: [DONE]\n", b"data: [DONE] ", ], ) def test_stream_chat_completion_response(stop_signal: bytes): data = [ b'data: {"object":"chat.completion.chunk","id":"","created":1721737661,"model":"","system_fingerprint":"2.1.2-dev0-sha-5fca30e","choices":[{"index":0,"delta":{"role":"assistant","content":"Both"},"logprobs":null,"finish_reason":null}]}', b"", # Empty line is skipped b"\n", # Newline is skipped b'data: {"object":"chat.completion.chunk","id":"","created":1721737661,"model":"","system_fingerprint":"2.1.2-dev0-sha-5fca30e","choices":[{"index":0,"delta":{"role":"assistant","content":" Rust"},"logprobs":null,"finish_reason":null}]}', stop_signal, # Stop signal # Won't parse after b'data: {"index":2,"token":{"id":311,"text":" to","logprob":-0.026245117,"special":false},"generated_text":" trying to","details":null}', ] output = list(_stream_chat_completion_response(data)) assert len(output) == 2 assert output[0].choices[0].delta.content == "Both" assert output[1].choices[0].delta.content == " Rust" def test_chat_completion_error_in_stream(): """ Regression test for https://github.com/huggingface/huggingface_hub/issues/2514. When an error is encountered in the stream, it should raise a TextGenerationError (e.g. a ValidationError). """ data = [ b'data: {"object":"chat.completion.chunk","id":"","created":1721737661,"model":"","system_fingerprint":"2.1.2-dev0-sha-5fca30e","choices":[{"index":0,"delta":{"role":"assistant","content":"Both"},"logprobs":null,"finish_reason":null}]}', b'data: {"error":"Input validation error: `inputs` tokens + `max_new_tokens` must be <= 4096. Given: 6 `inputs` tokens and 4091 `max_new_tokens`","error_type":"validation"}', ] with pytest.raises(ValidationError): for token in _stream_chat_completion_response(data): pass INFERENCE_API_URL = "https://api-inference.huggingface.co/models" INFERENCE_ENDPOINT_URL = "https://rur2d6yoccusjxgn.us-east-1.aws.endpoints.huggingface.cloud" # example LOCAL_TGI_URL = "http://0.0.0.0:8080" @pytest.mark.parametrize( ("model_url", "expected_url"), [ # Inference API ( f"{INFERENCE_API_URL}/username/repo_name", f"{INFERENCE_API_URL}/username/repo_name/v1/chat/completions", ), # Inference Endpoint ( INFERENCE_ENDPOINT_URL, f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions", ), # Inference Endpoint - full url ( f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions", f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions", ), # Inference Endpoint - url with '/v1' (OpenAI compatibility) ( f"{INFERENCE_ENDPOINT_URL}/v1", f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions", ), # Inference Endpoint - url with '/v1/' (OpenAI compatibility) ( f"{INFERENCE_ENDPOINT_URL}/v1/", f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions", ), # Local TGI with trailing '/v1' ( f"{LOCAL_TGI_URL}/v1", f"{LOCAL_TGI_URL}/v1/chat/completions", ), ], ) def test_resolve_chat_completion_url(model_url: str, expected_url: str): url = _build_chat_completion_url(model_url) assert url == expected_url def test_pass_url_as_base_url(): client = InferenceClient( provider="hf-inference", base_url="http://localhost:8082/v1/", ) provider = get_provider_helper("hf-inference", "text-generation", "test-model") request = provider.prepare_request( inputs="The huggingface_hub library is ", parameters={}, headers={}, model=client.model, api_key=None ) assert request.url == "http://localhost:8082/v1/" def test_cannot_pass_token_false(): """Regression test for #2853. It is no longer possible to pass `token=False` to the InferenceClient constructor. This was a legacy behavior, broken since 0.28.x release as passing token=False does not prevent the token from being used. Better to drop this feature altogether and raise an error if `token=False` is passed. See https://github.com/huggingface/huggingface_hub/pull/2853. """ with pytest.raises(ValueError): InferenceClient(token=False) class TestBillToOrganization: def test_bill_to_added_to_new_headers(self): client = InferenceClient(bill_to="huggingface_hub") assert client.headers["X-HF-Bill-To"] == "huggingface_hub" def test_bill_to_added_to_existing_headers(self): headers = {"foo": "bar"} client = InferenceClient(bill_to="huggingface_hub", headers=headers) assert client.headers["X-HF-Bill-To"] == "huggingface_hub" assert client.headers["foo"] == "bar" assert headers == {"foo": "bar"} # do not mutate the original headers def test_warning_if_bill_to_already_set(self): headers = {"X-HF-Bill-To": "huggingface"} with pytest.warns(UserWarning, match="Overriding existing 'huggingface' value in headers with 'openai'."): client = InferenceClient(bill_to="openai", headers=headers) assert client.headers["X-HF-Bill-To"] == "openai" assert headers == {"X-HF-Bill-To": "huggingface"} # do not mutate the original headers def test_warning_if_bill_to_with_direct_calls(self): with pytest.warns( UserWarning, match="You've provided an external provider's API key, so requests will be billed directly by the provider.", ): InferenceClient(bill_to="openai", token="replicate_key", provider="replicate") @pytest.mark.parametrize( "client_init_arg, init_kwarg_name, expected_request_url, expected_payload_model", [ # passing a custom endpoint in the model argument pytest.param( "https://my-custom-endpoint.com/custom_path", "model", "https://my-custom-endpoint.com/custom_path/v1/chat/completions", "dummy", id="client_model_is_url", ), # passing a custom endpoint in the base_url argument pytest.param( "https://another-endpoint.com/v1/", "base_url", "https://another-endpoint.com/v1/chat/completions", "dummy", id="client_base_url_is_url", ), # passing a model ID pytest.param( "username/repo_name", "model", "https://router.huggingface.co/hf-inference/models/username/repo_name/v1/chat/completions", "username/repo_name", id="client_model_is_id", ), # passing a custom endpoint in the model argument pytest.param( "https://specific-chat-endpoint.com/v1/chat/completions", "model", "https://specific-chat-endpoint.com/v1/chat/completions", "dummy", id="client_model_is_full_chat_url", ), # passing a localhost URL in the model argument pytest.param( "http://localhost:8080", "model", "http://localhost:8080/v1/chat/completions", "dummy", id="client_model_is_localhost_url", ), # passing a localhost URL in the base_url argument pytest.param( "http://127.0.0.1:8000/custom/path/v1", "base_url", "http://127.0.0.1:8000/custom/path/v1/chat/completions", "dummy", id="client_base_url_is_localhost_ip_with_path", ), ], ) def test_chat_completion_url_resolution( mocker, client_init_arg, init_kwarg_name, expected_request_url, expected_payload_model ): init_kwargs = {init_kwarg_name: client_init_arg, "provider": "hf-inference"} client = InferenceClient(**init_kwargs) mock_response_content = b'{"choices": [{"message": {"content": "Mock response"}}]}' mocker.patch( "huggingface_hub.inference._providers.hf_inference._check_supported_task", return_value=None, ) with patch.object(InferenceClient, "_inner_post", return_value=mock_response_content) as mock_inner_post: client.chat_completion(messages=[{"role": "user", "content": "Hello?"}], stream=False) mock_inner_post.assert_called_once() request_params = mock_inner_post.call_args[0][0] assert request_params.url == expected_request_url assert request_params.json is not None assert request_params.json.get("model") == expected_payload_model huggingface_hub-0.31.1/tests/test_inference_endpoints.py000066400000000000000000000267561500667546600235340ustar00rootroot00000000000000from datetime import datetime, timezone from itertools import chain, repeat from unittest.mock import MagicMock, Mock, patch import pytest from huggingface_hub import ( AsyncInferenceClient, HfApi, InferenceClient, InferenceEndpoint, InferenceEndpointError, InferenceEndpointTimeoutError, ) MOCK_INITIALIZING = { "name": "my-endpoint-name", "type": "protected", "accountId": None, "provider": {"vendor": "aws", "region": "us-east-1"}, "compute": { "accelerator": "cpu", "instanceType": "intel-icl", "instanceSize": "x2", "scaling": {"minReplica": 0, "maxReplica": 1}, }, "model": { "repository": "gpt2", "revision": "11c5a3d5811f50298f278a704980280950aedb10", "task": "text-generation", "framework": "pytorch", "image": {"huggingface": {}}, "secret": {"token": "my-token"}, }, "status": { "createdAt": "2023-10-26T12:41:53.263078506Z", "createdBy": {"id": "6273f303f6d63a28483fde12", "name": "Wauplin"}, "updatedAt": "2023-10-26T12:41:53.263079138Z", "updatedBy": {"id": "6273f303f6d63a28483fde12", "name": "Wauplin"}, "private": None, "state": "pending", "message": "Endpoint waiting to be scheduled", "readyReplica": 0, "targetReplica": 0, }, } MOCK_RUNNING = { "name": "my-endpoint-name", "type": "protected", "accountId": None, "provider": {"vendor": "aws", "region": "us-east-1"}, "compute": { "accelerator": "cpu", "instanceType": "intel-icl", "instanceSize": "x2", "scaling": {"minReplica": 0, "maxReplica": 1}, }, "model": { "repository": "gpt2", "revision": "11c5a3d5811f50298f278a704980280950aedb10", "task": "text-generation", "framework": "pytorch", "image": {"huggingface": {}}, "secrets": {"token": "my-token"}, }, "status": { "createdAt": "2023-10-26T12:41:53.263Z", "createdBy": {"id": "6273f303f6d63a28483fde12", "name": "Wauplin"}, "updatedAt": "2023-10-26T12:41:53.263Z", "updatedBy": {"id": "6273f303f6d63a28483fde12", "name": "Wauplin"}, "private": None, "state": "running", "message": "Endpoint is ready", "url": "https://vksrvs8pc1xnifhq.us-east-1.aws.endpoints.huggingface.cloud", "readyReplica": 1, "targetReplica": 1, }, } MOCK_FAILED = { "name": "my-endpoint-name", "type": "protected", "accountId": None, "provider": {"vendor": "aws", "region": "us-east-1"}, "compute": { "accelerator": "cpu", "instanceType": "intel-icl", "instanceSize": "x2", "scaling": {"minReplica": 0, "maxReplica": 1}, }, "model": { "repository": "gpt2", "revision": "11c5a3d5811f50298f278a704980280950aedb10", "task": "text-generation", "framework": "pytorch", "image": {"huggingface": {}}, "secrets": {"token": "my-token"}, }, "status": { "createdAt": "2023-10-26T12:41:53.263Z", "createdBy": {"id": "6273f303f6d63a28483fde12", "name": "Wauplin"}, "updatedAt": "2023-10-26T12:41:53.263Z", "updatedBy": {"id": "6273f303f6d63a28483fde12", "name": "Wauplin"}, "private": None, "state": "failed", "message": "Endpoint failed to deploy", "readyReplica": 0, "targetReplica": 1, }, } # added for test_wait_update function MOCK_UPDATE = { "name": "my-endpoint-name", "type": "protected", "accountId": None, "provider": {"vendor": "aws", "region": "us-east-1"}, "compute": { "accelerator": "cpu", "instanceType": "intel-icl", "instanceSize": "x2", "scaling": {"minReplica": 0, "maxReplica": 1}, }, "model": { "repository": "gpt2", "revision": "11c5a3d5811f50298f278a704980280950aedb10", "task": "text-generation", "framework": "pytorch", "image": {"huggingface": {}}, "secret": {"token": "my-token"}, }, "status": { "createdAt": "2023-10-26T12:41:53.263078506Z", "createdBy": {"id": "6273f303f6d63a28483fde12", "name": "Wauplin"}, "updatedAt": "2023-10-26T12:41:53.263079138Z", "updatedBy": {"id": "6273f303f6d63a28483fde12", "name": "Wauplin"}, "private": None, "state": "updating", "url": "https://vksrvs8pc1xnifhq.us-east-1.aws.endpoints.huggingface.cloud", "message": "Endpoint waiting for the update", "readyReplica": 0, "targetReplica": 1, }, } def test_from_raw_initialization(): """Test InferenceEndpoint is correctly initialized from raw dict.""" endpoint = InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo") # Main attributes parsed correctly assert endpoint.name == "my-endpoint-name" assert endpoint.namespace == "foo" assert endpoint.repository == "gpt2" assert endpoint.framework == "pytorch" assert endpoint.status == "pending" assert endpoint.revision == "11c5a3d5811f50298f278a704980280950aedb10" assert endpoint.task == "text-generation" assert endpoint.type == "protected" # Datetime parsed correctly assert endpoint.created_at == datetime(2023, 10, 26, 12, 41, 53, 263078, tzinfo=timezone.utc) assert endpoint.updated_at == datetime(2023, 10, 26, 12, 41, 53, 263079, tzinfo=timezone.utc) # Not initialized yet assert endpoint.url is None # Raw dict still accessible assert endpoint.raw == MOCK_INITIALIZING def test_from_raw_with_hf_api(): """Test that the HfApi is correctly passed to the InferenceEndpoint.""" endpoint = InferenceEndpoint.from_raw( MOCK_INITIALIZING, namespace="foo", api=HfApi(library_name="my-library", token="hf_***") ) assert endpoint._api.library_name == "my-library" assert endpoint._api.token == "hf_***" def test_get_client_not_ready(): """Test clients are not created when endpoint is not ready.""" endpoint = InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo") with pytest.raises(InferenceEndpointError): assert endpoint.client with pytest.raises(InferenceEndpointError): assert endpoint.async_client def test_get_client_ready(): """Test clients are created correctly when endpoint is ready.""" endpoint = InferenceEndpoint.from_raw(MOCK_RUNNING, namespace="foo", token="my-token") # Endpoint is ready assert endpoint.status == "running" assert endpoint.url == "https://vksrvs8pc1xnifhq.us-east-1.aws.endpoints.huggingface.cloud" # => Client available client = endpoint.client assert isinstance(client, InferenceClient) assert client.token == "my-token" # => AsyncClient available async_client = endpoint.async_client assert isinstance(async_client, AsyncInferenceClient) assert async_client.token == "my-token" @patch("huggingface_hub.hf_api.HfApi.get_inference_endpoint") def test_fetch(mock_get: Mock): endpoint = InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo") mock_get.return_value = InferenceEndpoint.from_raw(MOCK_RUNNING, namespace="foo") endpoint.fetch() assert endpoint.status == "running" assert endpoint.url == "https://vksrvs8pc1xnifhq.us-east-1.aws.endpoints.huggingface.cloud" @patch("huggingface_hub._inference_endpoints.get_session") @patch("huggingface_hub.hf_api.HfApi.get_inference_endpoint") def test_wait_until_running(mock_get: Mock, mock_session: Mock): """Test waits until the endpoint is ready.""" endpoint = InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo") mock_get.side_effect = [ InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo"), InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo"), InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo"), InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo"), InferenceEndpoint.from_raw(MOCK_RUNNING, namespace="foo"), InferenceEndpoint.from_raw(MOCK_RUNNING, namespace="foo"), ] mock_session.return_value = Mock() mock_session.return_value.get.side_effect = [ Mock(status_code=400), # url is provisioned but not yet ready Mock(status_code=200), # endpoint is ready ] endpoint.wait(refresh_every=0.01) assert endpoint.status == "running" assert len(mock_get.call_args_list) == 6 @patch("huggingface_hub.hf_api.HfApi.get_inference_endpoint") def test_wait_timeout(mock_get: Mock): """Test waits until timeout error is raised.""" endpoint = InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo") mock_get.side_effect = [ InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo"), InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo"), InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo"), InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo"), ] with pytest.raises(InferenceEndpointTimeoutError): endpoint.wait(timeout=0.1, refresh_every=0.05) assert endpoint.status == "pending" assert len(mock_get.call_args_list) == 2 @patch("huggingface_hub.hf_api.HfApi.get_inference_endpoint") def test_wait_failed(mock_get: Mock): """Test waits until timeout error is raised.""" endpoint = InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo") mock_get.side_effect = [ InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo"), InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo"), InferenceEndpoint.from_raw(MOCK_FAILED, namespace="foo"), ] with pytest.raises(InferenceEndpointError, match=".*failed to deploy.*"): endpoint.wait(refresh_every=0.001) @patch("huggingface_hub.hf_api.HfApi.get_inference_endpoint") @patch("huggingface_hub._inference_endpoints.get_session") def test_wait_update(mock_get_session, mock_get_inference_endpoint): """Test that wait() returns when the endpoint transitions to running.""" endpoint = InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo") # Create an iterator that yields three MOCK_UPDATE responses,and then infinitely yields MOCK_RUNNING responses. responses = chain( [InferenceEndpoint.from_raw(MOCK_UPDATE, namespace="foo")] * 3, repeat(InferenceEndpoint.from_raw(MOCK_RUNNING, namespace="foo")), ) mock_get_inference_endpoint.side_effect = lambda *args, **kwargs: next(responses) # Patch the get_session().get() call to always return a fake response with status_code 200. fake_response = MagicMock() fake_response.status_code = 200 mock_get_session.return_value.get.return_value = fake_response endpoint.wait(refresh_every=0.05) assert endpoint.status == "running" @patch("huggingface_hub.hf_api.HfApi.pause_inference_endpoint") def test_pause(mock: Mock): """Test `pause` calls the correct alias.""" endpoint = InferenceEndpoint.from_raw(MOCK_RUNNING, namespace="foo") mock.return_value = InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo") endpoint.pause() mock.assert_called_once_with(namespace="foo", name="my-endpoint-name", token=None) @patch("huggingface_hub.hf_api.HfApi.resume_inference_endpoint") def test_resume(mock: Mock): """Test `resume` calls the correct alias.""" endpoint = InferenceEndpoint.from_raw(MOCK_RUNNING, namespace="foo") mock.return_value = InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo") endpoint.resume() mock.assert_called_once_with(namespace="foo", name="my-endpoint-name", token=None, running_ok=True) huggingface_hub-0.31.1/tests/test_inference_providers.py000066400000000000000000001274321500667546600235370ustar00rootroot00000000000000import base64 import logging from typing import Dict from unittest.mock import patch import pytest from pytest import LogCaptureFixture from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters from huggingface_hub.inference._providers import PROVIDERS, get_provider_helper from huggingface_hub.inference._providers._common import ( BaseConversationalTask, BaseTextGenerationTask, TaskProviderHelper, recursive_merge, ) from huggingface_hub.inference._providers.black_forest_labs import BlackForestLabsTextToImageTask from huggingface_hub.inference._providers.cohere import CohereConversationalTask from huggingface_hub.inference._providers.fal_ai import ( _POLLING_INTERVAL, FalAIAutomaticSpeechRecognitionTask, FalAITextToImageTask, FalAITextToSpeechTask, FalAITextToVideoTask, ) from huggingface_hub.inference._providers.fireworks_ai import FireworksAIConversationalTask from huggingface_hub.inference._providers.hf_inference import ( HFInferenceBinaryInputTask, HFInferenceConversational, HFInferenceTask, ) from huggingface_hub.inference._providers.hyperbolic import HyperbolicTextGenerationTask, HyperbolicTextToImageTask from huggingface_hub.inference._providers.nebius import NebiusTextToImageTask from huggingface_hub.inference._providers.novita import NovitaConversationalTask, NovitaTextGenerationTask from huggingface_hub.inference._providers.openai import OpenAIConversationalTask from huggingface_hub.inference._providers.replicate import ReplicateTask, ReplicateTextToSpeechTask from huggingface_hub.inference._providers.sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask from huggingface_hub.inference._providers.together import TogetherTextToImageTask from .testing_utils import assert_in_logs class TestBasicTaskProviderHelper: def test_api_key_from_provider(self): helper = TaskProviderHelper(provider="provider-name", base_url="https://api.provider.com", task="task-name") assert helper._prepare_api_key("sk_provider_key") == "sk_provider_key" def test_api_key_routed(self, mocker): mocker.patch("huggingface_hub.inference._providers._common.get_token", return_value="hf_test_token") helper = TaskProviderHelper(provider="provider-name", base_url="https://api.provider.com", task="task-name") assert helper._prepare_api_key(None) == "hf_test_token" def test_api_key_missing(self): with patch("huggingface_hub.inference._providers._common.get_token", return_value=None): helper = TaskProviderHelper( provider="provider-name", base_url="https://api.provider.com", task="task-name" ) with pytest.raises(ValueError, match="You must provide an api_key.*"): helper._prepare_api_key(None) def test_prepare_mapping_info(self, mocker, caplog: LogCaptureFixture): helper = TaskProviderHelper(provider="provider-name", base_url="https://api.provider.com", task="task-name") caplog.set_level(logging.INFO) # Test missing model with pytest.raises(ValueError, match="Please provide an HF model ID.*"): helper._prepare_mapping_info(None) # Test unsupported model mocker.patch( "huggingface_hub.inference._providers._common._fetch_inference_provider_mapping", return_value={"other-provider": "mapping"}, ) with pytest.raises(ValueError, match="Model test-model is not supported.*"): helper._prepare_mapping_info("test-model") # Test task mismatch mocker.patch( "huggingface_hub.inference._providers._common._fetch_inference_provider_mapping", return_value={ "provider-name": mocker.Mock( task="other-task", providerId="mapped-id", hf_model_id="test-model", status="live", ) }, ) with pytest.raises(ValueError, match="Model test-model is not supported for task.*"): helper._prepare_mapping_info("test-model") # Test staging model mocker.patch( "huggingface_hub.inference._providers._common._fetch_inference_provider_mapping", return_value={ "provider-name": mocker.Mock( task="task-name", hf_model_id="test-model", provider_id="mapped-id", status="staging" ) }, ) assert helper._prepare_mapping_info("test-model").provider_id == "mapped-id" assert_in_logs( caplog, "Model test-model is in staging mode for provider provider-name. Meant for test purposes only." ) # Test successful mapping caplog.clear() mocker.patch( "huggingface_hub.inference._providers._common._fetch_inference_provider_mapping", return_value={ "provider-name": mocker.Mock( task="task-name", hf_model_id="test-model", provider_id="mapped-id", status="live" ) }, ) assert helper._prepare_mapping_info("test-model").provider_id == "mapped-id" assert helper._prepare_mapping_info("test-model").hf_model_id == "test-model" assert helper._prepare_mapping_info("test-model").task == "task-name" assert helper._prepare_mapping_info("test-model").status == "live" assert len(caplog.records) == 0 # Test with loras mocker.patch( "huggingface_hub.inference._providers._common._fetch_inference_provider_mapping", return_value={ "provider-name": mocker.Mock( task="task-name", hf_model_id="test-model", provider_id="mapped-id", status="live", adapter_weights_path="lora-weights-path", adapter="lora", ) }, ) assert helper._prepare_mapping_info("test-model").adapter_weights_path == "lora-weights-path" assert helper._prepare_mapping_info("test-model").provider_id == "mapped-id" assert helper._prepare_mapping_info("test-model").hf_model_id == "test-model" assert helper._prepare_mapping_info("test-model").task == "task-name" assert helper._prepare_mapping_info("test-model").status == "live" def test_prepare_headers(self): helper = TaskProviderHelper(provider="provider-name", base_url="https://api.provider.com", task="task-name") headers = helper._prepare_headers({"custom": "header"}, "api_key") assert "user-agent" in headers # From build_hf_headers assert headers["custom"] == "header" assert headers["authorization"] == "Bearer api_key" def test_prepare_url(self, mocker): helper = TaskProviderHelper(provider="provider-name", base_url="https://api.provider.com", task="task-name") mocker.patch.object(helper, "_prepare_route", return_value="/v1/test-route") # Test HF token routing url = helper._prepare_url("hf_test_token", "test-model") assert url == "https://router.huggingface.co/provider-name/v1/test-route" helper._prepare_route.assert_called_once_with("test-model", "hf_test_token") # Test direct API call helper._prepare_route.reset_mock() url = helper._prepare_url("sk_test_token", "test-model") assert url == "https://api.provider.com/v1/test-route" helper._prepare_route.assert_called_once_with("test-model", "sk_test_token") class TestBlackForestLabsProvider: def test_prepare_headers_bfl_key(self): helper = BlackForestLabsTextToImageTask() headers = helper._prepare_headers({}, "bfl_key") assert "authorization" not in headers assert headers["X-Key"] == "bfl_key" def test_prepare_headers_hf_key(self): """When using HF token, must use Bearer authorization.""" helper = BlackForestLabsTextToImageTask() headers = helper._prepare_headers({}, "hf_test_token") assert headers["authorization"] == "Bearer hf_test_token" assert "X-Key" not in headers def test_prepare_route(self): """Test route preparation.""" helper = BlackForestLabsTextToImageTask() assert helper._prepare_route("username/repo_name", "hf_token") == "/v1/username/repo_name" def test_prepare_url(self): helper = BlackForestLabsTextToImageTask() assert ( helper._prepare_url("hf_test_token", "username/repo_name") == "https://router.huggingface.co/black-forest-labs/v1/username/repo_name" ) def test_prepare_payload_as_dict(self): """Test payload preparation with parameter renaming.""" helper = BlackForestLabsTextToImageTask() payload = helper._prepare_payload_as_dict( "a beautiful cat", { "num_inference_steps": 30, "guidance_scale": 7.5, "width": 512, "height": 512, "seed": 42, }, "username/repo_name", ) assert payload == { "prompt": "a beautiful cat", "steps": 30, # renamed from num_inference_steps "guidance": 7.5, # renamed from guidance_scale "width": 512, "height": 512, "seed": 42, } def test_get_response_success(self, mocker): """Test successful response handling with polling.""" helper = BlackForestLabsTextToImageTask() mock_session = mocker.patch("huggingface_hub.inference._providers.black_forest_labs.get_session") mock_session.return_value.get.side_effect = [ mocker.Mock( json=lambda: {"status": "Ready", "result": {"sample": "https://example.com/image.jpg"}}, raise_for_status=lambda: None, ), mocker.Mock(content=b"image_bytes", raise_for_status=lambda: None), ] response = helper.get_response({"polling_url": "https://example.com/poll"}) assert response == b"image_bytes" assert mock_session.return_value.get.call_count == 2 mock_session.return_value.get.assert_has_calls( [ mocker.call("https://example.com/poll", headers={"Content-Type": "application/json"}), mocker.call("https://example.com/image.jpg"), ] ) class TestCohereConversationalTask: def test_prepare_url(self): helper = CohereConversationalTask() assert helper.task == "conversational" url = helper._prepare_url("cohere_token", "username/repo_name") assert url == "https://api.cohere.com/compatibility/v1/chat/completions" def test_prepare_payload_as_dict(self): helper = CohereConversationalTask() payload = helper._prepare_payload_as_dict( [{"role": "user", "content": "Hello!"}], {}, InferenceProviderMapping( hf_model_id="CohereForAI/command-r7b-12-2024", providerId="CohereForAI/command-r7b-12-2024", task="conversational", status="live", ), ) assert payload == { "messages": [{"role": "user", "content": "Hello!"}], "model": "CohereForAI/command-r7b-12-2024", } class TestFalAIProvider: def test_prepare_headers_fal_ai_key(self): """When using direct call, must use Key authorization.""" headers = FalAITextToImageTask()._prepare_headers({}, "fal_ai_key") assert headers["authorization"] == "Key fal_ai_key" def test_prepare_headers_hf_key(self): """When using routed call, must use Bearer authorization.""" headers = FalAITextToImageTask()._prepare_headers({}, "hf_token") assert headers["authorization"] == "Bearer hf_token" def test_prepare_url(self): url = FalAITextToImageTask()._prepare_url("hf_token", "username/repo_name") assert url == "https://router.huggingface.co/fal-ai/username/repo_name" def test_automatic_speech_recognition_payload(self): helper = FalAIAutomaticSpeechRecognitionTask() payload = helper._prepare_payload_as_dict("https://example.com/audio.mp3", {}, "username/repo_name") assert payload == {"audio_url": "https://example.com/audio.mp3"} payload = helper._prepare_payload_as_dict(b"dummy_audio_data", {}, "username/repo_name") assert payload == {"audio_url": f"data:audio/mpeg;base64,{base64.b64encode(b'dummy_audio_data').decode()}"} def test_automatic_speech_recognition_response(self): helper = FalAIAutomaticSpeechRecognitionTask() response = helper.get_response({"text": "Hello world"}) assert response == "Hello world" with pytest.raises(ValueError): helper.get_response({"text": 123}) def test_text_to_image_payload(self): helper = FalAITextToImageTask() payload = helper._prepare_payload_as_dict( "a beautiful cat", {"width": 512, "height": 512}, InferenceProviderMapping( hf_model_id="username/repo_name", providerId="username/repo_name", task="text-to-image", status="live", ), ) assert payload == { "prompt": "a beautiful cat", "image_size": {"width": 512, "height": 512}, } def test_text_to_image_response(self, mocker): helper = FalAITextToImageTask() mock = mocker.patch("huggingface_hub.inference._providers.fal_ai.get_session") response = helper.get_response({"images": [{"url": "image_url"}]}) mock.return_value.get.assert_called_once_with("image_url") assert response == mock.return_value.get.return_value.content def test_text_to_speech_payload(self): helper = FalAITextToSpeechTask() payload = helper._prepare_payload_as_dict("Hello world", {}, "username/repo_name") assert payload == {"text": "Hello world"} def test_text_to_speech_response(self, mocker): helper = FalAITextToSpeechTask() mock = mocker.patch("huggingface_hub.inference._providers.fal_ai.get_session") response = helper.get_response({"audio": {"url": "audio_url"}}) mock.return_value.get.assert_called_once_with("audio_url") assert response == mock.return_value.get.return_value.content def test_text_to_video_payload(self): helper = FalAITextToVideoTask() payload = helper._prepare_payload_as_dict("a cat walking", {"num_frames": 16}, "username/repo_name") assert payload == {"prompt": "a cat walking", "num_frames": 16} def test_text_to_video_response(self, mocker): helper = FalAITextToVideoTask() mock_session = mocker.patch("huggingface_hub.inference._providers.fal_ai.get_session") mock_sleep = mocker.patch("huggingface_hub.inference._providers.fal_ai.time.sleep") mock_session.return_value.get.side_effect = [ # First call: status mocker.Mock(json=lambda: {"status": "COMPLETED"}, headers={"Content-Type": "application/json"}), # Second call: get result mocker.Mock(json=lambda: {"video": {"url": "video_url"}}, headers={"Content-Type": "application/json"}), # Third call: get video content mocker.Mock(content=b"video_content"), ] api_key = helper._prepare_api_key("hf_token") headers = helper._prepare_headers({}, api_key) url = helper._prepare_url(api_key, "username/repo_name") request_params = RequestParameters( url=url, headers=headers, task="text-to-video", model="username/repo_name", data=None, json=None, ) response = helper.get_response( b'{"request_id": "test_request_id", "status": "PROCESSING", "response_url": "https://queue.fal.run/username_provider/repo_name_provider/requests/test_request_id", "status_url": "https://queue.fal.run/username_provider/repo_name_provider/requests/test_request_id/status"}', request_params, ) # Verify the correct URLs were called assert mock_session.return_value.get.call_count == 3 mock_session.return_value.get.assert_has_calls( [ mocker.call( "https://router.huggingface.co/fal-ai/username_provider/repo_name_provider/requests/test_request_id/status?_subdomain=queue", headers=request_params.headers, ), mocker.call( "https://router.huggingface.co/fal-ai/username_provider/repo_name_provider/requests/test_request_id?_subdomain=queue", headers=request_params.headers, ), mocker.call("video_url"), ] ) mock_sleep.assert_called_once_with(_POLLING_INTERVAL) assert response == b"video_content" class TestFireworksAIConversationalTask: def test_prepare_url(self): helper = FireworksAIConversationalTask() url = helper._prepare_url("fireworks_token", "username/repo_name") assert url == "https://api.fireworks.ai/inference/v1/chat/completions" def test_prepare_payload_as_dict(self): helper = FireworksAIConversationalTask() payload = helper._prepare_payload_as_dict( [{"role": "user", "content": "Hello!"}], {}, InferenceProviderMapping( hf_model_id="meta-llama/Llama-3.1-8B-Instruct", providerId="meta-llama/Llama-3.1-8B-Instruct", task="conversational", status="live", ), ) assert payload == { "messages": [{"role": "user", "content": "Hello!"}], "model": "meta-llama/Llama-3.1-8B-Instruct", } class TestHFInferenceProvider: def test_prepare_mapping_info(self, mocker): helper = HFInferenceTask("text-classification") mocker.patch( "huggingface_hub.inference._providers.hf_inference._check_supported_task", return_value=None, ) mocker.patch( "huggingface_hub.inference._providers.hf_inference._fetch_recommended_models", return_value={"text-classification": "username/repo_name"}, ) assert helper._prepare_mapping_info("username/repo_name").provider_id == "username/repo_name" assert helper._prepare_mapping_info(None).provider_id == "username/repo_name" assert helper._prepare_mapping_info("https://any-url.com").provider_id == "https://any-url.com" def test_prepare_mapping_info_unknown_task(self): with pytest.raises(ValueError, match="Task unknown-task has no recommended model for HF Inference."): HFInferenceTask("unknown-task")._prepare_mapping_info(None) def test_prepare_url(self): helper = HFInferenceTask("text-classification") assert ( helper._prepare_url("hf_test_token", "username/repo_name") == "https://router.huggingface.co/hf-inference/models/username/repo_name" ) assert helper._prepare_url("hf_test_token", "https://any-url.com") == "https://any-url.com" def test_prepare_url_feature_extraction(self): helper = HFInferenceTask("feature-extraction") assert ( helper._prepare_url("hf_test_token", "username/repo_name") == "https://router.huggingface.co/hf-inference/models/username/repo_name/pipeline/feature-extraction" ) def test_prepare_url_sentence_similarity(self): helper = HFInferenceTask("sentence-similarity") assert ( helper._prepare_url("hf_test_token", "username/repo_name") == "https://router.huggingface.co/hf-inference/models/username/repo_name/pipeline/sentence-similarity" ) def test_prepare_payload_as_dict(self): helper = HFInferenceTask("text-classification") mapping_info = InferenceProviderMapping( hf_model_id="username/repo_name", providerId="username/repo_name", task="text-classification", status="live", ) assert helper._prepare_payload_as_dict( "dummy text input", parameters={"a": 1, "b": None}, provider_mapping_info=mapping_info, ) == { "inputs": "dummy text input", "parameters": {"a": 1}, } with pytest.raises(ValueError, match="Unexpected binary input for task text-classification."): helper._prepare_payload_as_dict( b"dummy binary data", {}, mapping_info, ) def test_prepare_payload_as_bytes(self): helper = HFInferenceBinaryInputTask("image-classification") mapping_info = InferenceProviderMapping( hf_model_id="username/repo_name", providerId="username/repo_name", task="image-classification", status="live", ) assert ( helper._prepare_payload_as_bytes( b"dummy binary input", parameters={}, provider_mapping_info=mapping_info, extra_payload=None, ) == b"dummy binary input" ) assert ( helper._prepare_payload_as_bytes( b"dummy binary input", parameters={"a": 1, "b": None}, provider_mapping_info=mapping_info, extra_payload={"extra": "payload"}, ) == b'{"inputs": "ZHVtbXkgYmluYXJ5IGlucHV0", "parameters": {"a": 1}, "extra": "payload"}' # base64.b64encode(b"dummy binary input") ) def test_conversational_url(self): helper = HFInferenceConversational() helper._prepare_url( "hf_test_token", "username/repo_name" ) == "https://router.huggingface.co/hf-inference/models/username/repo_name/v1/chat/completions" helper._prepare_url("hf_test_token", "https://any-url.com") == "https://any-url.com/v1/chat/completions" helper._prepare_url("hf_test_token", "https://any-url.com/v1") == "https://any-url.com/v1/chat/completions" def test_prepare_request(self, mocker): mocker.patch( "huggingface_hub.inference._providers.hf_inference._check_supported_task", return_value=None, ) mocker.patch( "huggingface_hub.inference._providers.hf_inference._fetch_recommended_models", return_value={"text-classification": "username/repo_name"}, ) helper = HFInferenceTask("text-classification") request = helper.prepare_request( inputs="this is a dummy input", parameters={}, headers={}, model="username/repo_name", api_key="hf_test_token", ) assert request.url == "https://router.huggingface.co/hf-inference/models/username/repo_name" assert request.task == "text-classification" assert request.model == "username/repo_name" assert request.headers["authorization"] == "Bearer hf_test_token" assert request.json == {"inputs": "this is a dummy input", "parameters": {}} def test_prepare_request_conversational(self, mocker): mocker.patch( "huggingface_hub.inference._providers.hf_inference._check_supported_task", return_value=None, ) mocker.patch( "huggingface_hub.inference._providers.hf_inference._fetch_recommended_models", return_value={"text-classification": "username/repo_name"}, ) helper = HFInferenceConversational() request = helper.prepare_request( inputs=[{"role": "user", "content": "dummy text input"}], parameters={}, headers={}, model="username/repo_name", api_key="hf_test_token", ) assert ( request.url == "https://router.huggingface.co/hf-inference/models/username/repo_name/v1/chat/completions" ) assert request.task == "conversational" assert request.model == "username/repo_name" assert request.json == { "model": "username/repo_name", "messages": [{"role": "user", "content": "dummy text input"}], } @pytest.mark.parametrize( "mapped_model,parameters,expected_model", [ ( "username/repo_name", {}, "username/repo_name", ), # URL endpoint with model in parameters - use model from parameters ( "http://localhost:8000/v1/chat/completions", {"model": "username/repo_name"}, "username/repo_name", ), # URL endpoint without model - fallback to dummy ( "http://localhost:8000/v1/chat/completions", {}, "dummy", ), # HTTPS endpoint with model in parameters ( "https://api.example.com/v1/chat/completions", {"model": "username/repo_name"}, "username/repo_name", ), # URL endpoint with other parameters - should still use dummy ( "http://localhost:8000/v1/chat/completions", {"temperature": 0.7, "max_tokens": 100}, "dummy", ), ], ) def test_prepare_payload_as_dict_conversational(self, mapped_model, parameters, expected_model): helper = HFInferenceConversational() messages = [{"role": "user", "content": "Hello!"}] provider_mapping_info = InferenceProviderMapping( hf_model_id=mapped_model, providerId=mapped_model, task="conversational", status="live", ) payload = helper._prepare_payload_as_dict( inputs=messages, parameters=parameters, provider_mapping_info=provider_mapping_info, ) assert payload["model"] == expected_model assert payload["messages"] == messages @pytest.mark.parametrize( "pipeline_tag,tags,task,should_raise", [ # text-generation + no conversational tag -> only text-generation allowed ( "text-generation", [], "text-generation", False, ), ( "text-generation", [], "conversational", True, ), # text-generation + conversational tag -> both tasks allowed ( "text-generation", ["conversational"], "text-generation", False, ), ( "text-generation", ["conversational"], "conversational", False, ), # image-text-to-text + conversational tag -> only conversational allowed ( "image-text-to-text", ["conversational"], "conversational", False, ), ( "image-text-to-text", ["conversational"], "image-text-to-text", True, ), ( "image-text-to-text", [], "conversational", True, ), # text2text-generation only allowed for text-generation task ( "text2text-generation", [], "text-generation", False, ), ( "text2text-generation", [], "conversational", True, ), # Feature-extraction / sentence-similarity are interchangeable for HF Inference ( "sentence-similarity", ["tag1", "feature-extraction", "sentence-similarity"], "feature-extraction", False, ), ( "feature-extraction", ["tag1", "feature-extraction", "sentence-similarity"], "sentence-similarity", False, ), # if pipeline_tag is not feature-extraction or sentence-similarity, raise ("text-generation", ["tag1", "feature-extraction", "sentence-similarity"], "sentence-similarity", True), # Other tasks ( "audio-classification", [], "audio-classification", False, ), ( "audio-classification", [], "text-classification", True, ), ], ) def test_check_supported_task_scenarios(self, mocker, pipeline_tag, tags, task, should_raise): from huggingface_hub.inference._providers.hf_inference import _check_supported_task mock_model_info = mocker.Mock(pipeline_tag=pipeline_tag, tags=tags) mocker.patch("huggingface_hub.hf_api.HfApi.model_info", return_value=mock_model_info) if should_raise: with pytest.raises(ValueError): _check_supported_task("test-model", task) else: _check_supported_task("test-model", task) class TestHyperbolicProvider: def test_prepare_route(self): """Test route preparation for different tasks.""" helper = HyperbolicTextToImageTask() assert helper._prepare_route("username/repo_name", "hf_token") == "/v1/images/generations" helper = HyperbolicTextGenerationTask("text-generation") assert helper._prepare_route("username/repo_name", "hf_token") == "/v1/chat/completions" helper = HyperbolicTextGenerationTask("conversational") assert helper._prepare_route("username/repo_name", "hf_token") == "/v1/chat/completions" def test_prepare_payload_conversational(self): """Test payload preparation for conversational task.""" helper = HyperbolicTextGenerationTask("conversational") payload = helper._prepare_payload_as_dict( [{"role": "user", "content": "Hello!"}], {"temperature": 0.7}, InferenceProviderMapping( hf_model_id="meta-llama/Llama-3.2-3B-Instruct", providerId="meta-llama/Llama-3.2-3B-Instruct", task="conversational", status="live", ), ) assert payload == { "messages": [{"role": "user", "content": "Hello!"}], "temperature": 0.7, "model": "meta-llama/Llama-3.2-3B-Instruct", } def test_prepare_payload_text_to_image(self): """Test payload preparation for text-to-image task.""" helper = HyperbolicTextToImageTask() payload = helper._prepare_payload_as_dict( "a beautiful cat", { "num_inference_steps": 30, "guidance_scale": 7.5, "width": 512, "height": 512, "seed": 42, }, InferenceProviderMapping( hf_model_id="stabilityai/sdxl-turbo", providerId="stabilityai/sdxl", task="text-to-image", status="live", ), ) assert payload == { "prompt": "a beautiful cat", "steps": 30, # renamed from num_inference_steps "cfg_scale": 7.5, # renamed from guidance_scale "width": 512, "height": 512, "seed": 42, "model_name": "stabilityai/sdxl", } def test_text_to_image_get_response(self): """Test response handling for text-to-image task.""" helper = HyperbolicTextToImageTask() dummy_image = b"image_bytes" response = helper.get_response({"images": [{"image": base64.b64encode(dummy_image).decode()}]}) assert response == dummy_image class TestNebiusProvider: def test_prepare_route_text_to_image(self): helper = NebiusTextToImageTask() assert helper._prepare_route("username/repo_name", "hf_token") == "/v1/images/generations" def test_prepare_payload_as_dict_text_to_image(self): helper = NebiusTextToImageTask() payload = helper._prepare_payload_as_dict( "a beautiful cat", {"num_inference_steps": 10, "width": 512, "height": 512, "guidance_scale": 7.5}, InferenceProviderMapping( hf_model_id="black-forest-labs/flux-schnell", providerId="black-forest-labs/flux-schnell", task="text-to-image", status="live", ), ) assert payload == { "prompt": "a beautiful cat", "response_format": "b64_json", "width": 512, "height": 512, "num_inference_steps": 10, "model": "black-forest-labs/flux-schnell", } def test_text_to_image_get_response(self): helper = NebiusTextToImageTask() response = helper.get_response({"data": [{"b64_json": base64.b64encode(b"image_bytes").decode()}]}) assert response == b"image_bytes" class TestNovitaProvider: def test_prepare_url_text_generation(self): helper = NovitaTextGenerationTask() url = helper._prepare_url("novita_token", "username/repo_name") assert url == "https://api.novita.ai/v3/openai/completions" def test_prepare_url_conversational(self): helper = NovitaConversationalTask() url = helper._prepare_url("novita_token", "username/repo_name") assert url == "https://api.novita.ai/v3/openai/chat/completions" class TestOpenAIProvider: def test_prepare_url(self): helper = OpenAIConversationalTask() assert helper._prepare_url("sk-XXXXXX", "gpt-4o-mini") == "https://api.openai.com/v1/chat/completions" class TestReplicateProvider: def test_prepare_headers(self): helper = ReplicateTask("text-to-image") headers = helper._prepare_headers({}, "my_replicate_key") headers["Prefer"] == "wait" headers["authorization"] == "Bearer my_replicate_key" def test_prepare_route(self): helper = ReplicateTask("text-to-image") # No model version url = helper._prepare_route("black-forest-labs/FLUX.1-schnell", "hf_token") assert url == "/v1/models/black-forest-labs/FLUX.1-schnell/predictions" # Model with specific version url = helper._prepare_route("black-forest-labs/FLUX.1-schnell:1944af04d098ef", "hf_token") assert url == "/v1/predictions" def test_prepare_payload_as_dict(self): helper = ReplicateTask("text-to-image") # No model version payload = helper._prepare_payload_as_dict( "a beautiful cat", {"num_inference_steps": 20}, InferenceProviderMapping( hf_model_id="black-forest-labs/FLUX.1-schnell", providerId="black-forest-labs/FLUX.1-schnell", task="text-to-image", status="live", ), ) assert payload == {"input": {"prompt": "a beautiful cat", "num_inference_steps": 20}} # Model with specific version payload = helper._prepare_payload_as_dict( "a beautiful cat", {"num_inference_steps": 20}, InferenceProviderMapping( hf_model_id="black-forest-labs/FLUX.1-schnell", providerId="black-forest-labs/FLUX.1-schnell:1944af04d098ef", task="text-to-image", status="live", ), ) assert payload == { "input": {"prompt": "a beautiful cat", "num_inference_steps": 20}, "version": "1944af04d098ef", } def test_text_to_speech_payload(self): helper = ReplicateTextToSpeechTask() payload = helper._prepare_payload_as_dict( "Hello world", {}, InferenceProviderMapping( hf_model_id="hexgrad/Kokoro-82M", providerId="hexgrad/Kokoro-82M:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13", task="text-to-speech", status="live", ), ) assert payload == { "input": {"text": "Hello world"}, "version": "f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13", } def test_get_response_timeout(self): helper = ReplicateTask("text-to-image") with pytest.raises(TimeoutError, match="Inference request timed out after 60 seconds."): helper.get_response({"model": "black-forest-labs/FLUX.1-schnell"}) # no 'output' key def test_get_response_single_output(self, mocker): helper = ReplicateTask("text-to-image") mock = mocker.patch("huggingface_hub.inference._providers.replicate.get_session") response = helper.get_response({"output": "https://example.com/image.jpg"}) mock.return_value.get.assert_called_once_with("https://example.com/image.jpg") assert response == mock.return_value.get.return_value.content class TestSambanovaProvider: def test_prepare_url_conversational(self): helper = SambanovaConversationalTask() assert ( helper._prepare_url("sambanova_token", "username/repo_name") == "https://api.sambanova.ai/v1/chat/completions" ) def test_prepare_payload_as_dict_feature_extraction(self): helper = SambanovaFeatureExtractionTask() payload = helper._prepare_payload_as_dict( "Hello world", {"truncate": True}, InferenceProviderMapping( hf_model_id="username/repo_name", providerId="provider-id", task="feature-extraction", status="live", ), ) assert payload == {"input": "Hello world", "model": "provider-id", "truncate": True} def test_prepare_url_feature_extraction(self): helper = SambanovaFeatureExtractionTask() assert ( helper._prepare_url("hf_token", "username/repo_name") == "https://router.huggingface.co/sambanova/v1/embeddings" ) class TestTogetherProvider: def test_prepare_route_text_to_image(self): helper = TogetherTextToImageTask() assert helper._prepare_route("username/repo_name", "hf_token") == "/v1/images/generations" def test_prepare_payload_as_dict_text_to_image(self): helper = TogetherTextToImageTask() payload = helper._prepare_payload_as_dict( "a beautiful cat", {"num_inference_steps": 10, "guidance_scale": 1, "width": 512, "height": 512}, InferenceProviderMapping( hf_model_id="black-forest-labs/FLUX.1-schnell", providerId="black-forest-labs/FLUX.1-schnell", task="text-to-image", status="live", ), ) assert payload == { "prompt": "a beautiful cat", "response_format": "base64", "width": 512, "height": 512, "steps": 10, # renamed field "guidance": 1, # renamed field "model": "black-forest-labs/FLUX.1-schnell", } def test_text_to_image_get_response(self): helper = TogetherTextToImageTask() response = helper.get_response({"data": [{"b64_json": base64.b64encode(b"image_bytes").decode()}]}) assert response == b"image_bytes" class TestBaseConversationalTask: def test_prepare_route(self): helper = BaseConversationalTask(provider="test-provider", base_url="https://api.test.com") assert helper._prepare_route("dummy-model", "hf_token") == "/v1/chat/completions" assert helper.task == "conversational" def test_prepare_payload(self): helper = BaseConversationalTask(provider="test-provider", base_url="https://api.test.com") messages = [{"role": "user", "content": "Hello!"}] parameters = {"temperature": 0.7, "max_tokens": 100} payload = helper._prepare_payload_as_dict( inputs=messages, parameters=parameters, provider_mapping_info=InferenceProviderMapping( hf_model_id="test-model", providerId="test-provider-id", task="conversational", status="live", ), ) assert payload == { "messages": messages, "temperature": 0.7, "max_tokens": 100, "model": "test-provider-id", } class TestBaseTextGenerationTask: def test_prepare_route(self): helper = BaseTextGenerationTask(provider="test-provider", base_url="https://api.test.com") assert helper._prepare_route("dummy-model", "hf_token") == "/v1/completions" assert helper.task == "text-generation" def test_prepare_payload(self): helper = BaseTextGenerationTask(provider="test-provider", base_url="https://api.test.com") prompt = "Once upon a time" parameters = {"temperature": 0.7, "max_tokens": 100} payload = helper._prepare_payload_as_dict( inputs=prompt, parameters=parameters, provider_mapping_info=InferenceProviderMapping( hf_model_id="test-model", providerId="test-provider-id", task="text-generation", status="live", ), ) assert payload == { "prompt": prompt, "temperature": 0.7, "max_tokens": 100, "model": "test-provider-id", } @pytest.mark.parametrize( "dict1, dict2, expected", [ # Basic merge with non-overlapping keys ({"a": 1}, {"b": 2}, {"a": 1, "b": 2}), # Overwriting a key ({"a": 1}, {"a": 2}, {"a": 2}), # Empty dict merge ({}, {"a": 1}, {"a": 1}), ({"a": 1}, {}, {"a": 1}), ({}, {}, {}), # Nested dictionary merge ( {"a": {"b": 1}}, {"a": {"c": 2}}, {"a": {"b": 1, "c": 2}}, ), # Overwriting nested dictionary key ( {"a": {"b": 1}}, {"a": {"b": 2}}, {"a": {"b": 2}}, ), # Deep merge ( {"a": {"b": {"c": 1}}}, {"a": {"b": {"d": 2}}}, {"a": {"b": {"c": 1, "d": 2}}}, ), # Overwriting a nested value with a non-dict type ( {"a": {"b": {"c": 1}}}, {"a": {"b": 2}}, {"a": {"b": 2}}, # Overwrites dict with integer ), # Merging dictionaries with different types ( {"a": 1}, {"a": {"b": 2}}, {"a": {"b": 2}}, # Overwrites int with dict ), ], ) def test_recursive_merge(dict1: Dict, dict2: Dict, expected: Dict): initial_dict1 = dict1.copy() initial_dict2 = dict2.copy() assert recursive_merge(dict1, dict2) == expected # does not mutate the inputs assert dict1 == initial_dict1 assert dict2 == initial_dict2 def test_get_provider_helper_auto(mocker): """Test the 'auto' provider selection logic.""" mock_provider_a_helper = mocker.Mock(spec=TaskProviderHelper) mock_provider_b_helper = mocker.Mock(spec=TaskProviderHelper) PROVIDERS["provider-a"] = {"test-task": mock_provider_a_helper} PROVIDERS["provider-b"] = {"test-task": mock_provider_b_helper} mocker.patch( "huggingface_hub.inference._providers._fetch_inference_provider_mapping", return_value={ "provider-a": mocker.Mock(), "provider-b": mocker.Mock(), }, ) helper = get_provider_helper(provider="auto", task="test-task", model="test-model") # The helper should be the one from provider-a assert helper is mock_provider_a_helper PROVIDERS.pop("provider-a", None) PROVIDERS.pop("provider-b", None) huggingface_hub-0.31.1/tests/test_inference_text_generation.py000066400000000000000000000162631500667546600247200ustar00rootroot00000000000000# Original implementation taken from the `text-generation` Python client (see https://pypi.org/project/text-generation/ # and https://github.com/huggingface/text-generation-inference/tree/main/clients/python) # # See './src/huggingface_hub/inference/_text_generation.py' for details. import json import unittest from typing import Dict from unittest.mock import MagicMock, patch import pytest from requests import HTTPError from huggingface_hub import InferenceClient, TextGenerationOutputPrefillToken from huggingface_hub.inference._common import ( _UNSUPPORTED_TEXT_GENERATION_KWARGS, GenerationError, IncompleteGenerationError, OverloadedError, raise_text_generation_error, ) from huggingface_hub.inference._common import ValidationError as TextGenerationValidationError from .testing_utils import with_production_testing class TestTextGenerationErrors(unittest.TestCase): def test_generation_error(self): error = _mocked_error({"error_type": "generation", "error": "test"}) with self.assertRaises(GenerationError): raise_text_generation_error(error) def test_incomplete_generation_error(self): error = _mocked_error({"error_type": "incomplete_generation", "error": "test"}) with self.assertRaises(IncompleteGenerationError): raise_text_generation_error(error) def test_overloaded_error(self): error = _mocked_error({"error_type": "overloaded", "error": "test"}) with self.assertRaises(OverloadedError): raise_text_generation_error(error) def test_validation_error(self): error = _mocked_error({"error_type": "validation", "error": "test"}) with self.assertRaises(TextGenerationValidationError): raise_text_generation_error(error) def _mocked_error(payload: Dict) -> MagicMock: error = HTTPError(response=MagicMock()) error.response.json.return_value = payload return error @pytest.mark.skip("Temporary skipping TestTextGenerationClientVCR tests") @with_production_testing @patch.dict("huggingface_hub.inference._common._UNSUPPORTED_TEXT_GENERATION_KWARGS", {}) class TestTextGenerationClientVCR(unittest.TestCase): """Use VCR test to avoid making requests to the prod infra.""" def setUp(self) -> None: self.client = InferenceClient(model="google/flan-t5-xxl") return super().setUp() def test_generate_no_details(self): response = self.client.text_generation("test", details=False, max_new_tokens=1) assert response == "" def test_generate_with_details(self): response = self.client.text_generation("test", details=True, max_new_tokens=1, decoder_input_details=True) assert response.generated_text == "" assert response.details.finish_reason == "length" assert response.details.generated_tokens == 1 assert response.details.seed is None assert len(response.details.prefill) == 1 assert response.details.prefill[0] == TextGenerationOutputPrefillToken(id=0, text="", logprob=None) assert len(response.details.tokens) == 1 assert response.details.tokens[0].id == 3 assert response.details.tokens[0].text == " " assert not response.details.tokens[0].special def test_generate_best_of(self): response = self.client.text_generation( "test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True, details=True ) assert response.details.seed is not None assert response.details.best_of_sequences is not None assert len(response.details.best_of_sequences) == 1 assert response.details.best_of_sequences[0].seed is not None def test_generate_validation_error(self): with self.assertRaises(TextGenerationValidationError): self.client.text_generation("test", max_new_tokens=10_000) def test_generate_stream_no_details(self): responses = [ response for response in self.client.text_generation("test", max_new_tokens=1, stream=True, details=True) ] assert len(responses) == 1 response = responses[0] assert response.generated_text == "" assert response.details.finish_reason == "length" assert response.details.generated_tokens == 1 assert response.details.seed is None def test_generate_stream_with_details(self): responses = [ response for response in self.client.text_generation("test", max_new_tokens=1, stream=True, details=True) ] assert len(responses) == 1 response = responses[0] assert response.generated_text == "" assert response.details.finish_reason == "length" assert response.details.generated_tokens == 1 assert response.details.seed is None def test_generate_non_tgi_endpoint(self): text = self.client.text_generation("0 1 2", model="gpt2", max_new_tokens=10) self.assertEqual(text, " 3 4 5 6 7 8 9 10 11 12") self.assertIn("gpt2", _UNSUPPORTED_TEXT_GENERATION_KWARGS) # Watermark is ignored (+ warning) with self.assertWarns(UserWarning): self.client.text_generation("4 5 6", model="gpt2", max_new_tokens=10, watermark=True) # Return as detail even if details=True (+ warning) with self.assertWarns(UserWarning): text = self.client.text_generation("0 1 2", model="gpt2", max_new_tokens=10, details=True) self.assertIsInstance(text, str) # Return as stream raises error with self.assertRaises(ValueError): self.client.text_generation("0 1 2", model="gpt2", max_new_tokens=10, stream=True) def test_generate_non_tgi_endpoint_regression_test(self): # Regression test for https://github.com/huggingface/huggingface_hub/issues/2135 with self.assertWarnsRegex(UserWarning, "Ignoring following parameters: return_full_text"): text = self.client.text_generation( prompt="How are you today?", max_new_tokens=20, model="google/flan-t5-large", return_full_text=True ) assert text == "I am at work" def test_generate_with_grammar(self): # Example taken from https://huggingface.co/docs/text-generation-inference/conceptual/guidance#the-grammar-parameter response = self.client.text_generation( prompt="I saw a puppy a cat and a raccoon during my bike ride in the park", max_new_tokens=100, model="HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1", repetition_penalty=1.3, grammar={ "type": "json", "value": { "properties": { "location": {"type": "string"}, "activity": {"type": "string"}, "animals_seen": {"type": "integer", "minimum": 1, "maximum": 5}, "animals": {"type": "array", "items": {"type": "string"}}, }, "required": ["location", "activity", "animals_seen", "animals"], }, }, ) assert json.loads(response) == { "activity": "biking", "animals": [], "animals_seen": 3, "location": "park", } huggingface_hub-0.31.1/tests/test_inference_types.py000066400000000000000000000127551500667546600226670ustar00rootroot00000000000000import inspect import json from typing import List, Optional, Union, get_args, get_origin import pytest import huggingface_hub.inference._generated.types as types from huggingface_hub.inference._generated.types import AutomaticSpeechRecognitionParameters from huggingface_hub.inference._generated.types.base import BaseInferenceType, dataclass_with_extra @dataclass_with_extra class DummyType(BaseInferenceType): foo: int bar: str @dataclass_with_extra class DummyNestedType(BaseInferenceType): item: DummyType items: List[DummyType] maybe_items: Optional[List[DummyType]] = None DUMMY_AS_DICT = {"foo": 42, "bar": "baz"} DUMMY_AS_STR = json.dumps(DUMMY_AS_DICT) DUMMY_AS_BYTES = DUMMY_AS_STR.encode() DUMMY_AS_LIST = [DUMMY_AS_DICT] def test_parse_from_bytes(): instance = DummyType.parse_obj(DUMMY_AS_BYTES) assert instance.foo == 42 assert instance.bar == "baz" def test_parse_from_str(): instance = DummyType.parse_obj(DUMMY_AS_STR) assert instance.foo == 42 assert instance.bar == "baz" def test_parse_from_dict(): instance = DummyType.parse_obj(DUMMY_AS_DICT) assert instance.foo == 42 assert instance.bar == "baz" def test_parse_from_list(): instances = DummyType.parse_obj(DUMMY_AS_LIST) assert len(instances) == 1 assert instances[0].foo == 42 assert instances[0].bar == "baz" def test_parse_from_unexpected_type(): with pytest.raises(ValueError): DummyType.parse_obj(42) def test_parse_as_instance_success(): instance = DummyType.parse_obj_as_instance(DUMMY_AS_DICT) assert isinstance(instance, DummyType) def test_parse_as_instance_failure(): with pytest.raises(ValueError): DummyType.parse_obj_as_instance(DUMMY_AS_LIST) def test_parse_as_list_success(): instances = DummyType.parse_obj_as_list(DUMMY_AS_LIST) assert len(instances) == 1 def test_parse_as_list_failure(): with pytest.raises(ValueError): DummyType.parse_obj_as_list(DUMMY_AS_DICT) def test_parse_nested_class(): instance = DummyNestedType.parse_obj( { "item": DUMMY_AS_DICT, "items": DUMMY_AS_LIST, "maybe_items": None, } ) assert instance.item.foo == 42 assert instance.item.bar == "baz" assert len(instance.items) == 1 assert instance.items[0].foo == 42 assert instance.items[0].bar == "baz" assert instance.maybe_items is None def test_all_fields_are_optional(): # all fields are optional => silently accept None if server returns less data than expected instance = DummyNestedType.parse_obj({"maybe_items": [{}, DUMMY_AS_BYTES]}) assert instance.item is None assert instance.items is None assert len(instance.maybe_items) == 2 assert instance.maybe_items[0].foo is None assert instance.maybe_items[0].bar is None assert instance.maybe_items[1].foo == 42 assert instance.maybe_items[1].bar == "baz" def test_normalize_keys(): # all fields are normalized in the dataclasses (by convention) # if server response uses different keys, they will be normalized instance = DummyNestedType.parse_obj({"ItEm": DUMMY_AS_DICT, "Maybe-Items": [DUMMY_AS_DICT]}) assert isinstance(instance.item, DummyType) assert isinstance(instance.maybe_items, list) assert len(instance.maybe_items) == 1 assert isinstance(instance.maybe_items[0], DummyType) def test_optional_are_set_to_none(): for _type in types.BaseInferenceType.__subclasses__(): parameters = inspect.signature(_type).parameters for parameter in parameters.values(): if _is_optional(parameter.annotation): assert parameter.default is None, f"Parameter {parameter} of {_type} should be set to None" def test_none_inferred(): """Regression test for https://github.com/huggingface/huggingface_hub/pull/2095""" # Doing this should not fail with # TypeError: __init__() missing 2 required positional arguments: 'generate' and 'return_timestamps' AutomaticSpeechRecognitionParameters() def test_other_fields_are_set(): instance = DummyNestedType.parse_obj( { "item": DUMMY_AS_DICT, "extra": "value", "items": [{"foo": 42, "another_extra": "value", "bar": "baz"}], "maybe_items": None, } ) assert instance.extra == "value" assert instance.items[0].another_extra == "value" assert str(instance.items[0]) == "DummyType(foo=42, bar='baz', another_extra='value')" # extra field always last assert ( repr(instance) # works both with __str__ and __repr__ == ( "DummyNestedType(" "item=DummyType(foo=42, bar='baz'), " "items=[DummyType(foo=42, bar='baz', another_extra='value')], " "maybe_items=None, extra='value'" ")" ) ) def test_other_fields_not_proper_dataclass_fields(): instance_1 = DummyType.parse_obj({"foo": 42, "bar": "baz", "extra": "value1"}) instance_2 = DummyType.parse_obj({"foo": 42, "bar": "baz", "extra": "value2", "another_extra": "value2.1"}) assert instance_1.extra == "value1" assert instance_2.extra == "value2" assert instance_2.another_extra == "value2.1" # extra fields are not part of the dataclass fields # all dataclass methods except __repr__ should work as if the extra fields were not there assert instance_1 == instance_2 def _is_optional(field) -> bool: # Taken from https://stackoverflow.com/a/58841311 return get_origin(field) is Union and type(None) in get_args(field) huggingface_hub-0.31.1/tests/test_init_lazy_loading.py000066400000000000000000000036031500667546600231740ustar00rootroot00000000000000import unittest import jedi class TestHuggingfaceHubInit(unittest.TestCase): @unittest.skip( reason="`jedi.Completion.get_signatures()` output differs between Python 3.12 and earlier versions, affecting test consistency" ) def test_autocomplete_on_root_imports(self) -> None: """Test autocomplete with `huggingface_hub` works with Jedi. Not all autocomplete systems are based on Jedi but if this one works we can assume others do as well. """ source = """from huggingface_hub import c""" script = jedi.Script(source, path="example.py") completions = script.complete(1, len(source)) for completion in completions: if completion.name == "create_commit": # Assert `create_commit` is suggestion from `huggingface_hub` lib self.assertEqual(completion.module_name, "huggingface_hub") # Assert autocomplete knows where `create_commit` lives # It would not be the case with a dynamic import. goto_list = completion.goto() self.assertEqual(len(goto_list), 1) # Assert docstring is find. This means autocomplete can also provide # the help section. signature_list = goto_list[0].get_signatures() self.assertEqual(len(signature_list), 2) # create_commit has 2 signatures (normal and `run_as_future`) self.assertTrue(signature_list[0].docstring().startswith("create_commit(repo_id: str,")) break else: self.fail( "Jedi autocomplete did not suggest `create_commit` to complete the" f" line `{source}`. It is most probable that static imports are not" " correct in `./src/huggingface_hub/__init__.py`. Please run `make" " style` to fix this." ) huggingface_hub-0.31.1/tests/test_keras_integration.py000066400000000000000000000302511500667546600232040ustar00rootroot00000000000000import json import os import unittest from pathlib import Path import pytest from huggingface_hub import HfApi, hf_hub_download, snapshot_download from huggingface_hub.keras_mixin import ( KerasModelHubMixin, from_pretrained_keras, push_to_hub_keras, save_pretrained_keras, ) from huggingface_hub.utils import is_graphviz_available, is_pydot_available, is_tf_available, logging from .testing_constants import ENDPOINT_STAGING, TOKEN, USER from .testing_utils import repo_name logger = logging.get_logger(__name__) if is_tf_available(): import tensorflow as tf def require_tf(test_case): """ Decorator marking a test that requires TensorFlow, graphviz and pydot. These tests are skipped when TensorFlow, graphviz and pydot are installed. """ if not is_tf_available() or not is_pydot_available() or not is_graphviz_available(): return unittest.skip("test requires Tensorflow, graphviz and pydot.")(test_case) else: return test_case if is_tf_available(): # Define dummy mixin model... class DummyModel(tf.keras.Model, KerasModelHubMixin): def __init__(self, **kwargs): super().__init__() self.l1 = tf.keras.layers.Dense(2, activation="relu") dummy_batch_size = input_dim = 2 self.dummy_inputs = tf.ones([dummy_batch_size, input_dim]) def call(self, x): return self.l1(x) else: DummyModel = None @require_tf @pytest.mark.usefixtures("fx_cache_dir") class CommonKerasTest(unittest.TestCase): cache_dir: Path @classmethod def setUpClass(cls): """ Share this valid token in all tests below. """ cls._api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) class HubMixinTestKeras(CommonKerasTest): def test_save_pretrained(self): model = DummyModel() model(model.dummy_inputs) model.save_pretrained(self.cache_dir) files = os.listdir(self.cache_dir) self.assertTrue("saved_model.pb" in files) self.assertTrue("keras_metadata.pb" in files) self.assertTrue("README.md" in files) self.assertTrue("model.png" in files) self.assertEqual(len(files), 7) model.save_pretrained(self.cache_dir, config={"num": 12, "act": "gelu"}) files = os.listdir(self.cache_dir) self.assertTrue("config.json" in files) self.assertTrue("saved_model.pb" in files) self.assertEqual(len(files), 8) def test_keras_from_pretrained_weights(self): model = DummyModel() model(model.dummy_inputs) model.save_pretrained(self.cache_dir) new_model = DummyModel.from_pretrained(self.cache_dir) # Check the reloaded model's weights match the original model's weights self.assertTrue(tf.reduce_all(tf.equal(new_model.weights[0], model.weights[0]))) # Check a new model's weights are not the same as the reloaded model's weights another_model = DummyModel() another_model(tf.ones([2, 2])) self.assertFalse(tf.reduce_all(tf.equal(new_model.weights[0], another_model.weights[0])).numpy().item()) def test_abs_path_from_pretrained(self): model = DummyModel() model(model.dummy_inputs) model.save_pretrained(self.cache_dir, config={"num": 10, "act": "gelu_fast"}) model = DummyModel.from_pretrained(self.cache_dir) self.assertTrue(model.config == {"num": 10, "act": "gelu_fast"}) def test_push_to_hub_keras_mixin_via_http_basic(self): repo_id = f"{USER}/{repo_name()}" model = DummyModel() model(model.dummy_inputs) model.push_to_hub(repo_id=repo_id, token=TOKEN, config={"num": 7, "act": "gelu_fast"}) # Test model id exists assert self._api.model_info(repo_id).id == repo_id # Test config has been pushed to hub config_path = hf_hub_download( repo_id=repo_id, filename="config.json", use_auth_token=TOKEN, cache_dir=self.cache_dir ) with open(config_path) as f: assert json.load(f) == {"num": 7, "act": "gelu_fast"} # Delete tmp file and repo self._api.delete_repo(repo_id=repo_id) @require_tf class HubKerasSequentialTest(CommonKerasTest): def model_init(self): model = tf.keras.models.Sequential() model.add(tf.keras.layers.Dense(2, activation="relu")) model.compile(optimizer="adam", loss="mse") return model def model_fit(self, model): x = tf.constant([[0.44, 0.90], [0.65, 0.39]]) y = tf.constant([[1, 1], [0, 0]]) model.fit(x, y) return model def test_save_pretrained(self): model = self.model_init() with pytest.raises(ValueError, match="Model should be built*"): save_pretrained_keras(model, save_directory=self.cache_dir) model.build((None, 2)) save_pretrained_keras(model, save_directory=self.cache_dir) files = os.listdir(self.cache_dir) self.assertIn("saved_model.pb", files) self.assertIn("keras_metadata.pb", files) self.assertIn("model.png", files) self.assertIn("README.md", files) self.assertEqual(len(files), 7) loaded_model = from_pretrained_keras(self.cache_dir) self.assertIsNone(loaded_model.optimizer) def test_save_pretrained_model_card_fit(self): model = self.model_init() model = self.model_fit(model) save_pretrained_keras(model, save_directory=self.cache_dir) files = os.listdir(self.cache_dir) history = json.loads((self.cache_dir / "history.json").read_text()) self.assertIn("saved_model.pb", files) self.assertIn("keras_metadata.pb", files) self.assertIn("model.png", files) self.assertIn("README.md", files) self.assertIn("history.json", files) self.assertEqual(history, model.history.history) self.assertEqual(len(files), 8) def test_save_model_card_history_removal(self): model = self.model_init() model = self.model_fit(model) history_path = self.cache_dir / "history.json" history_path.write_text("Keras FTW") with pytest.warns(UserWarning, match="`history.json` file already exists, *"): save_pretrained_keras(model, save_directory=self.cache_dir) # assert that it's not the same as old history file and it's overridden self.assertNotEqual("Keras FTW", history_path.read_text()) # Check the history is saved as a json in the repository. files = os.listdir(self.cache_dir) self.assertIn("history.json", files) # Check that there is no "Training Metrics" section in the model card. # This was done in an older version. self.assertNotIn("Training Metrics", (self.cache_dir / "README.md").read_text()) def test_save_pretrained_optimizer_state(self): model = self.model_init() model.build((None, 2)) save_pretrained_keras(model, self.cache_dir, include_optimizer=True) loaded_model = from_pretrained_keras(self.cache_dir) self.assertIsNotNone(loaded_model.optimizer) def test_from_pretrained_weights(self): model = self.model_init() model.build((None, 2)) save_pretrained_keras(model, self.cache_dir) new_model = from_pretrained_keras(self.cache_dir) # Check a new model's weights are not the same as the reloaded model's weights another_model = DummyModel() another_model(tf.ones([2, 2])) self.assertFalse(tf.reduce_all(tf.equal(new_model.weights[0], another_model.weights[0])).numpy().item()) def test_save_pretrained_task_name_deprecation(self): model = self.model_init() model.build((None, 2)) with pytest.warns( FutureWarning, match="`task_name` input argument is deprecated. Pass `tags` instead.", ): save_pretrained_keras(model, self.cache_dir, tags=["test"], task_name="test", save_traces=True) def test_abs_path_from_pretrained(self): model = self.model_init() model.build((None, 2)) save_pretrained_keras( model, self.cache_dir, config={"num": 10, "act": "gelu_fast"}, plot_model=True, tags=None ) new_model = from_pretrained_keras(self.cache_dir) self.assertTrue(tf.reduce_all(tf.equal(new_model.weights[0], model.weights[0]))) self.assertTrue(new_model.config == {"num": 10, "act": "gelu_fast"}) def test_push_to_hub_keras_sequential_via_http_basic(self): repo_id = f"{USER}/{repo_name()}" model = self.model_init() model = self.model_fit(model) push_to_hub_keras(model, repo_id=repo_id, token=TOKEN, api_endpoint=ENDPOINT_STAGING) assert self._api.model_info(repo_id).id == repo_id repo_files = self._api.list_repo_files(repo_id) assert "README.md" in repo_files assert "model.png" in repo_files self._api.delete_repo(repo_id=repo_id) def test_push_to_hub_keras_sequential_via_http_plot_false(self): repo_id = f"{USER}/{repo_name()}" model = self.model_init() model = self.model_fit(model) push_to_hub_keras(model, repo_id=repo_id, token=TOKEN, api_endpoint=ENDPOINT_STAGING, plot_model=False) repo_files = self._api.list_repo_files(repo_id) self.assertNotIn("model.png", repo_files) self._api.delete_repo(repo_id=repo_id) def test_push_to_hub_keras_via_http_override_tensorboard(self): """Test log directory is overwritten when pushing a keras model a 2nd time.""" repo_id = f"{USER}/{repo_name()}" log_dir = self.cache_dir / "tb_log_dir" log_dir.mkdir(parents=True, exist_ok=True) (log_dir / "tensorboard.txt").write_text("Keras FTW") model = self.model_init() model.build((None, 2)) push_to_hub_keras(model, repo_id=repo_id, log_dir=log_dir, api_endpoint=ENDPOINT_STAGING, token=TOKEN) log_dir2 = self.cache_dir / "tb_log_dir2" log_dir2.mkdir(parents=True, exist_ok=True) (log_dir2 / "override.txt").write_text("Keras FTW") push_to_hub_keras(model, repo_id=repo_id, log_dir=log_dir2, api_endpoint=ENDPOINT_STAGING, token=TOKEN) files = self._api.list_repo_files(repo_id) self.assertIn("logs/override.txt", files) self.assertNotIn("logs/tensorboard.txt", files) self._api.delete_repo(repo_id=repo_id) def test_push_to_hub_keras_via_http_with_model_kwargs(self): repo_id = f"{USER}/{repo_name()}" model = self.model_init() model = self.model_fit(model) push_to_hub_keras( model, repo_id=repo_id, api_endpoint=ENDPOINT_STAGING, token=TOKEN, include_optimizer=True, save_traces=False, ) assert self._api.model_info(repo_id).id == repo_id snapshot_path = snapshot_download(repo_id=repo_id, cache_dir=self.cache_dir) from_pretrained_keras(snapshot_path) self._api.delete_repo(repo_id) @require_tf class HubKerasFunctionalTest(CommonKerasTest): def model_init(self): inputs = tf.keras.layers.Input(shape=(2,)) outputs = tf.keras.layers.Dense(2, activation="relu")(inputs) model = tf.keras.models.Model(inputs=inputs, outputs=outputs) model.compile(optimizer="adam", loss="mse") return model def model_fit(self, model): x = tf.constant([[0.44, 0.90], [0.65, 0.39]]) y = tf.constant([[1, 1], [0, 0]]) model.fit(x, y) return model def test_save_pretrained(self): model = self.model_init() model.build((None, 2)) self.assertTrue(model.built) save_pretrained_keras(model, self.cache_dir) files = os.listdir(self.cache_dir) self.assertIn("saved_model.pb", files) self.assertIn("keras_metadata.pb", files) self.assertEqual(len(files), 7) def test_save_pretrained_fit(self): model = self.model_init() model = self.model_fit(model) save_pretrained_keras(model, self.cache_dir) files = os.listdir(self.cache_dir) self.assertIn("saved_model.pb", files) self.assertIn("keras_metadata.pb", files) self.assertEqual(len(files), 8) huggingface_hub-0.31.1/tests/test_lfs.py000066400000000000000000000213541500667546600202640ustar00rootroot00000000000000import os import unittest from hashlib import sha256 from io import BytesIO from huggingface_hub.lfs import UploadInfo from huggingface_hub.utils import SoftTemporaryDirectory from huggingface_hub.utils._lfs import SliceFileObj class TestUploadInfo(unittest.TestCase): def setUp(self) -> None: self.content = b"RandOm ConTEnT" * 1024 self.size = len(self.content) self.sha = sha256(self.content).digest() self.sample = self.content[:512] def test_upload_info_from_path(self): with SoftTemporaryDirectory() as tmpdir: filepath = os.path.join(tmpdir, "file.bin") with open(filepath, "wb+") as file: file.write(self.content) upload_info = UploadInfo.from_path(filepath) self.assertEqual(upload_info.sample, self.sample) self.assertEqual(upload_info.size, self.size) self.assertEqual(upload_info.sha256, self.sha) def test_upload_info_from_bytes(self): upload_info = UploadInfo.from_bytes(self.content) self.assertEqual(upload_info.sample, self.sample) self.assertEqual(upload_info.size, self.size) self.assertEqual(upload_info.sha256, self.sha) def test_upload_info_from_bytes_io(self): upload_info = UploadInfo.from_fileobj(BytesIO(self.content)) self.assertEqual(upload_info.sample, self.sample) self.assertEqual(upload_info.size, self.size) self.assertEqual(upload_info.sha256, self.sha) class TestSliceFileObj(unittest.TestCase): def setUp(self) -> None: self.content = b"RANDOM self.content uauabciabeubahveb" * 1024 def test_slice_fileobj_BytesIO(self): fileobj = BytesIO(self.content) prev_pos = fileobj.tell() # Test read with SliceFileObj(fileobj, seek_from=24, read_limit=18) as fileobj_slice: self.assertEqual(fileobj_slice.tell(), 0) self.assertEqual(fileobj_slice.read(), self.content[24:42]) self.assertEqual(fileobj_slice.tell(), 18) self.assertEqual(fileobj_slice.read(), b"") self.assertEqual(fileobj_slice.tell(), 18) self.assertEqual(fileobj.tell(), prev_pos) with SliceFileObj(fileobj, seek_from=0, read_limit=990) as fileobj_slice: self.assertEqual(fileobj_slice.tell(), 0) self.assertEqual(fileobj_slice.read(200), self.content[0:200]) self.assertEqual(fileobj_slice.read(500), self.content[200:700]) self.assertEqual(fileobj_slice.read(200), self.content[700:900]) self.assertEqual(fileobj_slice.read(200), self.content[900:990]) self.assertEqual(fileobj_slice.read(200), b"") # Test seek with whence = os.SEEK_SET with SliceFileObj(fileobj, seek_from=100, read_limit=100) as fileobj_slice: self.assertEqual(fileobj_slice.tell(), 0) fileobj_slice.seek(2, os.SEEK_SET) self.assertEqual(fileobj_slice.tell(), 2) self.assertEqual(fileobj_slice.fileobj.tell(), 102) fileobj_slice.seek(-4, os.SEEK_SET) self.assertEqual(fileobj_slice.tell(), 0) self.assertEqual(fileobj_slice.fileobj.tell(), 100) fileobj_slice.seek(100 + 4, os.SEEK_SET) self.assertEqual(fileobj_slice.tell(), 100) self.assertEqual(fileobj_slice.fileobj.tell(), 200) # Test seek with whence = os.SEEK_CUR with SliceFileObj(fileobj, seek_from=100, read_limit=100) as fileobj_slice: self.assertEqual(fileobj_slice.tell(), 0) fileobj_slice.seek(-5, os.SEEK_CUR) self.assertEqual(fileobj_slice.tell(), 0) self.assertEqual(fileobj_slice.fileobj.tell(), 100) fileobj_slice.seek(50, os.SEEK_CUR) self.assertEqual(fileobj_slice.tell(), 50) self.assertEqual(fileobj_slice.fileobj.tell(), 150) fileobj_slice.seek(100, os.SEEK_CUR) self.assertEqual(fileobj_slice.tell(), 100) self.assertEqual(fileobj_slice.fileobj.tell(), 200) fileobj_slice.seek(-300, os.SEEK_CUR) self.assertEqual(fileobj_slice.tell(), 0) self.assertEqual(fileobj_slice.fileobj.tell(), 100) # Test seek with whence = os.SEEK_END with SliceFileObj(fileobj, seek_from=100, read_limit=100) as fileobj_slice: self.assertEqual(fileobj_slice.tell(), 0) fileobj_slice.seek(-5, os.SEEK_END) self.assertEqual(fileobj_slice.tell(), 95) self.assertEqual(fileobj_slice.fileobj.tell(), 195) fileobj_slice.seek(50, os.SEEK_END) self.assertEqual(fileobj_slice.tell(), 100) self.assertEqual(fileobj_slice.fileobj.tell(), 200) fileobj_slice.seek(-200, os.SEEK_END) self.assertEqual(fileobj_slice.tell(), 0) self.assertEqual(fileobj_slice.fileobj.tell(), 100) def test_slice_fileobj_file(self): self.content = b"RANDOM self.content uauabciabeubahveb" * 1024 with SoftTemporaryDirectory() as tmpdir: filepath = os.path.join(tmpdir, "file.bin") with open(filepath, "wb+") as f: f.write(self.content) with open(filepath, "rb") as fileobj: prev_pos = fileobj.tell() # Test read with SliceFileObj(fileobj, seek_from=24, read_limit=18) as fileobj_slice: self.assertEqual(fileobj_slice.tell(), 0) self.assertEqual(fileobj_slice.read(), self.content[24:42]) self.assertEqual(fileobj_slice.tell(), 18) self.assertEqual(fileobj_slice.read(), b"") self.assertEqual(fileobj_slice.tell(), 18) self.assertEqual(fileobj.tell(), prev_pos) with SliceFileObj(fileobj, seek_from=0, read_limit=990) as fileobj_slice: self.assertEqual(fileobj_slice.tell(), 0) self.assertEqual(fileobj_slice.read(200), self.content[0:200]) self.assertEqual(fileobj_slice.read(500), self.content[200:700]) self.assertEqual(fileobj_slice.read(200), self.content[700:900]) self.assertEqual(fileobj_slice.read(200), self.content[900:990]) self.assertEqual(fileobj_slice.read(200), b"") # Test seek with whence = os.SEEK_SET with SliceFileObj(fileobj, seek_from=100, read_limit=100) as fileobj_slice: self.assertEqual(fileobj_slice.tell(), 0) fileobj_slice.seek(2, os.SEEK_SET) self.assertEqual(fileobj_slice.tell(), 2) self.assertEqual(fileobj_slice.fileobj.tell(), 102) fileobj_slice.seek(-4, os.SEEK_SET) self.assertEqual(fileobj_slice.tell(), 0) self.assertEqual(fileobj_slice.fileobj.tell(), 100) fileobj_slice.seek(100 + 4, os.SEEK_SET) self.assertEqual(fileobj_slice.tell(), 100) self.assertEqual(fileobj_slice.fileobj.tell(), 200) # Test seek with whence = os.SEEK_CUR with SliceFileObj(fileobj, seek_from=100, read_limit=100) as fileobj_slice: self.assertEqual(fileobj_slice.tell(), 0) fileobj_slice.seek(-5, os.SEEK_CUR) self.assertEqual(fileobj_slice.tell(), 0) self.assertEqual(fileobj_slice.fileobj.tell(), 100) fileobj_slice.seek(50, os.SEEK_CUR) self.assertEqual(fileobj_slice.tell(), 50) self.assertEqual(fileobj_slice.fileobj.tell(), 150) fileobj_slice.seek(100, os.SEEK_CUR) self.assertEqual(fileobj_slice.tell(), 100) self.assertEqual(fileobj_slice.fileobj.tell(), 200) fileobj_slice.seek(-300, os.SEEK_CUR) self.assertEqual(fileobj_slice.tell(), 0) self.assertEqual(fileobj_slice.fileobj.tell(), 100) # Test seek with whence = os.SEEK_END with SliceFileObj(fileobj, seek_from=100, read_limit=100) as fileobj_slice: self.assertEqual(fileobj_slice.tell(), 0) fileobj_slice.seek(-5, os.SEEK_END) self.assertEqual(fileobj_slice.tell(), 95) self.assertEqual(fileobj_slice.fileobj.tell(), 195) fileobj_slice.seek(50, os.SEEK_END) self.assertEqual(fileobj_slice.tell(), 100) self.assertEqual(fileobj_slice.fileobj.tell(), 200) fileobj_slice.seek(-200, os.SEEK_END) self.assertEqual(fileobj_slice.tell(), 0) self.assertEqual(fileobj_slice.fileobj.tell(), 100) huggingface_hub-0.31.1/tests/test_local_folder.py000066400000000000000000000240641500667546600221260ustar00rootroot00000000000000# coding=utf-8 # Copyright 2024-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains tests for the `.cache/huggingface` folder in local directories. See `huggingface_hub/src/_local_folder.py` for the implementation. """ import logging import os import threading import time from pathlib import Path, WindowsPath import pytest from huggingface_hub._local_folder import ( LocalDownloadFileMetadata, LocalDownloadFilePaths, LocalUploadFilePaths, _huggingface_dir, get_local_download_paths, get_local_upload_paths, read_download_metadata, write_download_metadata, ) def test_creates_huggingface_dir_with_gitignore(tmp_path: Path): """Test `.cache/huggingface/` dir is ignored by git.""" local_dir = tmp_path / "path" / "to" / "local" huggingface_dir = _huggingface_dir(local_dir) assert huggingface_dir == local_dir / ".cache" / "huggingface" assert huggingface_dir.exists() # all subdirectories have been created assert huggingface_dir.is_dir() # Whole folder must be ignored assert (huggingface_dir / ".gitignore").exists() assert (huggingface_dir / ".gitignore").read_text() == "*" def test_gitignore_lock_timeout_is_ignored(tmp_path: Path): local_dir = tmp_path / "path" / "to" / "local" threads = [threading.Thread(target=_huggingface_dir, args=(local_dir,)) for _ in range(10)] for thread in threads: thread.start() for thread in threads: thread.join() assert (local_dir / ".cache" / "huggingface" / ".gitignore").exists() assert not (local_dir / ".cache" / "huggingface" / ".gitignore.lock").exists() def test_local_download_paths(tmp_path: Path): """Test local download paths are valid + usable.""" paths = get_local_download_paths(tmp_path, "path/in/repo.txt") # Correct paths (also sanitized on windows) assert isinstance(paths, LocalDownloadFilePaths) assert paths.file_path == tmp_path / "path" / "in" / "repo.txt" assert ( paths.metadata_path == tmp_path / ".cache" / "huggingface" / "download" / "path" / "in" / "repo.txt.metadata" ) assert paths.lock_path == tmp_path / ".cache" / "huggingface" / "download" / "path" / "in" / "repo.txt.lock" # Paths are usable (parent directories have been created) assert paths.file_path.parent.is_dir() assert paths.metadata_path.parent.is_dir() assert paths.lock_path.parent.is_dir() # Incomplete paths are etag-based incomplete_path = paths.incomplete_path("etag123") assert incomplete_path.parent == tmp_path / ".cache" / "huggingface" / "download" / "path" / "in" assert incomplete_path.name.endswith(".etag123.incomplete") assert paths.incomplete_path("etag123").parent.is_dir() # Incomplete paths are unique per file per etag other_paths = get_local_download_paths(tmp_path, "path/in/repo_other.txt") other_incomplete_path = other_paths.incomplete_path("etag123") assert incomplete_path != other_incomplete_path # different .incomplete files to prevent concurrency issues def test_local_download_paths_are_recreated_each_time(tmp_path: Path): paths1 = get_local_download_paths(tmp_path, "path/in/repo.txt") assert paths1.file_path.parent.is_dir() assert paths1.metadata_path.parent.is_dir() paths1.file_path.parent.rmdir() paths1.metadata_path.parent.rmdir() paths2 = get_local_download_paths(tmp_path, "path/in/repo.txt") assert paths2.file_path.parent.is_dir() assert paths2.metadata_path.parent.is_dir() @pytest.mark.skipif(os.name != "nt", reason="Windows-specific test.") def test_local_download_paths_long_paths(tmp_path: Path): """Test long path handling on Windows.""" long_file_name = "a" * 255 paths = get_local_download_paths(tmp_path, f"path/long/{long_file_name}.txt") # WindowsPath on Windows platform assert isinstance(paths.file_path, WindowsPath) assert isinstance(paths.lock_path, WindowsPath) assert isinstance(paths.metadata_path, WindowsPath) # Correct path prefixes assert str(paths.file_path).startswith("\\\\?\\") assert str(paths.lock_path).startswith("\\\\?\\") assert str(paths.metadata_path).startswith("\\\\?\\") def test_write_download_metadata(tmp_path: Path): """Test download metadata content is valid.""" # Write metadata write_download_metadata(tmp_path, filename="file.txt", commit_hash="commit_hash", etag="123456789") metadata_path = tmp_path / ".cache" / "huggingface" / "download" / "file.txt.metadata" assert metadata_path.exists() # Metadata is valid with metadata_path.open() as f: assert f.readline() == "commit_hash\n" assert f.readline() == "123456789\n" timestamp = float(f.readline().strip()) assert timestamp <= time.time() # in the past assert timestamp >= time.time() - 1 # but less than 1 seconds ago (we're not that slow) time.sleep(0.2) # for deterministic tests # Overwriting works as expected write_download_metadata(tmp_path, filename="file.txt", commit_hash="commit_hash2", etag="987654321") with metadata_path.open() as f: assert f.readline() == "commit_hash2\n" assert f.readline() == "987654321\n" timestamp2 = float(f.readline().strip()) assert timestamp <= timestamp2 # updated timestamp def test_read_download_metadata_valid_metadata(tmp_path: Path): """Test reading download metadata when metadata is valid.""" # Create file + write correct metadata (tmp_path / "file.txt").write_text("content") write_download_metadata(tmp_path, filename="file.txt", commit_hash="commit_hash", etag="123456789") # Read metadata metadata = read_download_metadata(tmp_path, filename="file.txt") assert isinstance(metadata, LocalDownloadFileMetadata) assert metadata.filename == "file.txt" assert metadata.commit_hash == "commit_hash" assert metadata.etag == "123456789" assert isinstance(metadata.timestamp, float) def test_read_download_metadata_no_metadata(tmp_path: Path): """Test reading download metadata when there is no metadata.""" # No metadata file => return None assert read_download_metadata(tmp_path, filename="file.txt") is None def test_read_download_metadata_corrupted_metadata(tmp_path: Path, caplog: pytest.LogCaptureFixture): """Test reading download metadata when metadata is corrupted.""" # Write corrupted metadata metadata_path = tmp_path / ".cache" / "huggingface" / "download" / "file.txt.metadata" metadata_path.parent.mkdir(parents=True, exist_ok=True) metadata_path.write_text("invalid content") # Corrupted metadata file => delete it + warn + return None with caplog.at_level(logging.WARNING): assert read_download_metadata(tmp_path, filename="file.txt") is None assert not metadata_path.exists() assert "Invalid metadata file" in caplog.text def test_read_download_metadata_correct_metadata_missing_file(tmp_path: Path): """Test reading download metadata when metadata is correct but file is missing.""" # Write correct metadata write_download_metadata(tmp_path, filename="file.txt", commit_hash="commit_hash", etag="123456789") # File missing => return None assert read_download_metadata(tmp_path, filename="file.txt") is None def test_read_download_metadata_correct_metadata_but_outdated(tmp_path: Path): """Test reading download metadata when metadata is correct but outdated.""" # Write correct metadata write_download_metadata(tmp_path, filename="file.txt", commit_hash="commit_hash", etag="123456789") time.sleep(2) # We allow for a 1s difference in practice, so let's wait a bit # File is outdated => return None (tmp_path / "file.txt").write_text("content") assert read_download_metadata(tmp_path, filename="file.txt") is None def test_local_upload_paths(tmp_path: Path): """Test local upload paths are valid + usable.""" paths = get_local_upload_paths(tmp_path, "path/in/repo.txt") # Correct paths (also sanitized on windows) assert isinstance(paths, LocalUploadFilePaths) assert paths.file_path == tmp_path / "path" / "in" / "repo.txt" assert paths.metadata_path == tmp_path / ".cache" / "huggingface" / "upload" / "path" / "in" / "repo.txt.metadata" assert paths.lock_path == tmp_path / ".cache" / "huggingface" / "upload" / "path" / "in" / "repo.txt.lock" # Paths are usable (parent directories have been created) assert paths.file_path.parent.is_dir() assert paths.metadata_path.parent.is_dir() assert paths.lock_path.parent.is_dir() def test_local_upload_paths_are_recreated_each_time(tmp_path: Path): paths1 = get_local_upload_paths(tmp_path, "path/in/repo.txt") assert paths1.file_path.parent.is_dir() assert paths1.metadata_path.parent.is_dir() paths1.file_path.parent.rmdir() paths1.metadata_path.parent.rmdir() paths2 = get_local_upload_paths(tmp_path, "path/in/repo.txt") assert paths2.file_path.parent.is_dir() assert paths2.metadata_path.parent.is_dir() @pytest.mark.skipif(os.name != "nt", reason="Windows-specific test.") def test_local_upload_paths_long_paths(tmp_path: Path): """Test long path handling on Windows.""" long_file_name = "a" * 255 paths = get_local_upload_paths(tmp_path, f"path/long/{long_file_name}.txt") # WindowsPath on Windows platform assert isinstance(paths.file_path, WindowsPath) assert isinstance(paths.lock_path, WindowsPath) assert isinstance(paths.metadata_path, WindowsPath) # Correct path prefixes assert str(paths.file_path).startswith("\\\\?\\") assert str(paths.lock_path).startswith("\\\\?\\") assert str(paths.metadata_path).startswith("\\\\?\\") huggingface_hub-0.31.1/tests/test_login_utils.py000066400000000000000000000025531500667546600220300ustar00rootroot00000000000000import subprocess import unittest from typing import Optional from huggingface_hub._login import _set_store_as_git_credential_helper_globally from huggingface_hub.utils import run_subprocess class TestSetGlobalStore(unittest.TestCase): previous_config: Optional[str] def setUp(self) -> None: """Get current global config value.""" try: self.previous_config = run_subprocess("git config --global credential.helper").stdout except subprocess.CalledProcessError: self.previous_config = None # Means global credential.helper value not set run_subprocess("git config --global credential.helper store") def tearDown(self) -> None: """Reset global config value.""" if self.previous_config is None: run_subprocess("git config --global --unset credential.helper") else: run_subprocess(f"git config --global credential.helper {self.previous_config}") def test_set_store_as_git_credential_helper_globally(self) -> None: """Test `_set_store_as_git_credential_helper_globally` works as expected. Previous value from the machine is restored after the test. """ _set_store_as_git_credential_helper_globally() new_config = run_subprocess("git config --global credential.helper").stdout self.assertEqual(new_config, "store\n") huggingface_hub-0.31.1/tests/test_offline_utils.py000066400000000000000000000023751500667546600223440ustar00rootroot00000000000000from io import BytesIO import pytest import requests from huggingface_hub.file_download import http_get from .testing_utils import ( OfflineSimulationMode, RequestWouldHangIndefinitelyError, offline, ) def test_offline_with_timeout(): with offline(OfflineSimulationMode.CONNECTION_TIMES_OUT): with pytest.raises(RequestWouldHangIndefinitelyError): requests.request("GET", "https://huggingface.co") with pytest.raises(requests.exceptions.ConnectTimeout): requests.request("GET", "https://huggingface.co", timeout=1.0) with pytest.raises(requests.exceptions.ConnectTimeout): http_get("https://huggingface.co", BytesIO()) def test_offline_with_connection_error(): with offline(OfflineSimulationMode.CONNECTION_FAILS): with pytest.raises(requests.exceptions.ConnectionError): requests.request("GET", "https://huggingface.co") with pytest.raises(requests.exceptions.ConnectionError): http_get("https://huggingface.co", BytesIO()) def test_offline_with_datasets_offline_mode_enabled(): with offline(OfflineSimulationMode.HF_HUB_OFFLINE_SET_TO_1): with pytest.raises(ConnectionError): http_get("https://huggingface.co", BytesIO()) huggingface_hub-0.31.1/tests/test_repocard.py000066400000000000000000001046721500667546600213040ustar00rootroot00000000000000# Copyright 2021 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import os import re import unittest from pathlib import Path import pytest import yaml from huggingface_hub import ( DatasetCard, DatasetCardData, EvalResult, ModelCard, ModelCardData, RepoCard, SpaceCard, SpaceCardData, constants, get_hf_file_metadata, hf_hub_url, metadata_eval_result, metadata_load, metadata_save, metadata_update, ) from huggingface_hub.errors import EntryNotFoundError from huggingface_hub.file_download import hf_hub_download from huggingface_hub.hf_api import HfApi from huggingface_hub.repocard import REGEX_YAML_BLOCK from huggingface_hub.repocard_data import CardData from huggingface_hub.utils import SoftTemporaryDirectory, is_jinja_available from .testing_constants import ENDPOINT_STAGING, TOKEN, USER from .testing_utils import repo_name, with_production_testing SAMPLE_CARDS_DIR = Path(__file__).parent / "fixtures/cards" ROUND_TRIP_MODELCARD_CASE = """ --- language: no datasets: CLUECorpusSmall widget: - text: 北京是[MASK]国的首都。 --- # Title """ DUMMY_MODELCARD = """ --- license: mit datasets: - foo - bar --- Hello """ DUMMY_MODELCARD_TARGET = """--- meaning_of_life: 42 --- Hello """ DUMMY_MODELCARD_TARGET_WITH_EMOJI = """--- emoji: 🎁 --- Hello """ DUMMY_MODELCARD_TARGET_NO_YAML = """--- meaning_of_life: 42 --- Hello """ DUMMY_NEW_MODELCARD_TARGET = """--- meaning_of_life: 42 --- """ DUMMY_MODELCARD_TARGET_NO_TAGS = """ Hello """ DUMMY_MODELCARD_EVAL_RESULT = """--- model-index: - name: RoBERTa fine-tuned on ReactionGIF results: - task: type: text-classification name: Text Classification dataset: name: ReactionGIF type: julien-c/reactiongif config: default split: test metrics: - type: accuracy value: 0.2662102282047272 name: Accuracy config: default verified: true --- """ DUMMY_MODELCARD_NO_TEXT_CONTENT = """--- license: cc-by-sa-4.0 --- """ DUMMY_MODELCARD_EVAL_RESULT_BOTH_VERIFIED_AND_UNVERIFIED = """--- model-index: - name: RoBERTa fine-tuned on ReactionGIF results: - task: type: text-classification name: Text Classification dataset: name: ReactionGIF type: julien-c/reactiongif config: default split: test metrics: - type: accuracy value: 0.2662102282047272 name: Accuracy config: default verified: false - task: type: text-classification name: Text Classification dataset: name: ReactionGIF type: julien-c/reactiongif config: default split: test metrics: - type: accuracy value: 0.6666666666666666 name: Accuracy config: default verified: true --- This is a test model card. """ DUMMY_MODELCARD_EMPTY_METADATA = """ --- --- # Begin of markdown after an empty metadata. Some cool dataset card. """ DUMMY_MODEL_CARD_TEMPLATE = """ --- {{ card_data }} --- Custom template passed as a string. {{ repo_url | default("[More Information Needed]", true) }} """ def require_jinja(test_case): """ Decorator marking a test that requires Jinja2. These tests are skipped when Jinja2 is not installed. """ if not is_jinja_available(): return unittest.skip("test requires Jinja2.")(test_case) else: return test_case @pytest.mark.usefixtures("fx_cache_dir") class RepocardMetadataTest(unittest.TestCase): cache_dir: Path def setUp(self) -> None: self.filepath = self.cache_dir / constants.REPOCARD_NAME def test_metadata_load(self): self.filepath.write_text(DUMMY_MODELCARD) data = metadata_load(self.filepath) self.assertDictEqual(data, {"license": "mit", "datasets": ["foo", "bar"]}) def test_metadata_save(self): self.filepath.write_text(DUMMY_MODELCARD) metadata_save(self.filepath, {"meaning_of_life": 42}) content = self.filepath.read_text() self.assertEqual(content, DUMMY_MODELCARD_TARGET) def test_metadata_save_with_emoji_character(self): self.filepath.write_text(DUMMY_MODELCARD) metadata_save(self.filepath, {"emoji": "🎁"}) content = self.filepath.read_text(encoding="utf-8") self.assertEqual(content, DUMMY_MODELCARD_TARGET_WITH_EMOJI) def test_metadata_save_from_file_no_yaml(self): self.filepath.write_text("Hello\n") metadata_save(self.filepath, {"meaning_of_life": 42}) content = self.filepath.read_text() self.assertEqual(content, DUMMY_MODELCARD_TARGET_NO_YAML) def test_metadata_save_new_file(self): metadata_save(self.filepath, {"meaning_of_life": 42}) content = self.filepath.read_text() self.assertEqual(content, DUMMY_NEW_MODELCARD_TARGET) def test_no_metadata_returns_none(self): self.filepath.write_text(DUMMY_MODELCARD_TARGET_NO_TAGS) data = metadata_load(self.filepath) self.assertEqual(data, None) def test_empty_metadata_returns_none_with_metadata_load(self): self.filepath.write_text(DUMMY_MODELCARD_EMPTY_METADATA) data = metadata_load(self.filepath) self.assertEqual(data, None) def test_empty_metadata_returns_none_with_repocard_load(self): self.filepath.write_text(DUMMY_MODELCARD_EMPTY_METADATA) self.assertIsNone(metadata_load(self.filepath)) self.assertEqual(RepoCard.load(self.filepath).data.to_dict(), {}) def test_metadata_eval_result(self): data = metadata_eval_result( model_pretty_name="RoBERTa fine-tuned on ReactionGIF", task_pretty_name="Text Classification", task_id="text-classification", metrics_pretty_name="Accuracy", metrics_id="accuracy", metrics_value=0.2662102282047272, metrics_config="default", metrics_verified=True, dataset_pretty_name="ReactionGIF", dataset_id="julien-c/reactiongif", dataset_config="default", dataset_split="test", ) metadata_save(self.filepath, data) content = self.filepath.read_text().splitlines() self.assertEqual(content, DUMMY_MODELCARD_EVAL_RESULT.splitlines()) @with_production_testing def test_load_from_hub_if_repo_id_or_path_is_a_dir(monkeypatch, tmp_path): """If `repo_id_or_path` happens to be both a `repo_id` and a local directory, the card must be loaded from the Hub. Path can only be a file path. Regression test for https://github.com/huggingface/huggingface_hub/issues/2768. """ repo_id = "openai-community/gpt2" monkeypatch.chdir(tmp_path) test_dir = tmp_path / "openai-community" test_dir.mkdir() model_dir = test_dir / "gpt2" model_dir.mkdir() card = RepoCard.load(repo_id) assert "GPT-2" in str(card) # loaded from Hub assert Path(repo_id).is_dir() class RepocardMetadataUpdateTest(unittest.TestCase): def setUp(self) -> None: self.token = TOKEN self.api = HfApi(token=TOKEN) self.repo_id = self.api.create_repo(repo_name()).repo_id self.api.upload_file( path_or_fileobj=DUMMY_MODELCARD_EVAL_RESULT.encode(), repo_id=self.repo_id, path_in_repo=constants.REPOCARD_NAME, ) self.existing_metadata = yaml.safe_load(DUMMY_MODELCARD_EVAL_RESULT.strip().strip("-")) def tearDown(self) -> None: self.api.delete_repo(repo_id=self.repo_id) def _get_remote_card(self) -> str: return hf_hub_download(repo_id=self.repo_id, filename=constants.REPOCARD_NAME) def test_update_dataset_name(self): new_datasets_data = {"datasets": ["test/test_dataset"]} metadata_update(self.repo_id, new_datasets_data, token=self.token) hf_hub_download(repo_id=self.repo_id, filename=constants.REPOCARD_NAME) updated_metadata = metadata_load(self._get_remote_card()) expected_metadata = copy.deepcopy(self.existing_metadata) expected_metadata.update(new_datasets_data) self.assertDictEqual(updated_metadata, expected_metadata) def test_update_existing_result_with_overwrite(self): new_metadata = copy.deepcopy(self.existing_metadata) new_metadata["model-index"][0]["results"][0]["metrics"][0]["value"] = 0.2862102282047272 metadata_update(self.repo_id, new_metadata, token=self.token, overwrite=True) updated_metadata = metadata_load(self._get_remote_card()) self.assertDictEqual(updated_metadata, new_metadata) def test_update_verify_token(self): """Tests whether updating the verification token updates in-place. Regression test for https://github.com/huggingface/huggingface_hub/issues/1210 """ new_metadata = copy.deepcopy(self.existing_metadata) new_metadata["model-index"][0]["results"][0]["metrics"][0]["verifyToken"] = "1234" metadata_update(self.repo_id, new_metadata, token=self.token, overwrite=True) updated_metadata = metadata_load(self._get_remote_card()) self.assertDictEqual(updated_metadata, new_metadata) def test_metadata_update_upstream(self): new_metadata = copy.deepcopy(self.existing_metadata) new_metadata["model-index"][0]["results"][0]["metrics"][0]["value"] = 0.1 # download first, then update path = self._get_remote_card() metadata_update(self.repo_id, new_metadata, token=self.token, overwrite=True) self.assertNotEqual(metadata_load(path), new_metadata) self.assertEqual(metadata_load(path), self.existing_metadata) def test_update_existing_result_without_overwrite(self): new_metadata = copy.deepcopy(self.existing_metadata) new_metadata["model-index"][0]["results"][0]["metrics"][0]["value"] = 0.2862102282047272 with pytest.raises( ValueError, match=( "You passed a new value for the existing metric 'name: Accuracy, type:" " accuracy'. Set `overwrite=True` to overwrite existing metrics." ), ): metadata_update(self.repo_id, new_metadata, token=self.token, overwrite=False) def test_update_existing_field_without_overwrite(self): new_datasets_data = {"datasets": ["Open-Orca/OpenOrca"]} metadata_update(self.repo_id, new_datasets_data, token=self.token) with pytest.raises( ValueError, match=( "You passed a new value for the existing meta data field 'datasets'." " Set `overwrite=True` to overwrite existing metadata." ), ): new_datasets_data = {"datasets": ["HuggingFaceH4/no_robots"]} metadata_update(self.repo_id, new_datasets_data, token=self.token, overwrite=False) def test_update_new_result_existing_dataset(self): new_result = metadata_eval_result( model_pretty_name="RoBERTa fine-tuned on ReactionGIF", task_pretty_name="Text Classification", task_id="text-classification", metrics_pretty_name="Recall", metrics_id="recall", metrics_value=0.7762102282047272, metrics_config="default", metrics_verified=False, dataset_pretty_name="ReactionGIF", dataset_id="julien-c/reactiongif", dataset_config="default", dataset_split="test", ) metadata_update(self.repo_id, new_result, token=self.token, overwrite=False) expected_metadata = copy.deepcopy(self.existing_metadata) expected_metadata["model-index"][0]["results"][0]["metrics"].append( new_result["model-index"][0]["results"][0]["metrics"][0] ) updated_metadata = metadata_load(self._get_remote_card()) self.assertDictEqual(updated_metadata, expected_metadata) def test_update_new_result_new_dataset(self): new_result = metadata_eval_result( model_pretty_name="RoBERTa fine-tuned on ReactionGIF", task_pretty_name="Text Classification", task_id="text-classification", metrics_pretty_name="Accuracy", metrics_id="accuracy", metrics_value=0.2662102282047272, metrics_config="default", metrics_verified=False, dataset_pretty_name="ReactionJPEG", dataset_id="julien-c/reactionjpeg", dataset_config="default", dataset_split="test", ) metadata_update(self.repo_id, new_result, token=self.token, overwrite=False) expected_metadata = copy.deepcopy(self.existing_metadata) expected_metadata["model-index"][0]["results"].append(new_result["model-index"][0]["results"][0]) updated_metadata = metadata_load(self._get_remote_card()) self.assertDictEqual(updated_metadata, expected_metadata) def test_update_metadata_on_empty_text_content(self) -> None: """Test `update_metadata` on a model card that has metadata but no text content Regression test for https://github.com/huggingface/huggingface_hub/issues/1010 """ # Create modelcard with metadata but empty text content self.api.upload_file( path_or_fileobj=DUMMY_MODELCARD_NO_TEXT_CONTENT.encode(), path_in_repo=constants.REPOCARD_NAME, repo_id=self.repo_id, ) metadata_update(self.repo_id, {"tag": "test"}, token=self.token) # Check update went fine updated_metadata = metadata_load(self._get_remote_card()) expected_metadata = {"license": "cc-by-sa-4.0", "tag": "test"} self.assertDictEqual(updated_metadata, expected_metadata) def test_update_with_existing_name(self): new_metadata = copy.deepcopy(self.existing_metadata) new_metadata["model-index"][0].pop("name") new_metadata["model-index"][0]["results"][0]["metrics"][0]["value"] = 0.2862102282047272 metadata_update(self.repo_id, new_metadata, token=self.token, overwrite=True) card_data = ModelCard.load(self.repo_id) self.assertEqual(card_data.data.model_name, self.existing_metadata["model-index"][0]["name"]) def test_update_without_existing_name(self): # delete existing metadata self.api.upload_file(path_or_fileobj="# Test".encode(), repo_id=self.repo_id, path_in_repo="README.md") new_metadata = copy.deepcopy(self.existing_metadata) new_metadata["model-index"][0].pop("name") metadata_update(self.repo_id, new_metadata, token=self.token, overwrite=True) card_data = ModelCard.load(self.repo_id) self.assertEqual(card_data.data.model_name, self.repo_id) def test_update_with_both_verified_and_unverified_metric(self): """Regression test for #1185. See https://github.com/huggingface/huggingface_hub/issues/1185. """ self.api.upload_file( path_or_fileobj=DUMMY_MODELCARD_EVAL_RESULT_BOTH_VERIFIED_AND_UNVERIFIED.encode(), repo_id=self.repo_id, path_in_repo="README.md", ) card = ModelCard.load(self.repo_id) metadata = card.data.to_dict() metadata_update(self.repo_id, metadata=metadata, overwrite=True, token=self.token) new_card = ModelCard.load(self.repo_id) self.assertEqual(len(new_card.data.eval_results), 2) first_result = new_card.data.eval_results[0] second_result = new_card.data.eval_results[1] # One is verified, the other not self.assertFalse(first_result.verified) self.assertTrue(second_result.verified) # Result values are different self.assertEqual(first_result.metric_value, 0.2662102282047272) self.assertEqual(second_result.metric_value, 0.6666666666666666) class TestMetadataUpdateOnMissingCard(unittest.TestCase): def setUp(self) -> None: """ Share this valid token in all tests below. """ self._token = TOKEN self._api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) self._repo_id = f"{USER}/{repo_name()}" def test_metadata_update_missing_readme_on_model(self) -> None: self._api.create_repo(self._repo_id) metadata_update(self._repo_id, {"tag": "this_is_a_test"}, token=self._token) model_card = ModelCard.load(self._repo_id, token=self._token) # Created a card with default template + metadata self.assertIn("# Model Card for Model ID", str(model_card)) self.assertEqual(model_card.data.to_dict(), {"tag": "this_is_a_test"}) self._api.delete_repo(self._repo_id) def test_metadata_update_missing_readme_on_dataset(self) -> None: self._api.create_repo(self._repo_id, repo_type="dataset") metadata_update( self._repo_id, {"tag": "this is a dataset test"}, token=self._token, repo_type="dataset", ) dataset_card = DatasetCard.load(self._repo_id, token=self._token) # Created a card with default template + metadata self.assertIn("# Dataset Card for Dataset Name", str(dataset_card)) self.assertEqual(dataset_card.data.to_dict(), {"tag": "this is a dataset test"}) self._api.delete_repo(self._repo_id, repo_type="dataset") def test_metadata_update_missing_readme_on_space(self) -> None: self._api.create_repo(self._repo_id, repo_type="space", space_sdk="static") self._api.delete_file("README.md", self._repo_id, repo_type="space") with self.assertRaises(ValueError): # Cannot create a default readme on a space repo (should be automatically # created on the Hub). metadata_update( self._repo_id, {"tag": "this is a space test"}, token=self._token, repo_type="space", ) self._api.delete_repo(self._repo_id, repo_type="space") class TestCaseWithHfApi(unittest.TestCase): _api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) class RepoCardTest(TestCaseWithHfApi): def test_load_repocard_from_file(self): sample_path = SAMPLE_CARDS_DIR / "sample_simple.md" card = RepoCard.load(sample_path) self.assertEqual( card.data.to_dict(), { "language": ["en"], "license": "mit", "library_name": "pytorch-lightning", "tags": ["pytorch", "image-classification"], "datasets": ["beans"], "metrics": ["acc"], }, ) self.assertTrue( card.text.strip().startswith("# my-cool-model"), "Card text not loaded properly", ) def test_change_repocard_data(self): sample_path = SAMPLE_CARDS_DIR / "sample_simple.md" card = RepoCard.load(sample_path) card.data.language = ["fr"] with SoftTemporaryDirectory() as tempdir: updated_card_path = Path(tempdir) / "updated.md" card.save(updated_card_path) updated_card = RepoCard.load(updated_card_path) self.assertEqual(updated_card.data.language, ["fr"], "Card data not updated properly") @require_jinja def test_repo_card_from_default_template(self): card = RepoCard.from_template( card_data=CardData( language="en", license="mit", library_name="pytorch", tags=["image-classification", "resnet"], datasets="imagenet", metrics=["acc", "f1"], ), model_id=None, ) self.assertIsInstance(card, RepoCard) self.assertTrue( card.text.strip().startswith("# Model Card for Model ID"), "Default model name not set correctly", ) @require_jinja def test_repo_card_from_default_template_with_model_id(self): card = RepoCard.from_template( card_data=CardData( language="en", license="mit", library_name="pytorch", tags=["image-classification", "resnet"], datasets="imagenet", metrics=["acc", "f1"], ), model_id="my-cool-model", ) self.assertTrue( card.text.strip().startswith("# Model Card for my-cool-model"), "model_id not properly set in card template", ) @require_jinja def test_repo_card_from_custom_template_path(self): # Template is passed as a path (not a raw string) template_path = SAMPLE_CARDS_DIR / "sample_template.md" card = RepoCard.from_template( card_data=CardData( language="en", license="mit", library_name="pytorch", tags="text-classification", datasets="glue", metrics="acc", ), template_path=template_path, some_data="asdf", ) self.assertTrue( card.text.endswith("asdf"), "Custom template didn't set jinja variable correctly", ) @require_jinja def test_repo_card_from_custom_template_string(self): # Template is passed as a raw string (not a path) card = RepoCard.from_template( card_data=CardData(language="en", license="mit"), template_str=DUMMY_MODEL_CARD_TEMPLATE, ) assert "Custom template passed as a string." in str(card) def test_repo_card_data_must_be_dict(self): sample_path = SAMPLE_CARDS_DIR / "sample_invalid_card_data.md" with pytest.raises(ValueError, match="repo card metadata block should be a dict"): RepoCard(sample_path.read_text()) def test_repo_card_without_metadata(self): sample_path = SAMPLE_CARDS_DIR / "sample_no_metadata.md" with self.assertLogs("huggingface_hub", level="WARNING") as warning_logs: card = RepoCard(sample_path.read_text()) self.assertTrue( any( "Repo card metadata block was not found. Setting CardData to empty." in log for log in warning_logs.output ) ) self.assertEqual(card.data, CardData()) def test_validate_repocard(self): sample_path = SAMPLE_CARDS_DIR / "sample_simple.md" card = RepoCard.load(sample_path) card.validate() card.data.license = "asdf" with pytest.raises(ValueError, match='- Error: "license" must be one of'): card.validate() def test_push_to_hub(self): repo_id = f"{USER}/{repo_name('push-card')}" self._api.create_repo(repo_id) card_data = CardData( language="en", license="mit", library_name="pytorch", tags=["text-classification"], datasets="glue", metrics="acc", ) # Mock what RepoCard.from_template does so we can test w/o Jinja2 content = f"---\n{card_data.to_yaml()}\n---\n\n# MyModel\n\nHello, world!" card = RepoCard(content) # Check this file doesn't exist (sanity check) readme_url = hf_hub_url(repo_id, "README.md") with self.assertRaises(EntryNotFoundError): get_hf_file_metadata(readme_url) # Push the card up to README.md in the repo card.push_to_hub(repo_id, token=TOKEN) # No error should occur now, as README.md should exist get_hf_file_metadata(readme_url) self._api.delete_repo(repo_id=repo_id) def test_push_and_create_pr(self): repo_id = f"{USER}/{repo_name('pr-card')}" self._api.create_repo(repo_id) card_data = CardData( language="en", license="mit", library_name="pytorch", tags=["text-classification"], datasets="glue", metrics="acc", ) # Mock what RepoCard.from_template does so we can test w/o Jinja2 content = f"---\n{card_data.to_yaml()}\n---\n\n# MyModel\n\nHello, world!" card = RepoCard(content) discussions = list(self._api.get_repo_discussions(repo_id)) self.assertEqual(len(discussions), 0) card.push_to_hub(repo_id, token=TOKEN, create_pr=True) discussions = list(self._api.get_repo_discussions(repo_id)) self.assertEqual(len(discussions), 1) self._api.delete_repo(repo_id=repo_id) def test_preserve_windows_linebreaks(self): card_path = SAMPLE_CARDS_DIR / "sample_windows_line_breaks.md" card = RepoCard.load(card_path) self.assertIn("\r\n", str(card)) def test_preserve_linebreaks_when_saving(self): card_path = SAMPLE_CARDS_DIR / "sample_simple.md" card = RepoCard.load(card_path) with SoftTemporaryDirectory() as tmpdir: tmpfile = os.path.join(tmpdir, "readme.md") card.save(tmpfile) card2 = RepoCard.load(tmpfile) self.assertEqual(str(card), str(card2)) def test_updating_text_updates_content(self): sample_path = SAMPLE_CARDS_DIR / "sample_simple.md" card = RepoCard.load(sample_path) card.text = "Hello, world!" line_break = "\r\n" if os.name == "nt" else "\n" self.assertEqual( card.content, # line_break depends on platform. Correctly set when using RepoCard.save(...) to avoid diffs f"---\n{card.data.to_yaml()}\n---\nHello, world!".replace("\n", line_break), ) class TestRegexYamlBlock(unittest.TestCase): def test_match_with_leading_whitespace(self): self.assertIsNotNone(REGEX_YAML_BLOCK.search(" \n---\nmetadata: 1\n---")) def test_match_without_leading_whitespace(self): self.assertIsNotNone(REGEX_YAML_BLOCK.search("---\nmetadata: 1\n---")) def test_does_not_match_with_leading_text(self): self.assertIsNone(REGEX_YAML_BLOCK.search("something\n---\nmetadata: 1\n---")) class ModelCardTest(TestCaseWithHfApi): def test_model_card_with_invalid_model_index(self): """Test raise an error when loading a card that has invalid model-index.""" sample_path = SAMPLE_CARDS_DIR / "sample_invalid_model_index.md" with self.assertRaises(ValueError): ModelCard.load(sample_path) def test_model_card_with_invalid_model_index_and_ignore_error(self): """Test trigger a warning when loading a card that has invalid model-index and `ignore_metadata_errors=True` Some information is lost. """ sample_path = SAMPLE_CARDS_DIR / "sample_invalid_model_index.md" with self.assertLogs("huggingface_hub", level="WARNING") as warning_logs: card = ModelCard.load(sample_path, ignore_metadata_errors=True) self.assertTrue( any("Invalid model-index. Not loading eval results into CardData." in log for log in warning_logs.output) ) self.assertIsNone(card.data.eval_results) def test_model_card_with_model_index(self): """Test that loading a model card with multiple evaluations is consistent with `metadata_load`. Regression test for https://github.com/huggingface/huggingface_hub/issues/1208 """ sample_path = SAMPLE_CARDS_DIR / "sample_simple_model_index.md" card = ModelCard.load(sample_path) metadata = metadata_load(sample_path) self.assertDictEqual(card.data.to_dict(), metadata) def test_load_model_card_from_file(self): sample_path = SAMPLE_CARDS_DIR / "sample_simple.md" card = ModelCard.load(sample_path) self.assertIsInstance(card, ModelCard) self.assertEqual( card.data.to_dict(), { "language": ["en"], "license": "mit", "library_name": "pytorch-lightning", "tags": ["pytorch", "image-classification"], "datasets": ["beans"], "metrics": ["acc"], }, ) self.assertTrue( card.text.strip().startswith("# my-cool-model"), "Card text not loaded properly", ) @require_jinja def test_model_card_from_custom_template(self): template_path = SAMPLE_CARDS_DIR / "sample_template.md" card = ModelCard.from_template( card_data=ModelCardData( language="en", license="mit", library_name="pytorch", tags="text-classification", datasets="glue", metrics="acc", ), template_path=template_path, some_data="asdf", ) self.assertIsInstance(card, ModelCard) self.assertTrue( card.text.endswith("asdf"), "Custom template didn't set jinja variable correctly", ) @require_jinja def test_model_card_from_template_eval_results(self): template_path = SAMPLE_CARDS_DIR / "sample_template.md" card = ModelCard.from_template( card_data=ModelCardData( eval_results=[ EvalResult( task_type="text-classification", task_name="Text Classification", dataset_type="julien-c/reactiongif", dataset_name="ReactionGIF", dataset_config="default", dataset_split="test", metric_type="accuracy", metric_value=0.2662102282047272, metric_name="Accuracy", metric_config="default", verified=True, ), ], model_name="RoBERTa fine-tuned on ReactionGIF", ), template_path=template_path, some_data="asdf", ) self.assertIsInstance(card, ModelCard) self.assertTrue(card.text.endswith("asdf")) self.assertTrue(card.data.to_dict().get("eval_results") is None) self.assertEqual(str(card)[: len(DUMMY_MODELCARD_EVAL_RESULT)], DUMMY_MODELCARD_EVAL_RESULT) def test_preserve_order_load_save(self): model_card = ModelCard(DUMMY_MODELCARD) model_card.data.license = "test" self.assertEqual(model_card.content, "---\nlicense: test\ndatasets:\n- foo\n- bar\n---\n\nHello\n") class DatasetCardTest(TestCaseWithHfApi): def test_load_datasetcard_from_file(self): sample_path = SAMPLE_CARDS_DIR / "sample_datasetcard_simple.md" card = DatasetCard.load(sample_path) self.assertEqual( card.data.to_dict(), { "annotations_creators": ["crowdsourced", "expert-generated"], "language_creators": ["found"], "language": ["en"], "license": ["bsd-3-clause"], "multilinguality": ["monolingual"], "size_categories": ["n<1K"], "task_categories": ["image-segmentation"], "task_ids": ["semantic-segmentation"], "pretty_name": "Sample Segmentation", }, ) self.assertIsInstance(card, DatasetCard) self.assertIsInstance(card.data, DatasetCardData) self.assertTrue(card.text.strip().startswith("# Dataset Card for")) @require_jinja def test_dataset_card_from_default_template(self): card_data = DatasetCardData( language="en", license="mit", ) # Here we check default title when pretty_name not provided. card = DatasetCard.from_template(card_data) self.assertTrue(card.text.strip().startswith("# Dataset Card for Dataset Name")) card_data = DatasetCardData( language="en", license="mit", pretty_name="My Cool Dataset", ) # Here we pass the card data as kwargs as well so template picks up pretty_name. card = DatasetCard.from_template(card_data, **card_data.to_dict()) self.assertTrue(card.text.strip().startswith("# Dataset Card for My Cool Dataset")) self.assertIsInstance(card, DatasetCard) @require_jinja def test_dataset_card_from_default_template_with_template_variables(self): card_data = DatasetCardData( language="en", license="mit", pretty_name="My Cool Dataset", ) # Here we pass the card data as kwargs as well so template picks up pretty_name. card = DatasetCard.from_template( card_data, repo="https://github.com/huggingface/huggingface_hub", paper="https://arxiv.org/pdf/1910.03771.pdf", dataset_summary=( "This is a test dataset card to check if the template variables " "in the dataset card template are working." ), ) self.assertTrue(card.text.strip().startswith("# Dataset Card for My Cool Dataset")) self.assertIsInstance(card, DatasetCard) matches = re.findall(r"Repository:\*\* https://github\.com/huggingface/huggingface_hub", str(card)) self.assertEqual(matches[0], "Repository:** https://github.com/huggingface/huggingface_hub") @require_jinja def test_dataset_card_from_custom_template(self): card = DatasetCard.from_template( card_data=DatasetCardData( language="en", license="mit", pretty_name="My Cool Dataset", ), template_path=SAMPLE_CARDS_DIR / "sample_datasetcard_template.md", pretty_name="My Cool Dataset", some_data="asdf", ) self.assertIsInstance(card, DatasetCard) # Title this time is just # {{ pretty_name }} self.assertTrue(card.text.strip().startswith("# My Cool Dataset")) # some_data is at the bottom of the template, so should end with whatever we passed to it self.assertTrue(card.text.strip().endswith("asdf")) @with_production_testing class SpaceCardTest(TestCaseWithHfApi): def test_load_spacecard_from_hub(self) -> None: card = SpaceCard.load("multimodalart/dreambooth-training") self.assertIsInstance(card, SpaceCard) self.assertIsInstance(card.data, SpaceCardData) self.assertEqual(card.data.title, "Dreambooth Training") self.assertIsNone(card.data.app_port) huggingface_hub-0.31.1/tests/test_repocard_data.py000066400000000000000000000304771500667546600222760ustar00rootroot00000000000000import unittest import pytest import yaml from huggingface_hub import SpaceCardData from huggingface_hub.repocard_data import ( CardData, DatasetCardData, EvalResult, ModelCardData, eval_results_to_model_index, model_index_to_eval_results, ) OPEN_LLM_LEADERBOARD_URL = "https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard" DUMMY_METADATA_WITH_MODEL_INDEX = """ language: en license: mit library_name: timm tags: - pytorch - image-classification datasets: - beans metrics: - acc model-index: - name: my-cool-model results: - task: type: image-classification dataset: type: beans name: Beans metrics: - type: acc value: 0.9 source: name: Open LLM Leaderboard url: https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard """ class BaseCardDataTest(unittest.TestCase): def test_metadata_behave_as_dict(self): metadata = CardData(foo="bar") # .get and __getitem__ self.assertEqual(metadata.get("foo"), "bar") self.assertEqual(metadata.get("FOO"), None) # case sensitive self.assertEqual(metadata["foo"], "bar") with self.assertRaises(KeyError): # case sensitive _ = metadata["FOO"] # __setitem__ metadata["foo"] = "BAR" self.assertEqual(metadata.get("foo"), "BAR") self.assertEqual(metadata["foo"], "BAR") # __contains__ self.assertTrue("foo" in metadata) self.assertFalse("FOO" in metadata) # default value # Should return default when key is not in metadata self.assertEqual(metadata.get("FOO", "default"), "default") # Should return default when key is in metadata but value is None metadata.FOO = None self.assertEqual(metadata.get("FOO", "default"), "default") # export self.assertEqual(str(metadata), "foo: BAR") # .pop self.assertEqual(metadata.pop("foo"), "BAR") class ModelCardDataTest(unittest.TestCase): def test_eval_results_to_model_index(self): expected_results = yaml.safe_load(DUMMY_METADATA_WITH_MODEL_INDEX) eval_results = [ EvalResult( task_type="image-classification", dataset_type="beans", dataset_name="Beans", metric_type="acc", metric_value=0.9, source_name="Open LLM Leaderboard", source_url=OPEN_LLM_LEADERBOARD_URL, ), ] model_index = eval_results_to_model_index("my-cool-model", eval_results) self.assertEqual(model_index, expected_results["model-index"]) def test_model_index_to_eval_results(self): model_index = [ { "name": "my-cool-model", "results": [ { "task": { "type": "image-classification", }, "dataset": { "type": "cats_vs_dogs", "name": "Cats vs. Dogs", }, "metrics": [ { "type": "acc", "value": 0.85, }, { "type": "f1", "value": 0.9, }, ], }, { "task": { "type": "image-classification", }, "dataset": { "type": "beans", "name": "Beans", }, "metrics": [ { "type": "acc", "value": 0.9, "verified": True, "verifyToken": 1234, } ], "source": { "name": "Open LLM Leaderboard", "url": OPEN_LLM_LEADERBOARD_URL, }, }, ], } ] model_name, eval_results = model_index_to_eval_results(model_index) self.assertEqual(len(eval_results), 3) self.assertEqual(model_name, "my-cool-model") self.assertEqual(eval_results[0].dataset_type, "cats_vs_dogs") self.assertIsNone(eval_results[0].source_name) self.assertIsNone(eval_results[0].source_url) self.assertEqual(eval_results[1].metric_type, "f1") self.assertEqual(eval_results[1].metric_value, 0.9) self.assertIsNone(eval_results[1].source_name) self.assertIsNone(eval_results[1].source_url) self.assertEqual(eval_results[2].task_type, "image-classification") self.assertEqual(eval_results[2].dataset_type, "beans") self.assertEqual(eval_results[2].verified, True) self.assertEqual(eval_results[2].verify_token, 1234) self.assertEqual(eval_results[2].source_name, "Open LLM Leaderboard") self.assertEqual(eval_results[2].source_url, OPEN_LLM_LEADERBOARD_URL) def test_card_data_requires_model_name_for_eval_results(self): with pytest.raises(ValueError, match="`eval_results` requires `model_name` to be set."): ModelCardData( eval_results=[ EvalResult( task_type="image-classification", dataset_type="beans", dataset_name="Beans", metric_type="acc", metric_value=0.9, ), ], ) data = ModelCardData( model_name="my-cool-model", eval_results=[ EvalResult( task_type="image-classification", dataset_type="beans", dataset_name="Beans", metric_type="acc", metric_value=0.9, ), ], ) model_index = eval_results_to_model_index(data.model_name, data.eval_results) self.assertEqual(model_index[0]["name"], "my-cool-model") self.assertEqual(model_index[0]["results"][0]["task"]["type"], "image-classification") def test_arbitrary_incoming_card_data(self): data = ModelCardData( model_name="my-cool-model", eval_results=[ EvalResult( task_type="image-classification", dataset_type="beans", dataset_name="Beans", metric_type="acc", metric_value=0.9, ), ], some_arbitrary_kwarg="some_value", ) self.assertEqual(data.some_arbitrary_kwarg, "some_value") data_dict = data.to_dict() self.assertEqual(data_dict["some_arbitrary_kwarg"], "some_value") def test_eval_result_with_incomplete_source(self): # Source url without name: ok EvalResult( task_type="image-classification", dataset_type="beans", dataset_name="Beans", metric_type="acc", metric_value=0.9, source_url=OPEN_LLM_LEADERBOARD_URL, ) # Source name without url: not ok with self.assertRaises(ValueError): EvalResult( task_type="image-classification", dataset_type="beans", dataset_name="Beans", metric_type="acc", metric_value=0.9, source_name="Open LLM Leaderboard", ) def test_model_card_unique_tags(self): data = ModelCardData(tags=["tag2", "tag1", "tag2", "tag3"]) assert data.tags == ["tag2", "tag1", "tag3"] def test_remove_top_level_none_values(self): as_obj = ModelCardData(tags=["tag1", None], foo={"bar": 3, "baz": None}, pipeline_tag=None) as_dict = as_obj.to_dict() assert as_obj.tags == ["tag1", None] assert as_dict["tags"] == ["tag1", None] # none value inside list should be kept assert as_obj.foo == {"bar": 3, "baz": None} assert as_dict["foo"] == {"bar": 3, "baz": None} # none value inside dict should be kept assert as_obj.pipeline_tag is None assert "pipeline_tag" not in as_dict # top level none value should be removed def test_eval_results_requires_evalresult_type(self): with pytest.raises(ValueError, match="should be of type `EvalResult` or a list of `EvalResult`"): ModelCardData(model_name="my-cool-model", eval_results="this is not an EvalResult") with pytest.raises(ValueError, match="should be of type `EvalResult` or a list of `EvalResult`"): ModelCardData(model_name="my-cool-model", eval_results=["accuracy: 0.9", "f1: 0.85"]) data = ModelCardData( model_name="my-cool-model", eval_results="this is not an EvalResult", ignore_metadata_errors=True, ) assert data.eval_results is not None and data.eval_results == "this is not an EvalResult" def test_model_name_required_with_eval_results(self): with pytest.raises(ValueError, match="`eval_results` requires `model_name` to be set"): ModelCardData( eval_results=[ EvalResult( task_type="image-classification", dataset_type="beans", dataset_name="Beans", metric_type="acc", metric_value=0.9, ), ], ) eval_results = [ EvalResult( task_type="image-classification", dataset_type="beans", dataset_name="Beans", metric_type="acc", metric_value=0.9, ), ] data = ModelCardData( eval_results=eval_results, ignore_metadata_errors=True, ) assert data.eval_results is not None and data.eval_results == eval_results class DatasetCardDataTest(unittest.TestCase): def test_train_eval_index_keys_updated(self): train_eval_index = [ { "config": "plain_text", "task": "text-classification", "task_id": "binary_classification", "splits": {"train_split": "train", "eval_split": "test"}, "col_mapping": {"text": "text", "label": "target"}, "metrics": [ { "type": "accuracy", "name": "Accuracy", }, {"type": "f1", "name": "F1 macro", "args": {"average": "macro"}}, ], } ] card_data = DatasetCardData( language="en", license="mit", pretty_name="My Cool Dataset", train_eval_index=train_eval_index, ) # The init should have popped this out of kwargs and into train_eval_index attr self.assertEqual(card_data.train_eval_index, train_eval_index) # Underlying train_eval_index gets converted to train-eval-index in DatasetCardData._to_dict. # So train_eval_index should be None in the dict self.assertTrue(card_data.to_dict().get("train_eval_index") is None) # And train-eval-index should be in the dict self.assertEqual(card_data.to_dict()["train-eval-index"], train_eval_index) class SpaceCardDataTest(unittest.TestCase): def test_space_card_data(self) -> None: card_data = SpaceCardData( title="Dreambooth Training", license="mit", sdk="gradio", duplicated_from="multimodalart/dreambooth-training", ) self.assertEqual( card_data.to_dict(), { "title": "Dreambooth Training", "sdk": "gradio", "license": "mit", "duplicated_from": "multimodalart/dreambooth-training", }, ) self.assertIsNone(card_data.tags) # SpaceCardData has some default attributes huggingface_hub-0.31.1/tests/test_repository.py000066400000000000000000001006611500667546600217160ustar00rootroot00000000000000# Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os import time import unittest from pathlib import Path import pytest import requests from huggingface_hub import RepoUrl from huggingface_hub.hf_api import HfApi from huggingface_hub.repository import ( Repository, is_tracked_upstream, is_tracked_with_lfs, ) from huggingface_hub.utils import SoftTemporaryDirectory, logging, run_subprocess from .testing_constants import ENDPOINT_STAGING, TOKEN from .testing_utils import ( expect_deprecation, repo_name, use_tmp_repo, with_production_testing, ) logger = logging.get_logger(__name__) @pytest.mark.usefixtures("fx_cache_dir") class RepositoryTestAbstract(unittest.TestCase): cache_dir: Path repo_path: Path # This content is 5MB (under 10MB) small_content = json.dumps([100] * int(1e6)) # This content is 20MB (over 10MB) large_content = json.dumps([100] * int(4e6)) # This content is binary (contains the null character) binary_content = "\x00\x00\x00\x00" _api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) @classmethod def setUp(self) -> None: self.repo_path = self.cache_dir / "working_dir" self.repo_path.mkdir() def _create_dummy_files(self): # Create dummy files # one is lfs-tracked, the other is not. small_file = self.repo_path / "dummy.txt" small_file.write_text(self.small_content) binary_file = self.repo_path / "model.bin" binary_file.write_text(self.binary_content) class TestRepositoryShared(RepositoryTestAbstract): """Tests in this class shares a single repo on the Hub (common to all tests). These tests must not push data to it. """ @classmethod def setUpClass(cls): """ Share this valid token in all tests below. """ super().setUpClass() cls.repo_url = cls._api.create_repo(repo_id=repo_name()) cls.repo_id = cls.repo_url.repo_id cls._api.upload_file( path_or_fileobj=cls.binary_content.encode(), path_in_repo="random_file.txt", repo_id=cls.repo_id, ) @classmethod def tearDownClass(cls): cls._api.delete_repo(repo_id=cls.repo_id) @expect_deprecation("Repository") def test_clone_from_repo_url(self): Repository(self.repo_path, clone_from=self.repo_url) @expect_deprecation("Repository") def test_clone_from_repo_id(self): Repository(self.repo_path, clone_from=self.repo_id) @expect_deprecation("Repository") def test_clone_from_repo_name_no_namespace_fails(self): with self.assertRaises(EnvironmentError): Repository(self.repo_path, clone_from=self.repo_id.split("/")[1], token=TOKEN) @expect_deprecation("Repository") def test_clone_from_not_hf_url(self): # Should not error out Repository(self.repo_path, clone_from="https://hf.co/hf-internal-testing/huggingface-hub-dummy-repository") @expect_deprecation("Repository") def test_clone_from_missing_repo(self): """If the repo does not exist an EnvironmentError is raised.""" with self.assertRaises(EnvironmentError): Repository(self.repo_path, clone_from="missing_repo") @expect_deprecation("Repository") @with_production_testing def test_clone_from_prod_canonical_repo_id(self): Repository(self.repo_path, clone_from="bert-base-cased", skip_lfs_files=True) @expect_deprecation("Repository") @with_production_testing def test_clone_from_prod_canonical_repo_url(self): Repository(self.repo_path, clone_from="https://huggingface.co/bert-base-cased", skip_lfs_files=True) @expect_deprecation("Repository") def test_init_from_existing_local_clone(self): run_subprocess(["git", "clone", self.repo_url, str(self.repo_path)]) repo = Repository(self.repo_path) repo.lfs_track(["*.pdf"]) repo.lfs_enable_largefiles() repo.git_pull() @expect_deprecation("Repository") def test_init_failure(self): with self.assertRaises(ValueError): Repository(self.repo_path) @expect_deprecation("Repository") def test_init_clone_in_empty_folder(self): repo = Repository(self.repo_path, clone_from=self.repo_url) repo.lfs_track(["*.pdf"]) repo.lfs_enable_largefiles() repo.git_pull() self.assertIn("random_file.txt", os.listdir(self.repo_path)) @expect_deprecation("Repository") def test_git_lfs_filename(self): run_subprocess("git init", folder=self.repo_path) repo = Repository(self.repo_path) large_file = self.repo_path / "large_file[].txt" large_file.write_text(self.large_content) repo.git_add() repo.lfs_track([large_file.name]) self.assertFalse(is_tracked_with_lfs(large_file)) repo.lfs_track([large_file.name], filename=True) self.assertTrue(is_tracked_with_lfs(large_file)) @expect_deprecation("Repository") def test_init_clone_in_nonempty_folder(self): self._create_dummy_files() with self.assertRaises(EnvironmentError): Repository(self.repo_path, clone_from=self.repo_url) @expect_deprecation("Repository") def test_init_clone_in_nonempty_linked_git_repo_with_token(self): Repository(self.repo_path, clone_from=self.repo_url, token=TOKEN) Repository(self.repo_path, clone_from=self.repo_url, token=TOKEN) @expect_deprecation("Repository") def test_is_tracked_upstream(self): Repository(self.repo_path, clone_from=self.repo_id) self.assertTrue(is_tracked_upstream(self.repo_path)) @expect_deprecation("Repository") def test_push_errors_on_wrong_checkout(self): repo = Repository(self.repo_path, clone_from=self.repo_id) head_commit_ref = run_subprocess("git show --oneline -s", folder=self.repo_path).stdout.split()[0] repo.git_checkout(head_commit_ref) with self.assertRaises(OSError): with repo.commit("New commit"): with open("new_file", "w+") as f: f.write("Ok") class TestRepositoryUniqueRepos(RepositoryTestAbstract): """Tests in this class use separated repos on the Hub (i.e. 1 test = 1 repo). These tests can push data to it. """ def setUp(self): super().setUp() self.repo_url = self._api.create_repo(repo_id=repo_name()) self.repo_id = self.repo_url.repo_id self._api.upload_file( path_or_fileobj=self.binary_content.encode(), path_in_repo="random_file.txt", repo_id=self.repo_id ) def tearDown(self): self._api.delete_repo(repo_id=self.repo_id) @expect_deprecation("Repository") def clone_repo(self, **kwargs) -> Repository: if "local_dir" not in kwargs: kwargs["local_dir"] = self.repo_path if "clone_from" not in kwargs: kwargs["clone_from"] = self.repo_url if "token" not in kwargs: kwargs["token"] = TOKEN if "git_user" not in kwargs: kwargs["git_user"] = "ci" if "git_email" not in kwargs: kwargs["git_email"] = "ci@dummy.com" return Repository(**kwargs) @use_tmp_repo() @expect_deprecation("Repository") def test_init_clone_in_nonempty_non_linked_git_repo(self, repo_url: RepoUrl): self.clone_repo() # Try and clone another repository within the same directory. # Should error out due to mismatched remotes. with self.assertRaises(EnvironmentError): Repository(self.repo_path, clone_from=repo_url) def test_init_clone_in_nonempty_linked_git_repo(self): # Clone the repository to disk self.clone_repo() # Add to the remote repository without doing anything to the local repository. self._api.upload_file( path_or_fileobj=self.binary_content.encode(), path_in_repo="random_file_3.txt", repo_id=self.repo_id ) # Cloning the repository in the same directory should not result in a git pull. self.clone_repo(clone_from=self.repo_url) self.assertNotIn("random_file_3.txt", os.listdir(self.repo_path)) def test_init_clone_in_nonempty_linked_git_repo_unrelated_histories(self): # Clone the repository to disk repo = self.clone_repo() # Create and commit file locally (self.repo_path / "random_file_3.txt").write_text("hello world") repo.git_add() repo.git_commit("Unrelated commit") # Add to the remote repository without doing anything to the local repository. self._api.upload_file( path_or_fileobj=self.binary_content.encode(), path_in_repo="random_file_3.txt", repo_id=self.repo_url.repo_id, ) # The repo should initialize correctly as the remote is the same, even with unrelated historied self.clone_repo() def test_add_commit_push(self): repo = self.clone_repo() self._create_dummy_files() repo.git_add() repo.git_commit() url = repo.git_push() # Check that the returned commit url # actually exists. r = requests.head(url) r.raise_for_status() def test_add_commit_push_non_blocking(self): repo = self.clone_repo() self._create_dummy_files() repo.git_add() repo.git_commit() url, result = repo.git_push(blocking=False) # Check background process if result._process.poll() is None: self.assertEqual(result.status, -1) while not result.is_done: time.sleep(0.5) self.assertTrue(result.is_done) self.assertEqual(result.status, 0) # Check that the returned commit url # actually exists. r = requests.head(url) r.raise_for_status() def test_context_manager_non_blocking(self): repo = self.clone_repo() with repo.commit("New commit", blocking=False): (self.repo_path / "dummy.txt").write_text("hello world") while repo.commands_in_progress: time.sleep(1) self.assertEqual(len(repo.commands_in_progress), 0) self.assertEqual(len(repo.command_queue), 1) self.assertEqual(repo.command_queue[-1].status, 0) self.assertEqual(repo.command_queue[-1].is_done, True) self.assertEqual(repo.command_queue[-1].title, "push") @unittest.skip("This is a flaky and legacy test") def test_add_commit_push_non_blocking_process_killed(self): repo = self.clone_repo() # Far too big file: will take forever (self.repo_path / "dummy.txt").write_text(str([[[1] * 10000] * 1000] * 10)) repo.git_add(auto_lfs_track=True) repo.git_commit() _, result = repo.git_push(blocking=False) result._process.kill() while result._process.poll() is None: time.sleep(0.5) self.assertTrue(result.is_done) self.assertEqual(result.status, -9) def test_commit_context_manager(self): # Clone and commit from a first folder folder_1 = self.repo_path / "folder_1" clone = self.clone_repo(local_dir=folder_1) with clone.commit("Commit"): with open("dummy.txt", "w") as f: f.write("hello") with open("model.bin", "w") as f: f.write("hello") # Clone in second folder. Check existence of committed files folder_2 = self.repo_path / "folder_2" self.clone_repo(local_dir=folder_2) files = os.listdir(folder_2) self.assertTrue("dummy.txt" in files) self.assertTrue("model.bin" in files) def test_clone_skip_lfs_files(self): # Upload LFS file self._api.upload_file(path_or_fileobj=b"Bin file", path_in_repo="file.bin", repo_id=self.repo_id) repo = self.clone_repo(skip_lfs_files=True) file_bin = self.repo_path / "file.bin" self.assertTrue(file_bin.read_text().startswith("version")) repo.git_pull(lfs=True) self.assertEqual(file_bin.read_text(), "Bin file") def test_commits_on_correct_branch(self): repo = self.clone_repo() branch = repo.current_branch repo.git_checkout("new-branch", create_branch_ok=True) repo.git_checkout(branch) with repo.commit("New commit"): with open("file.txt", "w+") as f: f.write("Ok") repo.git_checkout("new-branch") with repo.commit("New commit"): with open("new_file.txt", "w+") as f: f.write("Ok") with SoftTemporaryDirectory() as tmp: clone = self.clone_repo(local_dir=tmp) files = os.listdir(clone.local_dir) self.assertTrue("file.txt" in files) self.assertFalse("new_file.txt" in files) clone.git_checkout("new-branch") files = os.listdir(clone.local_dir) self.assertFalse("file.txt" in files) self.assertTrue("new_file.txt" in files) def test_repo_checkout_push(self): repo = self.clone_repo() repo.git_checkout("new-branch", create_branch_ok=True) repo.git_checkout("main") (self.repo_path / "file.txt").write_text("OK") repo.push_to_hub("Commit #1") repo.git_checkout("new-branch", create_branch_ok=True) (self.repo_path / "new_file.txt").write_text("OK") repo.push_to_hub("Commit #2") with SoftTemporaryDirectory() as tmp: clone = self.clone_repo(local_dir=tmp) files = os.listdir(clone.local_dir) self.assertTrue("file.txt" in files) self.assertFalse("new_file.txt" in files) clone.git_checkout("new-branch") files = os.listdir(clone.local_dir) self.assertFalse("file.txt" in files) self.assertTrue("new_file.txt" in files) def test_repo_checkout_commit_context_manager(self): repo = self.clone_repo() with repo.commit("Commit #1", branch="new-branch"): with open(os.path.join(repo.local_dir, "file.txt"), "w+") as f: f.write("Ok") with repo.commit("Commit #2", branch="main"): with open(os.path.join(repo.local_dir, "new_file.txt"), "w+") as f: f.write("Ok") # Maintains lastly used branch with repo.commit("Commit #3"): with open(os.path.join(repo.local_dir, "new_file-2.txt"), "w+") as f: f.write("Ok") with SoftTemporaryDirectory() as tmp: clone = self.clone_repo(local_dir=tmp) files = os.listdir(clone.local_dir) self.assertFalse("file.txt" in files) self.assertTrue("new_file-2.txt" in files) self.assertTrue("new_file.txt" in files) clone.git_checkout("new-branch") files = os.listdir(clone.local_dir) self.assertTrue("file.txt" in files) self.assertFalse("new_file.txt" in files) self.assertFalse("new_file-2.txt" in files) def test_add_tag(self): repo = self.clone_repo() repo.add_tag("v4.6.0", remote="origin") self.assertTrue(repo.tag_exists("v4.6.0", remote="origin")) def test_add_annotated_tag(self): repo = self.clone_repo() repo.add_tag("v4.5.0", message="This is an annotated tag", remote="origin") # Unfortunately git offers no built-in way to check the annotated # message of a remote tag. # In order to check that the remote tag was correctly annotated, # we delete the local tag before pulling the remote tag (which # should be the same). We then check that this tag is correctly # annotated. repo.delete_tag("v4.5.0") self.assertTrue(repo.tag_exists("v4.5.0", remote="origin")) self.assertFalse(repo.tag_exists("v4.5.0")) # Tag still exists on remote run_subprocess("git pull --tags", folder=self.repo_path) self.assertTrue(repo.tag_exists("v4.5.0")) # Tag is annotated result = run_subprocess("git tag -n9", folder=self.repo_path).stdout.strip() self.assertIn("This is an annotated tag", result) def test_delete_tag(self): repo = self.clone_repo() repo.add_tag("v4.6.0", message="This is an annotated tag", remote="origin") self.assertTrue(repo.tag_exists("v4.6.0", remote="origin")) repo.delete_tag("v4.6.0") self.assertFalse(repo.tag_exists("v4.6.0")) self.assertTrue(repo.tag_exists("v4.6.0", remote="origin")) repo.delete_tag("v4.6.0", remote="origin") self.assertFalse(repo.tag_exists("v4.6.0", remote="origin")) def test_lfs_prune(self): repo = self.clone_repo() with repo.commit("Committing LFS file"): with open("file.bin", "w+") as f: f.write("Random string 1") with repo.commit("Committing LFS file"): with open("file.bin", "w+") as f: f.write("Random string 2") root_directory = self.repo_path / ".git" / "lfs" git_lfs_files_size = sum(f.stat().st_size for f in root_directory.glob("**/*") if f.is_file()) repo.lfs_prune() post_prune_git_lfs_files_size = sum(f.stat().st_size for f in root_directory.glob("**/*") if f.is_file()) # Size of the directory holding LFS files was reduced self.assertLess(post_prune_git_lfs_files_size, git_lfs_files_size) def test_lfs_prune_git_push(self): repo = self.clone_repo() with repo.commit("Committing LFS file"): with open("file.bin", "w+") as f: f.write("Random string 1") root_directory = self.repo_path / ".git" / "lfs" git_lfs_files_size = sum(f.stat().st_size for f in root_directory.glob("**/*") if f.is_file()) with open(os.path.join(repo.local_dir, "file.bin"), "w+") as f: f.write("Random string 2") repo.git_add() repo.git_commit("New commit") repo.git_push(auto_lfs_prune=True) post_prune_git_lfs_files_size = sum(f.stat().st_size for f in root_directory.glob("**/*") if f.is_file()) # Size of the directory holding LFS files is the exact same self.assertEqual(post_prune_git_lfs_files_size, git_lfs_files_size) class TestRepositoryOffline(RepositoryTestAbstract): """Class to test `Repository` object on local folders only (no cloning from Hub).""" repo: Repository @classmethod @expect_deprecation("Repository") def setUp(self) -> None: super().setUp() run_subprocess("git init", folder=self.repo_path) self.repo = Repository(self.repo_path, git_user="ci", git_email="ci@dummy.ci") git_attributes_path = self.repo_path / ".gitattributes" git_attributes_path.write_text("*.pt filter=lfs diff=lfs merge=lfs -text") self.repo.git_add(".gitattributes") self.repo.git_commit("Add .gitattributes") def test_is_tracked_with_lfs(self): txt_1 = self.repo_path / "small_file_1.txt" txt_2 = self.repo_path / "small_file_2.txt" pt_1 = self.repo_path / "model.pt" txt_1.write_text(self.small_content) txt_2.write_text(self.small_content) pt_1.write_text(self.small_content) self.repo.lfs_track("small_file_1.txt") self.assertTrue(is_tracked_with_lfs(txt_1)) self.assertFalse(is_tracked_with_lfs(txt_2)) self.assertTrue(pt_1) def test_is_tracked_with_lfs_with_pattern(self): txt_small_file = self.repo_path / "small_file.txt" txt_small_file.write_text(self.small_content) txt_large_file = self.repo_path / "large_file.txt" txt_large_file.write_text(self.large_content) (self.repo_path / "dir").mkdir() txt_small_file_in_dir = self.repo_path / "dir" / "small_file.txt" txt_small_file_in_dir.write_text(self.small_content) txt_large_file_in_dir = self.repo_path / "dir" / "large_file.txt" txt_large_file_in_dir.write_text(self.large_content) self.repo.auto_track_large_files("dir") self.assertFalse(is_tracked_with_lfs(txt_large_file)) self.assertFalse(is_tracked_with_lfs(txt_small_file)) self.assertTrue(is_tracked_with_lfs(txt_large_file_in_dir)) self.assertFalse(is_tracked_with_lfs(txt_small_file_in_dir)) def test_auto_track_large_files(self): txt_small_file = self.repo_path / "small_file.txt" txt_small_file.write_text(self.small_content) txt_large_file = self.repo_path / "large_file.txt" txt_large_file.write_text(self.large_content) self.repo.auto_track_large_files() self.assertTrue(is_tracked_with_lfs(txt_large_file)) self.assertFalse(is_tracked_with_lfs(txt_small_file)) def test_auto_track_binary_files(self): non_binary_file = self.repo_path / "non_binary_file.txt" non_binary_file.write_text(self.small_content) binary_file = self.repo_path / "binary_file.txt" binary_file.write_text(self.binary_content) self.repo.auto_track_binary_files() self.assertFalse(is_tracked_with_lfs(non_binary_file)) self.assertTrue(is_tracked_with_lfs(binary_file)) def test_auto_track_large_files_ignored_with_gitignore(self): (self.repo_path / "dir").mkdir() # Test nested gitignores gitignore_file = self.repo_path / ".gitignore" gitignore_file.write_text("large_file.txt") gitignore_file_in_dir = self.repo_path / "dir" / ".gitignore" gitignore_file_in_dir.write_text("large_file_3.txt") large_file = self.repo_path / "large_file.txt" large_file.write_text(self.large_content) large_file_2 = self.repo_path / "large_file_2.txt" large_file_2.write_text(self.large_content) large_file_3 = self.repo_path / "dir" / "large_file_3.txt" large_file_3.write_text(self.large_content) large_file_4 = self.repo_path / "dir" / "large_file_4.txt" large_file_4.write_text(self.large_content) self.repo.auto_track_large_files() # Large files self.assertFalse(is_tracked_with_lfs(large_file)) self.assertTrue(is_tracked_with_lfs(large_file_2)) self.assertFalse(is_tracked_with_lfs(large_file_3)) self.assertTrue(is_tracked_with_lfs(large_file_4)) def test_auto_track_binary_files_ignored_with_gitignore(self): (self.repo_path / "dir").mkdir() # Test nested gitignores gitignore_file = self.repo_path / ".gitignore" gitignore_file.write_text("binary_file.txt") gitignore_file_in_dir = self.repo_path / "dir" / ".gitignore" gitignore_file_in_dir.write_text("binary_file_3.txt") binary_file = self.repo_path / "binary_file.txt" binary_file.write_text(self.binary_content) binary_file_2 = self.repo_path / "binary_file_2.txt" binary_file_2.write_text(self.binary_content) binary_file_3 = self.repo_path / "dir" / "binary_file_3.txt" binary_file_3.write_text(self.binary_content) binary_file_4 = self.repo_path / "dir" / "binary_file_4.txt" binary_file_4.write_text(self.binary_content) self.repo.auto_track_binary_files() # Binary files self.assertFalse(is_tracked_with_lfs(binary_file)) self.assertTrue(is_tracked_with_lfs(binary_file_2)) self.assertFalse(is_tracked_with_lfs(binary_file_3)) self.assertTrue(is_tracked_with_lfs(binary_file_4)) def test_auto_track_large_files_through_git_add(self): txt_small_file = self.repo_path / "small_file.txt" txt_small_file.write_text(self.small_content) txt_large_file = self.repo_path / "large_file.txt" txt_large_file.write_text(self.large_content) self.repo.git_add(auto_lfs_track=True) self.assertTrue(is_tracked_with_lfs(txt_large_file)) self.assertFalse(is_tracked_with_lfs(txt_small_file)) def test_auto_track_binary_files_through_git_add(self): non_binary_file = self.repo_path / "small_file.txt" non_binary_file.write_text(self.small_content) binary_file = self.repo_path / "binary.txt" binary_file.write_text(self.binary_content) self.repo.git_add(auto_lfs_track=True) self.assertTrue(is_tracked_with_lfs(binary_file)) self.assertFalse(is_tracked_with_lfs(non_binary_file)) def test_auto_no_track_large_files_through_git_add(self): txt_small_file = self.repo_path / "small_file.txt" txt_small_file.write_text(self.small_content) txt_large_file = self.repo_path / "large_file.txt" txt_large_file.write_text(self.large_content) self.repo.git_add(auto_lfs_track=False) self.assertFalse(is_tracked_with_lfs(txt_large_file)) self.assertFalse(is_tracked_with_lfs(txt_small_file)) def test_auto_no_track_binary_files_through_git_add(self): non_binary_file = self.repo_path / "small_file.txt" non_binary_file.write_text(self.small_content) binary_file = self.repo_path / "binary.txt" binary_file.write_text(self.binary_content) self.repo.git_add(auto_lfs_track=False) self.assertFalse(is_tracked_with_lfs(binary_file)) self.assertFalse(is_tracked_with_lfs(non_binary_file)) def test_auto_track_updates_removed_gitattributes(self): txt_small_file = self.repo_path / "small_file.txt" txt_small_file.write_text(self.small_content) txt_large_file = self.repo_path / "large_file.txt" txt_large_file.write_text(self.large_content) self.repo.git_add(auto_lfs_track=True) self.assertTrue(is_tracked_with_lfs(txt_large_file)) self.assertFalse(is_tracked_with_lfs(txt_small_file)) # Remove large file txt_large_file.unlink() # Auto track should remove the entry from .gitattributes self.repo.auto_track_large_files() # Recreate the large file with smaller contents txt_large_file.write_text(self.small_content) # Ensure the file is not LFS tracked anymore self.repo.auto_track_large_files() self.assertFalse(is_tracked_with_lfs(txt_large_file)) def test_checkout_non_existing_branch(self): self.assertRaises(EnvironmentError, self.repo.git_checkout, "brand-new-branch") def test_checkout_new_branch(self): self.repo.git_checkout("new-branch", create_branch_ok=True) self.assertEqual(self.repo.current_branch, "new-branch") def test_is_not_tracked_upstream(self): self.repo.git_checkout("new-branch", create_branch_ok=True) self.assertFalse(is_tracked_upstream(self.repo.local_dir)) def test_no_branch_checked_out_raises(self): head_commit_ref = run_subprocess("git show --oneline -s", folder=self.repo_path).stdout.split()[0] self.repo.git_checkout(head_commit_ref) self.assertRaises(OSError, is_tracked_upstream, self.repo.local_dir) @expect_deprecation("Repository") def test_repo_init_checkout_default_revision(self): # Instantiate repository on a given revision repo = Repository(self.repo_path, revision="new-branch") self.assertEqual(repo.current_branch, "new-branch") # The revision should be kept when re-initializing the repo repo_2 = Repository(self.repo_path) self.assertEqual(repo_2.current_branch, "new-branch") @expect_deprecation("Repository") def test_repo_init_checkout_revision(self): current_head_hash = self.repo.git_head_hash() (self.repo_path / "file.txt").write_text("hello world") self.repo.git_add() self.repo.git_commit("Add file.txt") new_head_hash = self.repo.git_head_hash() self.assertNotEqual(current_head_hash, new_head_hash) previous_head_repo = Repository(self.repo_path, revision=current_head_hash) files = os.listdir(previous_head_repo.local_dir) self.assertNotIn("file.txt", files) current_head_repo = Repository(self.repo_path, revision=new_head_hash) files = os.listdir(current_head_repo.local_dir) self.assertIn("file.txt", files) @expect_deprecation("Repository") def test_repo_user(self): _ = Repository(self.repo_path, token=TOKEN) username = run_subprocess("git config user.name", folder=self.repo_path).stdout email = run_subprocess("git config user.email", folder=self.repo_path).stdout # hardcode values to avoid another api call to whoami self.assertEqual(username.strip(), "Dummy User") self.assertEqual(email.strip(), "julien@huggingface.co") @expect_deprecation("Repository") def test_repo_passed_user(self): _ = Repository(self.repo_path, token=TOKEN, git_user="RANDOM_USER", git_email="EMAIL@EMAIL.EMAIL") username = run_subprocess("git config user.name", folder=self.repo_path).stdout email = run_subprocess("git config user.email", folder=self.repo_path).stdout self.assertEqual(username.strip(), "RANDOM_USER") self.assertEqual(email.strip(), "EMAIL@EMAIL.EMAIL") def test_add_tag(self): self.repo.add_tag("v4.6.0") self.assertTrue(self.repo.tag_exists("v4.6.0")) def test_add_annotated_tag(self): self.repo.add_tag("v4.6.0", message="This is an annotated tag") self.assertTrue(self.repo.tag_exists("v4.6.0")) result = run_subprocess("git tag -n9", folder=self.repo_path).stdout.strip() self.assertIn("This is an annotated tag", result) def test_delete_tag(self): self.repo.add_tag("v4.6.0", message="This is an annotated tag") self.assertTrue(self.repo.tag_exists("v4.6.0")) self.repo.delete_tag("v4.6.0") self.assertFalse(self.repo.tag_exists("v4.6.0")) def test_repo_clean(self): self.assertTrue(self.repo.is_repo_clean()) (self.repo_path / "file.txt").write_text("hello world") self.assertFalse(self.repo.is_repo_clean()) class TestRepositoryDataset(RepositoryTestAbstract): """Class to test that cloning from a different repo_type works fine.""" @classmethod def setUpClass(cls): super().setUpClass() cls.repo_url = cls._api.create_repo(repo_id=repo_name(), repo_type="dataset") cls.repo_id = cls.repo_url.repo_id cls._api.upload_file( path_or_fileobj=cls.binary_content.encode(), path_in_repo="file.txt", repo_id=cls.repo_id, repo_type="dataset", ) @classmethod def tearDownClass(cls): super().tearDownClass() cls._api.delete_repo(repo_id=cls.repo_id, repo_type="dataset") @expect_deprecation("Repository") def test_clone_dataset_with_endpoint_explicit_repo_type(self): Repository( self.repo_path, clone_from=self.repo_url, repo_type="dataset", git_user="ci", git_email="ci@dummy.com" ) self.assertTrue((self.repo_path / "file.txt").exists()) @expect_deprecation("Repository") def test_clone_dataset_with_endpoint_implicit_repo_type(self): self.assertIn("dataset", self.repo_url) # Implicit Repository(self.repo_path, clone_from=self.repo_url, git_user="ci", git_email="ci@dummy.com") self.assertTrue((self.repo_path / "file.txt").exists()) @expect_deprecation("Repository") def test_clone_dataset_with_repo_id_and_repo_type(self): Repository( self.repo_path, clone_from=self.repo_id, repo_type="dataset", git_user="ci", git_email="ci@dummy.com" ) self.assertTrue((self.repo_path / "file.txt").exists()) @expect_deprecation("Repository") def test_clone_dataset_no_ci_user_and_email(self): Repository(self.repo_path, clone_from=self.repo_id, repo_type="dataset") self.assertTrue((self.repo_path / "file.txt").exists()) @expect_deprecation("Repository") def test_clone_dataset_with_repo_name_and_repo_type_fails(self): with self.assertRaises(EnvironmentError): Repository( self.repo_path, clone_from=self.repo_id.split("/")[1], repo_type="dataset", token=TOKEN, git_user="ci", git_email="ci@dummy.com", ) huggingface_hub-0.31.1/tests/test_serialization.py000066400000000000000000000723731500667546600223640ustar00rootroot00000000000000import json import struct from pathlib import Path from typing import TYPE_CHECKING, Dict, List from unittest.mock import Mock import pytest from pytest_mock import MockerFixture from huggingface_hub import constants from huggingface_hub.serialization import ( get_tf_storage_size, get_torch_storage_size, load_state_dict_from_file, load_torch_model, save_torch_model, save_torch_state_dict, split_state_dict_into_shards_factory, split_torch_state_dict_into_shards, ) from huggingface_hub.serialization._base import parse_size_to_int from huggingface_hub.serialization._torch import _load_sharded_checkpoint from .testing_utils import requires if TYPE_CHECKING: import torch def _dummy_get_storage_id(item): return None def _dummy_get_storage_size(item): return sum(item) # util functions for checking the version for pytorch def is_wrapper_tensor_subclass_available(): try: from torch.utils._python_dispatch import is_traceable_wrapper_subclass # type: ignore[import] # noqa: F401 return True except ImportError: return False @pytest.fixture def dummy_state_dict() -> Dict[str, List[int]]: return { "layer_1": [6], "layer_2": [10], "layer_3": [30], "layer_4": [2], "layer_5": [2], } @pytest.fixture def torch_state_dict() -> Dict[str, "torch.Tensor"]: try: import torch return { "layer_1": torch.tensor([4]), "layer_2": torch.tensor([10]), "layer_3": torch.tensor([30]), "layer_4": torch.tensor([2]), "layer_5": torch.tensor([2]), } except ImportError: pytest.skip("torch is not available") @pytest.fixture def dummy_model(): try: import torch class DummyModel(torch.nn.Module): """Simple model for testing that matches the state dict `torch_state_dict` fixture.""" def __init__(self): super().__init__() self.register_parameter("layer_1", torch.nn.Parameter(torch.tensor([4.0]))) self.register_parameter("layer_2", torch.nn.Parameter(torch.tensor([10.0]))) self.register_parameter("layer_3", torch.nn.Parameter(torch.tensor([30.0]))) self.register_parameter("layer_4", torch.nn.Parameter(torch.tensor([2.0]))) self.register_parameter("layer_5", torch.nn.Parameter(torch.tensor([2.0]))) return DummyModel() except ImportError: pytest.skip("torch is not available") @pytest.fixture def torch_state_dict_tensor_subclass() -> Dict[str, "torch.Tensor"]: try: import torch # type: ignore[import] from torch.testing._internal.two_tensor import TwoTensor # type: ignore[import] t = torch.tensor([4]) return { "layer_1": torch.tensor([4]), "layer_2": torch.tensor([10]), "layer_3": torch.tensor([30]), "layer_4": torch.tensor([2]), "layer_5": torch.tensor([2]), "layer_6": TwoTensor(t, t), } except ImportError: pytest.skip("torch is not available") @pytest.fixture def torch_state_dict_shared_layers() -> Dict[str, "torch.Tensor"]: try: import torch # type: ignore[import] shared_layer = torch.tensor([4]) return { "shared_1": shared_layer, "unique_1": torch.tensor([10]), "unique_2": torch.tensor([30]), "shared_2": shared_layer, } except ImportError: pytest.skip("torch is not available") @pytest.fixture def torch_state_dict_shared_layers_tensor_subclass() -> Dict[str, "torch.Tensor"]: try: import torch # type: ignore[import] from torch.testing._internal.two_tensor import TwoTensor # type: ignore[import] t = torch.tensor([4]) tensor_subclass_tensor = TwoTensor(t, t) t = torch.tensor([4]) shared_tensor_subclass_tensor = TwoTensor(t, t) return { "layer_1": torch.tensor([4]), "layer_2": torch.tensor([10]), "layer_3": torch.tensor([30]), "layer_4": torch.tensor([2]), "layer_5": torch.tensor([2]), "layer_6": tensor_subclass_tensor, "ts_shared_1": shared_tensor_subclass_tensor, "ts_shared_2": shared_tensor_subclass_tensor, } except ImportError: pytest.skip("torch is not available") def test_single_shard(dummy_state_dict): state_dict_split = split_state_dict_into_shards_factory( dummy_state_dict, get_storage_id=_dummy_get_storage_id, get_storage_size=_dummy_get_storage_size, max_shard_size=100, # large shard size => only one shard filename_pattern="file{suffix}.dummy", ) assert not state_dict_split.is_sharded assert state_dict_split.filename_to_tensors == { # All layers fit in one shard => no suffix in filename "file.dummy": ["layer_1", "layer_2", "layer_3", "layer_4", "layer_5"], } assert state_dict_split.tensor_to_filename == { "layer_1": "file.dummy", "layer_2": "file.dummy", "layer_3": "file.dummy", "layer_4": "file.dummy", "layer_5": "file.dummy", } assert state_dict_split.metadata == {"total_size": 50} def test_multiple_shards(dummy_state_dict): state_dict_split = split_state_dict_into_shards_factory( dummy_state_dict, get_storage_id=_dummy_get_storage_id, get_storage_size=_dummy_get_storage_size, max_shard_size=10, # small shard size => multiple shards filename_pattern="file{suffix}.dummy", ) assert state_dict_split.is_sharded assert state_dict_split.filename_to_tensors == { # layer 4 and 5 could go in this one but assignment is not optimal, and it's fine "file-00001-of-00004.dummy": ["layer_1"], "file-00002-of-00004.dummy": ["layer_3"], "file-00003-of-00004.dummy": ["layer_2"], "file-00004-of-00004.dummy": ["layer_4", "layer_5"], } assert state_dict_split.tensor_to_filename == { "layer_1": "file-00001-of-00004.dummy", "layer_3": "file-00002-of-00004.dummy", "layer_2": "file-00003-of-00004.dummy", "layer_4": "file-00004-of-00004.dummy", "layer_5": "file-00004-of-00004.dummy", } assert state_dict_split.metadata == {"total_size": 50} def test_tensor_same_storage(): state_dict_split = split_state_dict_into_shards_factory( { "layer_1": [1], "layer_2": [2], "layer_3": [1], "layer_4": [2], "layer_5": [1], }, get_storage_id=lambda x: (x[0]), # dummy for test: storage id based on first element get_storage_size=_dummy_get_storage_size, max_shard_size=1, filename_pattern="model{suffix}.safetensors", ) assert state_dict_split.is_sharded assert state_dict_split.filename_to_tensors == { "model-00001-of-00002.safetensors": ["layer_2", "layer_4"], "model-00002-of-00002.safetensors": ["layer_1", "layer_3", "layer_5"], } assert state_dict_split.tensor_to_filename == { "layer_1": "model-00002-of-00002.safetensors", "layer_2": "model-00001-of-00002.safetensors", "layer_3": "model-00002-of-00002.safetensors", "layer_4": "model-00001-of-00002.safetensors", "layer_5": "model-00002-of-00002.safetensors", } assert state_dict_split.metadata == {"total_size": 3} # count them once @requires("tensorflow") def test_get_tf_storage_size(): import tensorflow as tf # type: ignore[import] assert get_tf_storage_size(tf.constant([1, 2, 3, 4, 5], dtype=tf.float64)) == 5 * 8 assert get_tf_storage_size(tf.constant([1, 2, 3, 4, 5], dtype=tf.float16)) == 5 * 2 @requires("torch") def test_get_torch_storage_size(): import torch # type: ignore[import] assert get_torch_storage_size(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float64)) == 5 * 8 assert get_torch_storage_size(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)) == 5 * 2 @requires("torch") @pytest.mark.skipif(not is_wrapper_tensor_subclass_available(), reason="requires torch 2.1 or higher") def test_get_torch_storage_size_wrapper_tensor_subclass(): import torch # type: ignore[import] from torch.testing._internal.two_tensor import TwoTensor # type: ignore[import] t = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float64) assert get_torch_storage_size(TwoTensor(t, t)) == 5 * 8 * 2 t = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16) assert get_torch_storage_size(TwoTensor(t, TwoTensor(t, t))) == 5 * 2 * 3 def test_parse_size_to_int(): assert parse_size_to_int("1KB") == 1 * 10**3 assert parse_size_to_int("2MB") == 2 * 10**6 assert parse_size_to_int("3GB") == 3 * 10**9 assert parse_size_to_int(" 10 KB ") == 10 * 10**3 # ok with whitespace assert parse_size_to_int("20mb") == 20 * 10**6 # ok with lowercase with pytest.raises(ValueError, match="Unit 'IB' not supported"): parse_size_to_int("1KiB") # not a valid unit with pytest.raises(ValueError, match="Could not parse the size value"): parse_size_to_int("1ooKB") # not a float def test_save_torch_model(mocker: MockerFixture, tmp_path: Path) -> None: """Test `save_torch_model` is only a wrapper around `save_torch_state_dict`.""" model_mock = Mock() safe_state_dict_mock = mocker.patch("huggingface_hub.serialization._torch.save_torch_state_dict") save_torch_model( model_mock, save_directory=tmp_path, filename_pattern="my-pattern", force_contiguous=True, max_shard_size="3GB", metadata={"foo": "bar"}, safe_serialization=True, is_main_process=True, shared_tensors_to_discard=None, ) safe_state_dict_mock.assert_called_once_with( state_dict=model_mock.state_dict.return_value, save_directory=tmp_path, filename_pattern="my-pattern", force_contiguous=True, max_shard_size="3GB", metadata={"foo": "bar"}, safe_serialization=True, is_main_process=True, shared_tensors_to_discard=None, ) def test_save_torch_state_dict_not_sharded(tmp_path: Path, torch_state_dict: Dict[str, "torch.Tensor"]) -> None: """Save as safetensors without sharding.""" save_torch_state_dict(torch_state_dict, tmp_path, max_shard_size="1GB") assert (tmp_path / "model.safetensors").is_file() assert not (tmp_path / "model.safetensors.index.json").is_file() def test_save_torch_state_dict_sharded(tmp_path: Path, torch_state_dict: Dict[str, "torch.Tensor"]) -> None: """Save as safetensors with sharding.""" save_torch_state_dict(torch_state_dict, tmp_path, max_shard_size=30) assert not (tmp_path / "model.safetensors").is_file() assert (tmp_path / "model.safetensors.index.json").is_file() assert (tmp_path / "model-00001-of-00002.safetensors").is_file() assert (tmp_path / "model-00001-of-00002.safetensors").is_file() assert json.loads((tmp_path / "model.safetensors.index.json").read_text("utf-8")) == { "metadata": {"total_size": 40}, "weight_map": { "layer_1": "model-00001-of-00002.safetensors", "layer_2": "model-00001-of-00002.safetensors", "layer_3": "model-00001-of-00002.safetensors", "layer_4": "model-00002-of-00002.safetensors", "layer_5": "model-00002-of-00002.safetensors", }, } def test_save_torch_state_dict_unsafe_not_sharded( tmp_path: Path, caplog: pytest.LogCaptureFixture, torch_state_dict: Dict[str, "torch.Tensor"] ) -> None: """Save as pickle without sharding.""" with caplog.at_level("WARNING"): save_torch_state_dict(torch_state_dict, tmp_path, max_shard_size="1GB", safe_serialization=False) assert "we strongly recommend using safe serialization" in caplog.text assert (tmp_path / "pytorch_model.bin").is_file() assert not (tmp_path / "pytorch_model.bin.index.json").is_file() @pytest.mark.skipif(not is_wrapper_tensor_subclass_available(), reason="requires torch 2.1 or higher") def test_save_torch_state_dict_tensor_subclass_unsafe_not_sharded( tmp_path: Path, caplog: pytest.LogCaptureFixture, torch_state_dict_tensor_subclass: Dict[str, "torch.Tensor"] ) -> None: """Save as pickle without sharding.""" with caplog.at_level("WARNING"): save_torch_state_dict( torch_state_dict_tensor_subclass, tmp_path, max_shard_size="1GB", safe_serialization=False ) assert "we strongly recommend using safe serialization" in caplog.text assert (tmp_path / "pytorch_model.bin").is_file() assert not (tmp_path / "pytorch_model.bin.index.json").is_file() @pytest.mark.skipif(not is_wrapper_tensor_subclass_available(), reason="requires torch 2.1 or higher") def test_save_torch_state_dict_shared_layers_tensor_subclass_unsafe_not_sharded( tmp_path: Path, caplog: pytest.LogCaptureFixture, torch_state_dict_shared_layers_tensor_subclass: Dict[str, "torch.Tensor"], ) -> None: """Save as pickle without sharding.""" with caplog.at_level("WARNING"): save_torch_state_dict( torch_state_dict_shared_layers_tensor_subclass, tmp_path, max_shard_size="1GB", safe_serialization=False ) assert "we strongly recommend using safe serialization" in caplog.text assert (tmp_path / "pytorch_model.bin").is_file() assert not (tmp_path / "pytorch_model.bin.index.json").is_file() def test_save_torch_state_dict_unsafe_sharded( tmp_path: Path, caplog: pytest.LogCaptureFixture, torch_state_dict: Dict[str, "torch.Tensor"] ) -> None: """Save as pickle with sharding.""" # Check logs with caplog.at_level("WARNING"): save_torch_state_dict(torch_state_dict, tmp_path, max_shard_size=30, safe_serialization=False) assert "we strongly recommend using safe serialization" in caplog.text assert not (tmp_path / "pytorch_model.bin").is_file() assert (tmp_path / "pytorch_model.bin.index.json").is_file() assert (tmp_path / "pytorch_model-00001-of-00002.bin").is_file() assert (tmp_path / "pytorch_model-00001-of-00002.bin").is_file() assert json.loads((tmp_path / "pytorch_model.bin.index.json").read_text("utf-8")) == { "metadata": {"total_size": 40}, "weight_map": { "layer_1": "pytorch_model-00001-of-00002.bin", "layer_2": "pytorch_model-00001-of-00002.bin", "layer_3": "pytorch_model-00001-of-00002.bin", "layer_4": "pytorch_model-00002-of-00002.bin", "layer_5": "pytorch_model-00002-of-00002.bin", }, } def test_save_torch_state_dict_shared_layers_not_sharded( tmp_path: Path, torch_state_dict_shared_layers: Dict[str, "torch.Tensor"] ) -> None: from safetensors.torch import load_file save_torch_state_dict(torch_state_dict_shared_layers, tmp_path, safe_serialization=True) safetensors_file = tmp_path / "model.safetensors" assert safetensors_file.is_file() # Check shared layer not duplicated in file state_dict = load_file(safetensors_file) assert "shared_1" in state_dict assert "shared_2" not in state_dict # Check shared layer info in metadata file_bytes = safetensors_file.read_bytes() metadata_str = file_bytes[ 8 : struct.unpack(" None: from safetensors.torch import load_file save_torch_state_dict(torch_state_dict_shared_layers, tmp_path, max_shard_size=2, safe_serialization=True) index_file = tmp_path / "model.safetensors.index.json" assert index_file.is_file() # Check shared layer info in index metadata index = json.loads(index_file.read_text()) assert index["metadata"]["shared_2"] == "shared_1" # Check shared layer not duplicated in files for filename in index["weight_map"].values(): state_dict = load_file(tmp_path / filename) assert "shared_2" not in state_dict def test_save_torch_state_dict_discard_selected_sharded( tmp_path: Path, torch_state_dict_shared_layers: Dict[str, "torch.Tensor"] ) -> None: from safetensors.torch import load_file save_torch_state_dict( torch_state_dict_shared_layers, tmp_path, max_shard_size=2, safe_serialization=True, shared_tensors_to_discard=["shared_1"], ) index_file = tmp_path / "model.safetensors.index.json" index = json.loads(index_file.read_text()) assert index["metadata"]["shared_1"] == "shared_2" for filename in index["weight_map"].values(): state_dict = load_file(tmp_path / filename) assert "shared_1" not in state_dict def test_save_torch_state_dict_discard_selected_not_sharded( tmp_path: Path, torch_state_dict_shared_layers: Dict[str, "torch.Tensor"] ) -> None: from safetensors.torch import load_file save_torch_state_dict( torch_state_dict_shared_layers, tmp_path, safe_serialization=True, shared_tensors_to_discard=["shared_1"], ) safetensors_file = tmp_path / "model.safetensors" assert safetensors_file.is_file() # Check shared layer not duplicated in file state_dict = load_file(safetensors_file) assert "shared_1" not in state_dict assert "shared_2" in state_dict # Check shared layer info in metadata file_bytes = safetensors_file.read_bytes() metadata_str = file_bytes[ 8 : struct.unpack(" None: """Custom filename pattern is respected.""" # Not sharded save_torch_state_dict(torch_state_dict, tmp_path, filename_pattern="model.variant{suffix}.safetensors") assert (tmp_path / "model.variant.safetensors").is_file() # Sharded save_torch_state_dict( torch_state_dict, tmp_path, filename_pattern="model.variant{suffix}.safetensors", max_shard_size=30 ) assert (tmp_path / "model.variant.safetensors.index.json").is_file() assert (tmp_path / "model.variant-00001-of-00002.safetensors").is_file() assert (tmp_path / "model.variant-00002-of-00002.safetensors").is_file() def test_save_torch_state_dict_delete_existing_files( tmp_path: Path, torch_state_dict: Dict[str, "torch.Tensor"] ) -> None: """Directory is cleaned before saving new files.""" (tmp_path / "model.safetensors").touch() (tmp_path / "model.safetensors.index.json").touch() (tmp_path / "model-00001-of-00003.safetensors").touch() (tmp_path / "model-00002-of-00003.safetensors").touch() (tmp_path / "model-00003-of-00003.safetensors").touch() (tmp_path / "pytorch_model.bin").touch() (tmp_path / "pytorch_model.bin.index.json").touch() (tmp_path / "pytorch_model-00001-of-00003.bin").touch() (tmp_path / "pytorch_model-00002-of-00003.bin").touch() (tmp_path / "pytorch_model-00003-of-00003.bin").touch() save_torch_state_dict(torch_state_dict, tmp_path) assert (tmp_path / "model.safetensors").stat().st_size > 0 # new file # Previous shards have been deleted assert not (tmp_path / "model.safetensors.index.json").is_file() # deleted assert not (tmp_path / "model-00001-of-00003.safetensors").is_file() # deleted assert not (tmp_path / "model-00002-of-00003.safetensors").is_file() # deleted assert not (tmp_path / "model-00003-of-00003.safetensors").is_file() # deleted # But not previous pickle files (since saving as safetensors) assert (tmp_path / "pytorch_model.bin").is_file() # not deleted assert (tmp_path / "pytorch_model.bin.index.json").is_file() assert (tmp_path / "pytorch_model-00001-of-00003.bin").is_file() assert (tmp_path / "pytorch_model-00002-of-00003.bin").is_file() assert (tmp_path / "pytorch_model-00003-of-00003.bin").is_file() def test_save_torch_state_dict_not_main_process( tmp_path: Path, torch_state_dict: Dict[str, "torch.Tensor"], ) -> None: """ Test that previous files in the directory are not deleted when is_main_process=False. When is_main_process=True, previous files should be deleted, this is already tested in `test_save_torch_state_dict_delete_existing_files`. """ # Create some .safetensors files before saving a new state dict. (tmp_path / "model.safetensors").touch() (tmp_path / "model-00001-of-00002.safetensors").touch() (tmp_path / "model-00002-of-00002.safetensors").touch() (tmp_path / "model.safetensors.index.json").touch() # Save with is_main_process=False save_torch_state_dict(torch_state_dict, tmp_path, is_main_process=False) # Previous files should still exist (not deleted) assert (tmp_path / "model.safetensors").is_file() assert (tmp_path / "model-00001-of-00002.safetensors").is_file() assert (tmp_path / "model-00002-of-00002.safetensors").is_file() assert (tmp_path / "model.safetensors.index.json").is_file() @requires("torch") def test_load_state_dict_from_file(tmp_path: Path, torch_state_dict: Dict[str, "torch.Tensor"]): """Test saving and loading a state dict with both safetensors and pickle formats.""" import torch # type: ignore[import] # Test safetensors format (default) save_torch_state_dict(torch_state_dict, tmp_path) loaded_dict = load_state_dict_from_file(tmp_path / "model.safetensors") assert isinstance(loaded_dict, dict) assert set(loaded_dict.keys()) == set(torch_state_dict.keys()) for key in torch_state_dict: assert torch.equal(loaded_dict[key], torch_state_dict[key]) # Test PyTorch pickle format save_torch_state_dict(torch_state_dict, tmp_path, safe_serialization=False) loaded_dict = load_state_dict_from_file(tmp_path / "pytorch_model.bin") assert isinstance(loaded_dict, dict) assert set(loaded_dict.keys()) == set(torch_state_dict.keys()) for key in torch_state_dict: assert torch.equal(loaded_dict[key], torch_state_dict[key]) @requires("torch") def test_load_sharded_state_dict( tmp_path: Path, torch_state_dict: Dict[str, "torch.Tensor"], dummy_model: "torch.nn.Module", ): """Test saving and loading a sharded state dict.""" import torch save_torch_state_dict( torch_state_dict, save_directory=tmp_path, max_shard_size=30, # Small size to force sharding ) # Verify sharding occurred index_file = tmp_path / "model.safetensors.index.json" assert index_file.exists() # Load and verify content result = _load_sharded_checkpoint(dummy_model, tmp_path) assert not result.missing_keys assert not result.unexpected_keys # Verify tensor values loaded_state_dict = dummy_model.state_dict() for key in torch_state_dict: assert torch.equal(loaded_state_dict[key], torch_state_dict[key]) @requires("torch") def test_load_from_directory_not_sharded( tmp_path: Path, torch_state_dict: Dict[str, "torch.Tensor"], dummy_model: "torch.nn.Module" ): import torch save_torch_state_dict(torch_state_dict, save_directory=tmp_path) # Verify no sharding occurred index_file = tmp_path / "model.safetensors.index.json" assert not index_file.exists() result = load_torch_model(dummy_model, tmp_path) assert not result.missing_keys assert not result.unexpected_keys loaded_state_dict = dummy_model.state_dict() for key in torch_state_dict: assert torch.equal(loaded_state_dict[key], torch_state_dict[key]) @pytest.mark.parametrize("safe_serialization", [True, False]) def test_load_state_dict_missing_file(safe_serialization): """Test proper error handling when file is missing.""" with pytest.raises(FileNotFoundError, match="No checkpoint file found"): load_state_dict_from_file( "nonexistent.safetensors" if safe_serialization else "nonexistent.bin", weights_only=False, ) def test_load_torch_model_directory_does_not_exist(): """Test proper error handling when directory does not contain a valid checkpoint.""" with pytest.raises(ValueError, match="Checkpoint path does_not_exist does not exist"): load_torch_model(Mock(), "does_not_exist") def test_load_torch_model_directory_does_not_contain_checkpoint(tmp_path): """Test proper error handling when directory does not contain a valid checkpoint.""" with pytest.raises(ValueError, match=r"Directory .* does not contain a valid checkpoint."): load_torch_model(Mock(), tmp_path) @pytest.mark.parametrize( "strict", [ True, False, ], ) def test_load_sharded_model_strict_mode(tmp_path, torch_state_dict, dummy_model, strict): """Test loading model with strict mode behavior for both sharded and non-sharded checkpoints.""" import torch # Add an extra key to the state dict modified_dict = {**torch_state_dict, "extra_key": torch.tensor([1.0])} # Save checkpoint save_torch_state_dict( modified_dict, save_directory=tmp_path, max_shard_size=30, ) if strict: with pytest.raises(RuntimeError, match=".*Unexpected key.*"): result = load_torch_model( model=dummy_model, checkpoint_path=tmp_path, strict=strict, ) else: result = load_torch_model( model=dummy_model, checkpoint_path=tmp_path, strict=strict, ) assert "extra_key" in result.unexpected_keys def test_load_torch_model_with_filename_pattern(tmp_path, torch_state_dict, dummy_model): """Test loading a model with a custom filename pattern.""" import torch save_torch_state_dict( torch_state_dict, save_directory=tmp_path, filename_pattern="model.variant{suffix}.safetensors", ) result = load_torch_model( dummy_model, tmp_path, filename_pattern="model.variant{suffix}.safetensors", ) assert not result.missing_keys assert not result.unexpected_keys loaded_state_dict = dummy_model.state_dict() for key in torch_state_dict: assert torch.equal(loaded_state_dict[key], torch_state_dict[key]) @pytest.mark.parametrize( "filename_pattern, safe, files_exist, expected_filename_pattern", [ ( None, True, ["model.safetensors.index.json"], constants.SAFETENSORS_WEIGHTS_FILE_PATTERN, ), # safetensors exists and safe=True -> load safetensors ( None, False, ["pytorch_model.bin.index.json"], constants.PYTORCH_WEIGHTS_FILE_PATTERN, ), # only picle file exists and safe=False -> load pickle files ( None, False, ["model.safetensors.index.json", "pytorch_model.bin.index.json"], constants.SAFETENSORS_WEIGHTS_FILE_PATTERN, ), # both exist and safe=False -> load safetensors ( "model.variant{suffix}.safetensors", False, ["model.variant.safetensors.index.json", "pytorch_model.bin.index.json"], "model.variant{suffix}.safetensors", ), # both exist and safe=False -> load safetensors # `filename_pattern` takes precedence over `safe` parameter ( "model.variant{suffix}.bin", False, ["model.variant.safetensors.index.json", "model.variant.bin.index.json"], "model.variant{suffix}.bin", ), # custom filename pattern and safe=False -> load custom file index ( "model.variant{suffix}.bin", True, ["model.variant.safetensors.index.json", "model.variant.bin.index.json"], "model.variant{suffix}.bin", ), # custom filename pattern and safe=False -> load custom file index ], ) @requires("torch") def test_load_torch_model_index_selection( tmp_path: Path, filename_pattern, safe, files_exist, expected_filename_pattern, mocker, ): """Test the logic for selecting between safetensors and pytorch index files.""" import torch class SimpleModel(torch.nn.Module): def __init__(self): super().__init__() self.layer_1 = torch.nn.Parameter(torch.tensor([0.0])) model = SimpleModel() # Create specified index files for filename in files_exist: (tmp_path / filename).touch() # Mock _load_sharded_checkpoint to capture the safe parameter mock_load = mocker.patch("huggingface_hub.serialization._torch._load_sharded_checkpoint") load_torch_model(model, tmp_path, safe=safe, filename_pattern=filename_pattern) mock_load.assert_called_once() assert mock_load.call_args.kwargs["filename_pattern"] == expected_filename_pattern huggingface_hub-0.31.1/tests/test_snapshot_download.py000066400000000000000000000311631500667546600232250ustar00rootroot00000000000000import os import unittest from pathlib import Path from unittest.mock import patch from huggingface_hub import CommitOperationAdd, HfApi, snapshot_download from huggingface_hub.errors import LocalEntryNotFoundError, RepositoryNotFoundError from huggingface_hub.utils import SoftTemporaryDirectory from .testing_constants import TOKEN from .testing_utils import OfflineSimulationMode, offline, repo_name class SnapshotDownloadTests(unittest.TestCase): @classmethod def setUpClass(cls): """ Share this valid token in all tests below. """ cls.api = HfApi(token=TOKEN) cls.repo_id = cls.api.create_repo(repo_name("snapshot-download")).repo_id # First commit on `main` cls.first_commit_hash = cls.api.create_commit( repo_id=cls.repo_id, operations=[ CommitOperationAdd(path_in_repo="dummy_file.txt", path_or_fileobj=b"v1"), CommitOperationAdd(path_in_repo="subpath/file.txt", path_or_fileobj=b"content in subpath"), ], commit_message="Add file to main branch", ).oid # Second commit on `main` cls.second_commit_hash = cls.api.create_commit( repo_id=cls.repo_id, operations=[ CommitOperationAdd(path_in_repo="dummy_file.txt", path_or_fileobj=b"v2"), CommitOperationAdd(path_in_repo="dummy_file_2.txt", path_or_fileobj=b"v3"), ], commit_message="Add file to main branch", ).oid # Third commit on `other` cls.api.create_branch(repo_id=cls.repo_id, branch="other") cls.third_commit_hash = cls.api.create_commit( repo_id=cls.repo_id, operations=[ CommitOperationAdd(path_in_repo="dummy_file_2.txt", path_or_fileobj=b"v4"), ], commit_message="Add file to other branch", revision="other", ).oid @classmethod def tearDownClass(cls) -> None: cls.api.delete_repo(repo_id=cls.repo_id) def test_download_model(self): # Test `main` branch with SoftTemporaryDirectory() as tmpdir: storage_folder = snapshot_download(self.repo_id, revision="main", cache_dir=tmpdir) # folder contains the two files contributed and the .gitattributes folder_contents = os.listdir(storage_folder) self.assertEqual(len(folder_contents), 4) self.assertTrue("dummy_file.txt" in folder_contents) self.assertTrue("dummy_file_2.txt" in folder_contents) self.assertTrue(".gitattributes" in folder_contents) with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f: contents = f.read() self.assertEqual(contents, "v2") # folder name contains the revision's commit sha. self.assertTrue(self.second_commit_hash in storage_folder) # Test with specific revision with SoftTemporaryDirectory() as tmpdir: storage_folder = snapshot_download( self.repo_id, revision=self.first_commit_hash, cache_dir=tmpdir, ) # folder contains the two files contributed and the .gitattributes folder_contents = os.listdir(storage_folder) self.assertEqual(len(folder_contents), 3) self.assertTrue("dummy_file.txt" in folder_contents) self.assertTrue(".gitattributes" in folder_contents) with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f: contents = f.read() self.assertEqual(contents, "v1") # folder name contains the revision's commit sha. self.assertTrue(self.first_commit_hash in storage_folder) def test_download_private_model(self): self.api.update_repo_settings(repo_id=self.repo_id, private=True) # Test download fails without token with SoftTemporaryDirectory() as tmpdir: with self.assertRaises(RepositoryNotFoundError): _ = snapshot_download(self.repo_id, revision="main", cache_dir=tmpdir) # Test we can download with token from cache with patch("huggingface_hub.utils._headers.get_token", return_value=TOKEN): with SoftTemporaryDirectory() as tmpdir: storage_folder = snapshot_download(self.repo_id, revision="main", cache_dir=tmpdir) self.assertTrue(self.second_commit_hash in storage_folder) # Test we can download with explicit token with SoftTemporaryDirectory() as tmpdir: storage_folder = snapshot_download(self.repo_id, revision="main", cache_dir=tmpdir, token=TOKEN) self.assertTrue(self.second_commit_hash in storage_folder) self.api.update_repo_settings(repo_id=self.repo_id, private=False) def test_download_model_local_only(self): # Test no branch specified with SoftTemporaryDirectory() as tmpdir: # first download folder to cache it snapshot_download(self.repo_id, cache_dir=tmpdir) # now load from cache storage_folder = snapshot_download(self.repo_id, cache_dir=tmpdir, local_files_only=True) self.assertTrue(self.second_commit_hash in storage_folder) # has expected revision # Test with specific revision branch with SoftTemporaryDirectory() as tmpdir: # first download folder to cache it snapshot_download(self.repo_id, revision="other", cache_dir=tmpdir) # now load from cache storage_folder = snapshot_download(self.repo_id, revision="other", cache_dir=tmpdir, local_files_only=True) self.assertTrue(self.third_commit_hash in storage_folder) # has expected revision # Test with specific revision hash with SoftTemporaryDirectory() as tmpdir: # first download folder to cache it snapshot_download(self.repo_id, revision=self.first_commit_hash, cache_dir=tmpdir) # now load from cache storage_folder = snapshot_download( self.repo_id, revision=self.first_commit_hash, cache_dir=tmpdir, local_files_only=True ) self.assertTrue(self.first_commit_hash in storage_folder) # has expected revision # Test with local_dir with SoftTemporaryDirectory() as tmpdir: # first download folder to local_dir snapshot_download(self.repo_id, local_dir=tmpdir) # now load from local_dir storage_folder = snapshot_download(self.repo_id, local_dir=tmpdir, local_files_only=True) self.assertEqual(str(tmpdir), storage_folder) def test_download_model_to_local_dir_with_offline_mode(self): """Test that an already downloaded folder is returned when there is a connection error""" # first download folder to local_dir with SoftTemporaryDirectory() as tmpdir: snapshot_download(self.repo_id, local_dir=tmpdir) # Check that the folder is returned when there is a connection error for offline_mode in OfflineSimulationMode: with offline(mode=offline_mode): storage_folder = snapshot_download(self.repo_id, local_dir=tmpdir) self.assertEqual(str(tmpdir), storage_folder) def test_offline_mode_with_cache_and_empty_local_dir(self): """Test that when cache exists but an empty local_dir is specified in offline mode, we raise an error.""" with SoftTemporaryDirectory() as tmpdir_cache: snapshot_download(self.repo_id, cache_dir=tmpdir_cache) for offline_mode in OfflineSimulationMode: with offline(mode=offline_mode): with self.assertRaises(LocalEntryNotFoundError): with SoftTemporaryDirectory() as tmpdir: snapshot_download(self.repo_id, cache_dir=tmpdir_cache, local_dir=tmpdir) def test_download_model_offline_mode_not_in_local_dir(self): """Test when connection error but local_dir is empty.""" with SoftTemporaryDirectory() as tmpdir: with self.assertRaises(LocalEntryNotFoundError): snapshot_download(self.repo_id, local_dir=tmpdir, local_files_only=True) for offline_mode in OfflineSimulationMode: with offline(mode=offline_mode): with SoftTemporaryDirectory() as tmpdir: with self.assertRaises(LocalEntryNotFoundError): snapshot_download(self.repo_id, local_dir=tmpdir) def test_download_model_offline_mode_not_cached(self): """Test when connection error but cache is empty.""" with SoftTemporaryDirectory() as tmpdir: with self.assertRaises(LocalEntryNotFoundError): snapshot_download(self.repo_id, cache_dir=tmpdir, local_files_only=True) for offline_mode in OfflineSimulationMode: with offline(mode=offline_mode): with SoftTemporaryDirectory() as tmpdir: with self.assertRaises(LocalEntryNotFoundError): snapshot_download(self.repo_id, cache_dir=tmpdir) def test_download_model_local_only_multiple(self): # cache multiple commits and make sure correct commit is taken with SoftTemporaryDirectory() as tmpdir: # download folder from main and other to cache it snapshot_download(self.repo_id, cache_dir=tmpdir) snapshot_download(self.repo_id, revision="other", cache_dir=tmpdir) # now make sure that loading "main" branch gives correct branch # folder name contains the 2nd commit sha and not the 3rd storage_folder = snapshot_download(self.repo_id, cache_dir=tmpdir, local_files_only=True) self.assertTrue(self.second_commit_hash in storage_folder) def check_download_model_with_pattern(self, pattern, allow=True): # Test `main` branch allow_patterns = pattern if allow else None ignore_patterns = pattern if not allow else None with SoftTemporaryDirectory() as tmpdir: storage_folder = snapshot_download( self.repo_id, revision="main", cache_dir=tmpdir, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, ) # folder contains the three text files but not the .gitattributes folder_contents = os.listdir(storage_folder) self.assertEqual(len(folder_contents), 3) self.assertTrue("dummy_file.txt" in folder_contents) self.assertTrue("dummy_file_2.txt" in folder_contents) self.assertTrue(".gitattributes" not in folder_contents) def test_download_model_with_allow_pattern(self): self.check_download_model_with_pattern("*.txt") def test_download_model_with_allow_pattern_list(self): self.check_download_model_with_pattern(["dummy_file.txt", "dummy_file_2.txt", "subpath/*"]) def test_download_model_with_ignore_pattern(self): self.check_download_model_with_pattern(".gitattributes", allow=False) def test_download_model_with_ignore_pattern_list(self): self.check_download_model_with_pattern(["*.git*", "*.pt"], allow=False) def test_download_to_local_dir(self) -> None: """Download a repository to local dir. Cache dir is not used. Symlinks are not used. This test is here to check once the normal behavior with snapshot_download. More individual tests exists in `test_file_download.py`. """ with SoftTemporaryDirectory() as cache_dir: with SoftTemporaryDirectory() as local_dir: returned_path = snapshot_download(self.repo_id, cache_dir=cache_dir, local_dir=local_dir) # Files have been downloaded in correct structure assert (Path(local_dir) / "dummy_file.txt").is_file() assert (Path(local_dir) / "dummy_file_2.txt").is_file() assert (Path(local_dir) / "subpath" / "file.txt").is_file() # Symlinks are not used anymore assert not (Path(local_dir) / "dummy_file.txt").is_symlink() assert not (Path(local_dir) / "dummy_file_2.txt").is_symlink() assert not (Path(local_dir) / "subpath" / "file.txt").is_symlink() # Check returns local dir and not cache dir assert Path(returned_path).resolve() == Path(local_dir).resolve() # Nothing has been added to cache dir (except some subfolders created) for path in cache_dir.glob("*"): assert path.is_dir() huggingface_hub-0.31.1/tests/test_testing_configuration.py000066400000000000000000000002541500667546600241000ustar00rootroot00000000000000from huggingface_hub import get_token def test_no_token_in_staging_environment(): """Make sure no token is set in test environment.""" assert get_token() is None huggingface_hub-0.31.1/tests/test_tf_import.py000066400000000000000000000012661500667546600215030ustar00rootroot00000000000000import sys import unittest from huggingface_hub.utils import is_tf_available def require_tf(test_case): """ Decorator marking a test that requires TensorFlow. These tests are skipped when TensorFlow is not installed. """ if not is_tf_available(): return unittest.skip("test requires Tensorflow")(test_case) else: return test_case @require_tf def test_import_huggingface_hub_does_not_import_tensorflow(): # `import huggingface_hub` is not necessary since huggingface_hub is already imported at the top of this file, # but let's keep it here anyway just in case import huggingface_hub # noqa assert "tensorflow" not in sys.modules huggingface_hub-0.31.1/tests/test_upload_large_folder.py000066400000000000000000000023101500667546600234600ustar00rootroot00000000000000# tests/test_upload_large_folder.py import pytest from huggingface_hub._upload_large_folder import COMMIT_SIZE_SCALE, LargeUploadStatus @pytest.fixture def status(): return LargeUploadStatus(items=[]) def test_target_chunk_default(status): assert status.target_chunk() == COMMIT_SIZE_SCALE[1] @pytest.mark.parametrize( "start_idx, success, delta_items, duration, expected_idx", [ (2, False, 0, 10, 1), # drop by one on failure (0, False, 0, 10, 0), # never go below zero (1, True, 0, 50, 1), # duration >= 40 --> no bump (1, True, -1, 30, 1), # nb_items < threshold --> no bump (1, True, 0, 30, 2), # fast enough and enough items (len(COMMIT_SIZE_SCALE) - 1, True, 0, 10, len(COMMIT_SIZE_SCALE) - 1), # never exceed last index ], ) def test_update_chunk_transitions(status, start_idx, success, delta_items, duration, expected_idx): status._chunk_idx = start_idx threshold = COMMIT_SIZE_SCALE[start_idx] nb_items = threshold + delta_items status.update_chunk(success=success, nb_items=nb_items, duration=duration) assert status._chunk_idx == expected_idx assert status.target_chunk() == COMMIT_SIZE_SCALE[expected_idx] huggingface_hub-0.31.1/tests/test_utils_assets.py000066400000000000000000000063671500667546600222310ustar00rootroot00000000000000import unittest from pathlib import Path from unittest.mock import patch import pytest from huggingface_hub import cached_assets_path @pytest.mark.usefixtures("fx_cache_dir") class CacheAssetsTest(unittest.TestCase): cache_dir: Path def test_cached_assets_path_with_namespace_and_subfolder(self) -> None: expected_path = self.cache_dir / "datasets" / "SQuAD" / "download" self.assertFalse(expected_path.is_dir()) path = cached_assets_path( library_name="datasets", namespace="SQuAD", subfolder="download", assets_dir=self.cache_dir, ) self.assertEqual(path, expected_path) # Path is generated self.assertTrue(path.is_dir()) # And dir is created def test_cached_assets_path_without_subfolder(self) -> None: path = cached_assets_path(library_name="datasets", namespace="SQuAD", assets_dir=self.cache_dir) self.assertEqual(path, self.cache_dir / "datasets" / "SQuAD" / "default") self.assertTrue(path.is_dir()) def test_cached_assets_path_without_namespace(self) -> None: path = cached_assets_path(library_name="datasets", subfolder="download", assets_dir=self.cache_dir) self.assertEqual(path, self.cache_dir / "datasets" / "default" / "download") self.assertTrue(path.is_dir()) def test_cached_assets_path_without_namespace_and_subfolder(self) -> None: path = cached_assets_path(library_name="datasets", assets_dir=self.cache_dir) self.assertEqual(path, self.cache_dir / "datasets" / "default" / "default") self.assertTrue(path.is_dir()) def test_cached_assets_path_forbidden_symbols(self) -> None: path = cached_assets_path( library_name="ReAlLy dumb", namespace="user/repo_name", subfolder="this is/not\\clever", assets_dir=self.cache_dir, ) self.assertEqual( path, self.cache_dir / "ReAlLy--dumb" / "user--repo_name" / "this--is--not--clever", ) self.assertTrue(path.is_dir()) def test_cached_assets_path_default_assets_dir(self) -> None: with patch( "huggingface_hub.utils._cache_assets.HF_ASSETS_CACHE", self.cache_dir, ): # Uses environment variable from HF_ASSETS_CACHE self.assertEqual( cached_assets_path(library_name="datasets"), self.cache_dir / "datasets" / "default" / "default", ) def test_cached_assets_path_is_a_file(self) -> None: expected_path = self.cache_dir / "datasets" / "default" / "default" expected_path.parent.mkdir(parents=True) expected_path.touch() # this should be the generated folder but is a file ! with self.assertRaises(ValueError): cached_assets_path(library_name="datasets", assets_dir=self.cache_dir) def test_cached_assets_path_parent_is_a_file(self) -> None: expected_path = self.cache_dir / "datasets" / "default" / "default" expected_path.parent.parent.mkdir(parents=True) expected_path.parent.touch() # cannot create folder as a parent is a file ! with self.assertRaises(ValueError): cached_assets_path(library_name="datasets", assets_dir=self.cache_dir) huggingface_hub-0.31.1/tests/test_utils_cache.py000066400000000000000000001052321500667546600217610ustar00rootroot00000000000000import os import tempfile import time import unittest from pathlib import Path from unittest.mock import Mock import pytest from huggingface_hub._snapshot_download import snapshot_download from huggingface_hub.commands.scan_cache import ScanCacheCommand from huggingface_hub.utils import DeleteCacheStrategy, HFCacheInfo, capture_output, scan_cache_dir from huggingface_hub.utils._cache_manager import ( CacheNotFound, _format_size, _format_timesince, _try_delete_path, ) from .testing_utils import ( rmtree_with_retry, with_production_testing, xfail_on_windows, ) # On production server to avoid recreating them all the time MODEL_ID = "hf-internal-testing/hfh_ci_scan_repo_a" MODEL_PATH = "models--hf-internal-testing--hfh_ci_scan_repo_a" DATASET_ID = "hf-internal-testing/hfh_ci_scan_dataset_b" DATASET_PATH = "datasets--hf-internal-testing--hfh_ci_scan_dataset_b" REPO_A_MAIN_HASH = "c0d57e03d9f128062eadb6665618982db612b2e3" REPO_A_PR_1_HASH = "1a665a9d28a66b1d0f8edd9359fc824aacc63234" REPO_A_OTHER_HASH = "f95875cd910793299a545417cc4b3c9055202883" REPO_A_MAIN_README_BLOB_HASH = "fffc22b462ba2368b09b4d38527760051c9090a9" REPO_B_MAIN_HASH = "f1cdcd4641b3ea2dfa8d4333dba1ea3b532735e1" REF_1_NAME = "refs/pr/1" @pytest.mark.usefixtures("fx_cache_dir") class TestMissingCacheUtils(unittest.TestCase): cache_dir: Path def test_cache_dir_is_missing(self) -> None: """Directory to scan does not exist raises CacheNotFound.""" self.assertRaises(CacheNotFound, scan_cache_dir, self.cache_dir / "does_not_exist") def test_cache_dir_is_a_file(self) -> None: """Directory to scan is a file raises ValueError.""" file_path = self.cache_dir / "file.txt" file_path.touch() self.assertRaises(ValueError, scan_cache_dir, file_path) @pytest.mark.usefixtures("fx_cache_dir") class TestValidCacheUtils(unittest.TestCase): cache_dir: Path @with_production_testing def setUp(self) -> None: """Setup a clean cache for tests that will remain valid in all tests.""" # Download latest main snapshot_download(repo_id=MODEL_ID, repo_type="model", cache_dir=self.cache_dir) # Download latest commit which is same as `main` snapshot_download(repo_id=MODEL_ID, revision=REPO_A_MAIN_HASH, repo_type="model", cache_dir=self.cache_dir) # Download the first commit snapshot_download(repo_id=MODEL_ID, revision=REPO_A_OTHER_HASH, repo_type="model", cache_dir=self.cache_dir) # Download from a PR snapshot_download(repo_id=MODEL_ID, revision="refs/pr/1", repo_type="model", cache_dir=self.cache_dir) # Download a Dataset repo from "main" snapshot_download(repo_id=DATASET_ID, revision="main", repo_type="dataset", cache_dir=self.cache_dir) @unittest.skipIf(os.name == "nt", "Windows cache is tested separately") def test_scan_cache_on_valid_cache_unix(self) -> None: """Scan the cache dir without warnings (on unix-based platform). This test is duplicated and adapted for Windows in `test_scan_cache_on_valid_cache_windows`. Note: Please make sure to updated both if any change is made. """ report = scan_cache_dir(self.cache_dir) # Check general information about downloaded snapshots self.assertEqual(report.size_on_disk, 3766) self.assertEqual(len(report.repos), 2) # Model and dataset self.assertEqual(len(report.warnings), 0) # Repos are valid repo_a = [repo for repo in report.repos if repo.repo_id == MODEL_ID][0] # Check repo A general information repo_a_path = self.cache_dir / MODEL_PATH self.assertEqual(repo_a.repo_id, MODEL_ID) self.assertEqual(repo_a.repo_type, "model") self.assertEqual(repo_a.repo_path, repo_a_path) # 4 downloads but 3 revisions because "main" and REPO_A_MAIN_HASH are the same self.assertEqual(len(repo_a.revisions), 3) self.assertEqual( {rev.commit_hash for rev in repo_a.revisions}, {REPO_A_MAIN_HASH, REPO_A_PR_1_HASH, REPO_A_OTHER_HASH}, ) # Repo size on disk is less than sum of revisions ! self.assertEqual(repo_a.size_on_disk, 1501) self.assertEqual(sum(rev.size_on_disk for rev in repo_a.revisions), 4463) # Repo nb files is less than sum of revisions ! self.assertEqual(repo_a.nb_files, 3) self.assertEqual(sum(rev.nb_files for rev in repo_a.revisions), 6) # 2 REFS in the repo: "main" and "refs/pr/1" # We could have add a tag as well self.assertEqual(set(repo_a.refs.keys()), {"main", REF_1_NAME}) self.assertEqual(repo_a.refs["main"].commit_hash, REPO_A_MAIN_HASH) self.assertEqual(repo_a.refs[REF_1_NAME].commit_hash, REPO_A_PR_1_HASH) # Check "main" revision information main_revision = repo_a.refs["main"] main_revision_path = repo_a_path / "snapshots" / REPO_A_MAIN_HASH self.assertEqual(main_revision.commit_hash, REPO_A_MAIN_HASH) self.assertEqual(main_revision.snapshot_path, main_revision_path) self.assertEqual(main_revision.refs, {"main"}) # Same nb of files and size on disk that the sum self.assertEqual(main_revision.nb_files, len(main_revision.files)) self.assertEqual( main_revision.size_on_disk, sum(file.size_on_disk for file in main_revision.files), ) # Check readme file from "main" revision main_readme_file = [file for file in main_revision.files if file.file_name == "README.md"][0] main_readme_file_path = main_revision_path / "README.md" main_readme_blob_path = repo_a_path / "blobs" / REPO_A_MAIN_README_BLOB_HASH self.assertEqual(main_readme_file.file_name, "README.md") self.assertEqual(main_readme_file.file_path, main_readme_file_path) self.assertEqual(main_readme_file.blob_path, main_readme_blob_path) # Check readme file from "refs/pr/1" revision pr_1_revision = repo_a.refs[REF_1_NAME] pr_1_revision_path = repo_a_path / "snapshots" / REPO_A_PR_1_HASH pr_1_readme_file = [file for file in pr_1_revision.files if file.file_name == "README.md"][0] pr_1_readme_file_path = pr_1_revision_path / "README.md" # file_path in "refs/pr/1" revision is different than "main" but same blob path self.assertEqual(pr_1_readme_file.file_path, pr_1_readme_file_path) # different self.assertEqual(pr_1_readme_file.blob_path, main_readme_blob_path) # same @unittest.skipIf(os.name != "nt", "Windows cache is tested separately") def test_scan_cache_on_valid_cache_windows(self) -> None: """Scan the cache dir without warnings (on Windows). Windows tests do not use symlinks which leads to duplication in the cache. This test is duplicated from `test_scan_cache_on_valid_cache_unix` with a few tweaks specific to windows. Note: Please make sure to updated both if any change is made. """ report = scan_cache_dir(self.cache_dir) # Check general information about downloaded snapshots self.assertEqual(report.size_on_disk, 6728) self.assertEqual(len(report.repos), 2) # Model and dataset self.assertEqual(len(report.warnings), 0) # Repos are valid repo_a = [repo for repo in report.repos if repo.repo_id == MODEL_ID][0] # Check repo A general information repo_a_path = self.cache_dir / MODEL_PATH self.assertEqual(repo_a.repo_id, MODEL_ID) self.assertEqual(repo_a.repo_type, "model") self.assertEqual(repo_a.repo_path, repo_a_path) # 4 downloads but 3 revisions because "main" and REPO_A_MAIN_HASH are the same self.assertEqual(len(repo_a.revisions), 3) self.assertEqual( {rev.commit_hash for rev in repo_a.revisions}, {REPO_A_MAIN_HASH, REPO_A_PR_1_HASH, REPO_A_OTHER_HASH}, ) # Repo size on disk is equal to the sum of revisions (no symlinks) self.assertEqual(repo_a.size_on_disk, 4463) # Windows-specific self.assertEqual(sum(rev.size_on_disk for rev in repo_a.revisions), 4463) # Repo nb files is equal to the sum of revisions ! self.assertEqual(repo_a.nb_files, 6) # Windows-specific self.assertEqual(sum(rev.nb_files for rev in repo_a.revisions), 6) # 2 REFS in the repo: "main" and "refs/pr/1" # We could have add a tag as well REF_1_NAME = "refs\\pr\\1" # Windows-specific self.assertEqual(set(repo_a.refs.keys()), {"main", REF_1_NAME}) self.assertEqual(repo_a.refs["main"].commit_hash, REPO_A_MAIN_HASH) self.assertEqual(repo_a.refs[REF_1_NAME].commit_hash, REPO_A_PR_1_HASH) # Check "main" revision information main_revision = repo_a.refs["main"] main_revision_path = repo_a_path / "snapshots" / REPO_A_MAIN_HASH self.assertEqual(main_revision.commit_hash, REPO_A_MAIN_HASH) self.assertEqual(main_revision.snapshot_path, main_revision_path) self.assertEqual(main_revision.refs, {"main"}) # Same nb of files and size on disk that the sum self.assertEqual(main_revision.nb_files, len(main_revision.files)) self.assertEqual( main_revision.size_on_disk, sum(file.size_on_disk for file in main_revision.files), ) # Check readme file from "main" revision main_readme_file = [file for file in main_revision.files if file.file_name == "README.md"][0] main_readme_file_path = main_revision_path / "README.md" main_readme_blob_path = repo_a_path / "blobs" / REPO_A_MAIN_README_BLOB_HASH self.assertEqual(main_readme_file.file_name, "README.md") self.assertEqual(main_readme_file.file_path, main_readme_file_path) self.assertEqual(main_readme_file.blob_path, main_readme_file_path) # Windows-specific: no blob file self.assertFalse(main_readme_blob_path.exists()) # Windows-specific # Check readme file from "refs/pr/1" revision pr_1_revision = repo_a.refs[REF_1_NAME] pr_1_revision_path = repo_a_path / "snapshots" / REPO_A_PR_1_HASH pr_1_readme_file = [file for file in pr_1_revision.files if file.file_name == "README.md"][0] pr_1_readme_file_path = pr_1_revision_path / "README.md" # file_path in "refs/pr/1" revision is different than "main" # Windows-specific: even blob path is different self.assertEqual(pr_1_readme_file.file_path, pr_1_readme_file_path) self.assertNotEqual( # Windows-specific: different as well pr_1_readme_file.blob_path, main_readme_file.blob_path ) @xfail_on_windows("Size on disk and paths differ on Windows. Not useful to test.") def test_cli_scan_cache_quiet(self) -> None: """Test output from CLI scan cache with non verbose output. End-to-end test just to see if output is in expected format. """ args = Mock() args.verbose = 0 args.dir = self.cache_dir with capture_output() as output: ScanCacheCommand(args).run() expected_output = f""" REPO ID REPO TYPE SIZE ON DISK NB FILES LAST_ACCESSED LAST_MODIFIED REFS LOCAL PATH ----------------------------- --------- ------------ -------- ----------------- ----------------- --------------- --------------------------------------------------------- {DATASET_ID} dataset 2.3K 1 a few seconds ago a few seconds ago main {self.cache_dir}/{DATASET_PATH} {MODEL_ID} model 1.5K 3 a few seconds ago a few seconds ago main, refs/pr/1 {self.cache_dir}/{MODEL_PATH} Done in 0.0s. Scanned 2 repo(s) for a total of \x1b[1m\x1b[31m3.8K\x1b[0m. """ self.assertListEqual( output.getvalue().replace("-", "").split(), expected_output.replace("-", "").split(), ) @xfail_on_windows("Size on disk and paths differ on Windows. Not useful to test.") def test_cli_scan_cache_verbose(self) -> None: """Test output from CLI scan cache with verbose output. End-to-end test just to see if output is in expected format. """ args = Mock() args.verbose = 1 args.dir = self.cache_dir with capture_output() as output: ScanCacheCommand(args).run() expected_output = f""" REPO ID REPO TYPE REVISION SIZE ON DISK NB FILES LAST_MODIFIED REFS LOCAL PATH ----------------------------- --------- ---------------------------------------- ------------ -------- ----------------- --------- ------------------------------------------------------------------------------------------------------------ {DATASET_ID} dataset {REPO_B_MAIN_HASH} 2.3K 1 a few seconds ago main {self.cache_dir}/{DATASET_PATH}/snapshots/{REPO_B_MAIN_HASH} {MODEL_ID} model {REPO_A_PR_1_HASH} 1.5K 3 a few seconds ago refs/pr/1 {self.cache_dir}/{MODEL_PATH}/snapshots/{REPO_A_PR_1_HASH} {MODEL_ID} model {REPO_A_MAIN_HASH} 1.5K 2 a few seconds ago main {self.cache_dir}/{MODEL_PATH}/snapshots/{REPO_A_MAIN_HASH} {MODEL_ID} model {REPO_A_OTHER_HASH} 1.5K 1 a few seconds ago {self.cache_dir}/{MODEL_PATH}/snapshots/{REPO_A_OTHER_HASH} Done in 0.0s. Scanned 2 repo(s) for a total of \x1b[1m\x1b[31m3.8K\x1b[0m. """ self.assertListEqual( output.getvalue().replace("-", "").split(), expected_output.replace("-", "").split(), ) def test_cli_scan_missing_cache(self) -> None: """Test output from CLI scan cache when cache does not exist. End-to-end test just to see if output is in expected format. """ tmp_dir = tempfile.mkdtemp() os.rmdir(tmp_dir) args = Mock() args.verbose = 0 args.dir = tmp_dir with capture_output() as output: ScanCacheCommand(args).run() expected_output = f""" Cache directory not found: {Path(tmp_dir).resolve()} """ self.assertListEqual(output.getvalue().split(), expected_output.split()) @pytest.mark.usefixtures("fx_cache_dir") class TestCorruptedCacheUtils(unittest.TestCase): cache_dir: Path repo_path: Path refs_path: Path snapshots_path: Path @with_production_testing def setUp(self) -> None: """Setup a clean cache for tests that will get corrupted/modified in tests.""" # Download latest main snapshot_download(repo_id=MODEL_ID, repo_type="model", cache_dir=self.cache_dir) self.repo_path = self.cache_dir / MODEL_PATH self.refs_path = self.repo_path / "refs" self.snapshots_path = self.repo_path / "snapshots" def test_repo_path_not_valid_dir(self) -> None: """Test if found a not valid path in cache dir.""" # Case 1: a file repo_path = self.cache_dir / "a_file_that_should_not_be_there.txt" repo_path.touch() report = scan_cache_dir(self.cache_dir) self.assertEqual(len(report.repos), 1) # Scan still worked ! self.assertEqual(len(report.warnings), 1) self.assertEqual(str(report.warnings[0]), f"Repo path is not a directory: {repo_path}") # Case 2: a folder with wrong naming os.remove(repo_path) repo_path = self.cache_dir / "a_folder_that_should_not_be_there" repo_path.mkdir() report = scan_cache_dir(self.cache_dir) self.assertEqual(len(report.repos), 1) # Scan still worked ! self.assertEqual(len(report.warnings), 1) self.assertEqual( str(report.warnings[0]), f"Repo path is not a valid HuggingFace cache directory: {repo_path}", ) # Case 3: good naming but not a dataset/model/space rmtree_with_retry(repo_path) repo_path = self.cache_dir / "not-models--t5-small" repo_path.mkdir() report = scan_cache_dir(self.cache_dir) self.assertEqual(len(report.repos), 1) # Scan still worked ! self.assertEqual(len(report.warnings), 1) self.assertEqual( str(report.warnings[0]), f"Repo type must be `dataset`, `model` or `space`, found `not-model` ({repo_path}).", ) def test_snapshots_path_not_found(self) -> None: """Test if snapshots directory is missing in cached repo.""" rmtree_with_retry(self.snapshots_path) report = scan_cache_dir(self.cache_dir) self.assertEqual(len(report.repos), 0) # Failed self.assertEqual(len(report.warnings), 1) self.assertEqual( str(report.warnings[0]), f"Snapshots dir doesn't exist in cached repo: {self.snapshots_path}", ) def test_file_in_snapshots_dir(self) -> None: """Test if snapshots directory contains a file.""" wrong_file_path = self.snapshots_path / "should_not_be_there" wrong_file_path.touch() report = scan_cache_dir(self.cache_dir) self.assertEqual(len(report.repos), 0) # Failed self.assertEqual(len(report.warnings), 1) self.assertEqual( str(report.warnings[0]), f"Snapshots folder corrupted. Found a file: {wrong_file_path}", ) def test_snapshot_with_no_blob_files(self) -> None: """Test if a snapshot directory (e.g. a cached revision) is empty.""" for revision_path in self.snapshots_path.glob("*"): # Delete content of the revision rmtree_with_retry(revision_path) revision_path.mkdir() # Scan report = scan_cache_dir(self.cache_dir) # Get single repo self.assertEqual(len(report.warnings), 0) # Did not fail self.assertEqual(len(report.repos), 1) repo_report = list(report.repos)[0] # Repo report is empty self.assertEqual(repo_report.size_on_disk, 0) self.assertEqual(len(repo_report.revisions), 1) revision_report = list(repo_report.revisions)[0] # No files in revision so last_modified is the one from the revision folder self.assertEqual(revision_report.nb_files, 0) self.assertEqual(revision_report.last_modified, revision_path.stat().st_mtime) def test_repo_with_no_snapshots(self) -> None: """Test if the snapshot directory exists but is empty.""" rmtree_with_retry(self.refs_path) rmtree_with_retry(self.snapshots_path) self.snapshots_path.mkdir() # Scan report = scan_cache_dir(self.cache_dir) # Get single repo self.assertEqual(len(report.warnings), 0) # Did not fail self.assertEqual(len(report.repos), 1) repo_report = list(report.repos)[0] # No revisions in repos so last_modified is the one from the repo folder self.assertEqual(repo_report.size_on_disk, 0) self.assertEqual(len(repo_report.revisions), 0) self.assertEqual(repo_report.last_modified, self.repo_path.stat().st_mtime) self.assertEqual(repo_report.last_accessed, self.repo_path.stat().st_atime) def test_ref_to_missing_revision(self) -> None: """Test if a `refs` points to a missing revision.""" new_ref = self.repo_path / "refs" / "not_main" with new_ref.open("w") as f: f.write("revision_hash_that_does_not_exist") report = scan_cache_dir(self.cache_dir) self.assertEqual(len(report.repos), 0) # Failed self.assertEqual(len(report.warnings), 1) self.assertEqual( str(report.warnings[0]), "Reference(s) refer to missing commit hashes: {'revision_hash_that_does_not_exist': {'not_main'}} " + f"({self.repo_path}).", ) @xfail_on_windows("Last modified/last accessed work a bit differently on Windows.") def test_scan_cache_last_modified_and_last_accessed(self) -> None: """Scan the last_modified and last_accessed properties when scanning.""" TIME_GAP = 0.1 # Make a first scan report_1 = scan_cache_dir(self.cache_dir) # Values from first report repo_1 = list(report_1.repos)[0] revision_1 = list(repo_1.revisions)[0] readme_file_1 = [file for file in revision_1.files if file.file_name == "README.md"][0] another_file_1 = [file for file in revision_1.files if file.file_name == ".gitattributes"][0] # Comparison of last_accessed/last_modified between file and repo self.assertLessEqual(readme_file_1.blob_last_accessed, repo_1.last_accessed) self.assertLessEqual(readme_file_1.blob_last_modified, repo_1.last_modified) self.assertEqual(revision_1.last_modified, repo_1.last_modified) # Sleep and write new readme time.sleep(TIME_GAP) readme_file_1.file_path.write_text("modified readme") # Sleep and read content from readme time.sleep(TIME_GAP) with readme_file_1.file_path.open("r") as f: _ = f.read() # Sleep and re-scan time.sleep(TIME_GAP) report_2 = scan_cache_dir(self.cache_dir) # Values from second report repo_2 = list(report_2.repos)[0] revision_2 = list(repo_2.revisions)[0] readme_file_2 = [file for file in revision_2.files if file.file_name == "README.md"][0] another_file_2 = [file for file in revision_1.files if file.file_name == ".gitattributes"][0] # Report 1 is not updated when cache changes self.assertLess(repo_1.last_accessed, repo_2.last_accessed) self.assertLess(repo_1.last_modified, repo_2.last_modified) # "Another_file.md" did not change self.assertEqual(another_file_1, another_file_2) # Readme.md has been modified and then accessed more recently self.assertGreaterEqual( readme_file_2.blob_last_modified - readme_file_1.blob_last_modified, TIME_GAP * 0.9, # 0.9 factor because not exactly precise ) self.assertGreaterEqual( readme_file_2.blob_last_accessed - readme_file_1.blob_last_accessed, 2 * TIME_GAP * 0.9, # 0.9 factor because not exactly precise ) self.assertGreaterEqual( readme_file_2.blob_last_accessed - readme_file_2.blob_last_modified, TIME_GAP * 0.9, # 0.9 factor because not exactly precise ) # Comparison of last_accessed/last_modified between file and repo self.assertEqual(readme_file_2.blob_last_accessed, repo_2.last_accessed) self.assertEqual(readme_file_2.blob_last_modified, repo_2.last_modified) self.assertEqual(revision_2.last_modified, repo_2.last_modified) class TestDeleteRevisionsDryRun(unittest.TestCase): cache_info: Mock # Mocked HFCacheInfo def setUp(self) -> None: """Set up fake cache scan report.""" repo_A_path = Path("repo_A") blobs_path = repo_A_path / "blobs" snapshots_path = repo_A_path / "snapshots_path" # Define blob files main_only_file = Mock() main_only_file.blob_path = blobs_path / "main_only_hash" main_only_file.size_on_disk = 1 detached_only_file = Mock() detached_only_file.blob_path = blobs_path / "detached_only_hash" detached_only_file.size_on_disk = 10 pr_1_only_file = Mock() pr_1_only_file.blob_path = blobs_path / "pr_1_only_hash" pr_1_only_file.size_on_disk = 100 detached_and_pr_1_only_file = Mock() detached_and_pr_1_only_file.blob_path = blobs_path / "detached_and_pr_1_only_hash" detached_and_pr_1_only_file.size_on_disk = 1000 shared_file = Mock() shared_file.blob_path = blobs_path / "shared_file_hash" shared_file.size_on_disk = 10000 # Define revisions repo_A_rev_main = Mock() repo_A_rev_main.commit_hash = "repo_A_rev_main" repo_A_rev_main.snapshot_path = snapshots_path / "repo_A_rev_main" repo_A_rev_main.files = {main_only_file, shared_file} repo_A_rev_main.refs = {"main"} repo_A_rev_detached = Mock() repo_A_rev_detached.commit_hash = "repo_A_rev_detached" repo_A_rev_detached.snapshot_path = snapshots_path / "repo_A_rev_detached" repo_A_rev_detached.files = { detached_only_file, detached_and_pr_1_only_file, shared_file, } repo_A_rev_detached.refs = {} repo_A_rev_pr_1 = Mock() repo_A_rev_pr_1.commit_hash = "repo_A_rev_pr_1" repo_A_rev_pr_1.snapshot_path = snapshots_path / "repo_A_rev_pr_1" repo_A_rev_pr_1.files = { pr_1_only_file, detached_and_pr_1_only_file, shared_file, } repo_A_rev_pr_1.refs = {"refs/pr/1"} # Define repo repo_A = Mock() repo_A.repo_path = Path("repo_A") repo_A.size_on_disk = 4444 repo_A.revisions = {repo_A_rev_main, repo_A_rev_detached, repo_A_rev_pr_1} # Define cache cache_info = Mock() cache_info.repos = [repo_A] self.cache_info = cache_info def test_delete_detached_revision(self) -> None: strategy = HFCacheInfo.delete_revisions(self.cache_info, "repo_A_rev_detached") expected = DeleteCacheStrategy( expected_freed_size=10, blobs={ # "shared_file_hash" and "detached_and_pr_1_only_hash" are not deleted Path("repo_A/blobs/detached_only_hash"), }, refs=set(), # No ref deleted since detached repos=set(), # No repo deleted as other revisions exist snapshots={Path("repo_A/snapshots_path/repo_A_rev_detached")}, ) self.assertEqual(strategy, expected) def test_delete_pr_1_revision(self) -> None: strategy = HFCacheInfo.delete_revisions(self.cache_info, "repo_A_rev_pr_1") expected = DeleteCacheStrategy( expected_freed_size=100, blobs={ # "shared_file_hash" and "detached_and_pr_1_only_hash" are not deleted Path("repo_A/blobs/pr_1_only_hash") }, refs={Path("repo_A/refs/refs/pr/1")}, # Ref is deleted ! repos=set(), # No repo deleted as other revisions exist snapshots={Path("repo_A/snapshots_path/repo_A_rev_pr_1")}, ) self.assertEqual(strategy, expected) def test_delete_pr_1_and_detached(self) -> None: strategy = HFCacheInfo.delete_revisions(self.cache_info, "repo_A_rev_detached", "repo_A_rev_pr_1") expected = DeleteCacheStrategy( expected_freed_size=1110, blobs={ Path("repo_A/blobs/detached_only_hash"), Path("repo_A/blobs/pr_1_only_hash"), # blob shared in both revisions and only those two Path("repo_A/blobs/detached_and_pr_1_only_hash"), }, refs={Path("repo_A/refs/refs/pr/1")}, repos=set(), snapshots={ Path("repo_A/snapshots_path/repo_A_rev_detached"), Path("repo_A/snapshots_path/repo_A_rev_pr_1"), }, ) self.assertEqual(strategy, expected) def test_delete_all_revisions(self) -> None: strategy = HFCacheInfo.delete_revisions( self.cache_info, "repo_A_rev_detached", "repo_A_rev_pr_1", "repo_A_rev_main" ) expected = DeleteCacheStrategy( expected_freed_size=4444, blobs=set(), refs=set(), repos={Path("repo_A")}, # No remaining revisions: full repo is deleted snapshots=set(), ) self.assertEqual(strategy, expected) def test_delete_unknown_revision(self) -> None: with self.assertLogs() as captured: strategy = HFCacheInfo.delete_revisions(self.cache_info, "repo_A_rev_detached", "abcdef123456789") # Expected is same strategy as without "abcdef123456789" expected = HFCacheInfo.delete_revisions(self.cache_info, "repo_A_rev_detached") self.assertEqual(strategy, expected) # Expect a warning message self.assertEqual(len(captured.records), 1) self.assertEqual(captured.records[0].levelname, "WARNING") self.assertEqual( captured.records[0].message, "Revision(s) not found - cannot delete them: abcdef123456789", ) @pytest.mark.usefixtures("fx_cache_dir") class TestDeleteStrategyExecute(unittest.TestCase): cache_dir: Path def test_execute(self) -> None: # Repo folders repo_A_path = self.cache_dir / "repo_A" repo_A_path.mkdir() repo_B_path = self.cache_dir / "repo_B" repo_B_path.mkdir() # Refs files in repo_B refs_main_path = repo_B_path / "refs" / "main" refs_main_path.parent.mkdir(parents=True) refs_main_path.touch() refs_pr_1_path = repo_B_path / "refs" / "refs" / "pr" / "1" refs_pr_1_path.parent.mkdir(parents=True) refs_pr_1_path.touch() # Blobs files in repo_B (repo_B_path / "blobs").mkdir() blob_1 = repo_B_path / "blobs" / "blob_1" blob_2 = repo_B_path / "blobs" / "blob_2" blob_3 = repo_B_path / "blobs" / "blob_3" blob_1.touch() blob_2.touch() blob_3.touch() # Snapshot folders in repo_B snapshot_1 = repo_B_path / "snapshots" / "snapshot_1" snapshot_2 = repo_B_path / "snapshots" / "snapshot_2" snapshot_1.mkdir(parents=True) snapshot_2.mkdir() # Execute deletion # Delete repo_A + keep only blob_1, main ref and snapshot_1 in repo_B. DeleteCacheStrategy( expected_freed_size=123456, blobs={blob_2, blob_3}, refs={refs_pr_1_path}, repos={repo_A_path}, snapshots={snapshot_2}, ).execute() # Repo A deleted self.assertFalse(repo_A_path.exists()) self.assertTrue(repo_B_path.exists()) # Only `blob` 1 remains self.assertTrue(blob_1.exists()) self.assertFalse(blob_2.exists()) self.assertFalse(blob_3.exists()) # Only ref `main` remains self.assertTrue(refs_main_path.exists()) self.assertFalse(refs_pr_1_path.exists()) # Only `snapshot_1` remains self.assertTrue(snapshot_1.exists()) self.assertFalse(snapshot_2.exists()) @pytest.mark.usefixtures("fx_cache_dir") class TestTryDeletePath(unittest.TestCase): cache_dir: Path def test_delete_path_on_file_success(self) -> None: """Successfully delete a local file.""" file_path = self.cache_dir / "file.txt" file_path.touch() _try_delete_path(file_path, path_type="TYPE") self.assertFalse(file_path.exists()) def test_delete_path_on_folder_success(self) -> None: """Successfully delete a local folder.""" dir_path = self.cache_dir / "something" subdir_path = dir_path / "bar" subdir_path.mkdir(parents=True) # subfolder file_path_1 = dir_path / "file.txt" # file at root file_path_1.touch() file_path_2 = subdir_path / "config.json" # file in subfolder file_path_2.touch() _try_delete_path(dir_path, path_type="TYPE") self.assertFalse(dir_path.exists()) self.assertFalse(subdir_path.exists()) self.assertFalse(file_path_1.exists()) self.assertFalse(file_path_2.exists()) def test_delete_path_on_missing_file(self) -> None: """Try delete a missing file.""" file_path = self.cache_dir / "file.txt" with self.assertLogs() as captured: _try_delete_path(file_path, path_type="TYPE") # Assert warning message with traceback for debug purposes self.assertEqual(len(captured.output), 1) self.assertTrue( captured.output[0].startswith( "WARNING:huggingface_hub.utils._cache_manager:Couldn't delete TYPE:" f" file not found ({file_path})\nTraceback (most recent call last):" ) ) def test_delete_path_on_missing_folder(self) -> None: """Try delete a missing folder.""" dir_path = self.cache_dir / "folder" with self.assertLogs() as captured: _try_delete_path(dir_path, path_type="TYPE") # Assert warning message with traceback for debug purposes self.assertEqual(len(captured.output), 1) self.assertTrue( captured.output[0].startswith( "WARNING:huggingface_hub.utils._cache_manager:Couldn't delete TYPE:" f" file not found ({dir_path})\nTraceback (most recent call last):" ) ) @xfail_on_windows(reason="Permissions are handled differently on Windows.") def test_delete_path_on_local_folder_with_wrong_permission(self) -> None: """Try delete a local folder that is protected.""" dir_path = self.cache_dir / "something" dir_path.mkdir() file_path_1 = dir_path / "file.txt" # file at root file_path_1.touch() dir_path.chmod(444) # Read-only folder with self.assertLogs() as captured: _try_delete_path(dir_path, path_type="TYPE") # Folder still exists (couldn't be deleted) self.assertTrue(dir_path.is_dir()) # Assert warning message with traceback for debug purposes self.assertEqual(len(captured.output), 1) self.assertTrue( captured.output[0].startswith( "WARNING:huggingface_hub.utils._cache_manager:Couldn't delete TYPE:" f" permission denied ({dir_path})\nTraceback (most recent call last):" ) ) # For proper cleanup dir_path.chmod(509) class TestStringFormatters(unittest.TestCase): SIZES = { 16.0: "16.0", 1000.0: "1.0K", 1024 * 1024 * 1024: "1.1G", # not 1.0GiB } SINCE = { 1: "a few seconds ago", 15: "a few seconds ago", 25: "25 seconds ago", 80: "1 minute ago", 1000: "17 minutes ago", 4000: "1 hour ago", 8000: "2 hours ago", 3600 * 24 * 13: "2 weeks ago", 3600 * 24 * 30 * 8.2: "8 months ago", 3600 * 24 * 365: "1 year ago", 3600 * 24 * 365 * 9.6: "10 years ago", } def test_format_size(self) -> None: """Test `_format_size` formatter.""" for size, expected in self.SIZES.items(): self.assertEqual( _format_size(size), expected, msg=f"Wrong formatting for {size} == '{expected}'", ) def test_format_timesince(self) -> None: """Test `_format_timesince` formatter.""" for ts, expected in self.SINCE.items(): self.assertEqual( _format_timesince(time.time() - ts), expected, msg=f"Wrong formatting for {ts} == '{expected}'", ) huggingface_hub-0.31.1/tests/test_utils_chunks.py000066400000000000000000000025301500667546600222060ustar00rootroot00000000000000import unittest from huggingface_hub.utils._chunk_utils import chunk_iterable class TestUtilsCommon(unittest.TestCase): def test_chunk_iterable_non_truncated(self): # Can iterable over any iterable (iterator, list, tuple,...) for iterable in (range(12), list(range(12)), tuple(range(12))): # 12 is a multiple of 4 -> last chunk is not truncated for chunk, expected_chunk in zip( chunk_iterable(iterable, chunk_size=4), [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], ): self.assertListEqual(list(chunk), expected_chunk) def test_chunk_iterable_last_chunk_truncated(self): # Can iterable over any iterable (iterator, list, tuple,...) for iterable in (range(12), list(range(12)), tuple(range(12))): # 12 is NOT a multiple of 5 -> last chunk is truncated for chunk, expected_chunk in zip( chunk_iterable(iterable, chunk_size=5), [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11]], ): self.assertListEqual(list(chunk), expected_chunk) def test_chunk_iterable_validation(self): with self.assertRaises(ValueError): next(chunk_iterable(range(128), 0)) with self.assertRaises(ValueError): next(chunk_iterable(range(128), -1)) huggingface_hub-0.31.1/tests/test_utils_cli.py000066400000000000000000000045471500667546600214740ustar00rootroot00000000000000import os import unittest from unittest import mock from huggingface_hub.commands._cli_utils import ANSI, tabulate class TestCLIUtils(unittest.TestCase): @mock.patch.dict(os.environ, {}, clear=True) def test_ansi_utils(self) -> None: """Test `ANSI` works as expected.""" self.assertEqual( ANSI.bold("this is bold"), "\x1b[1mthis is bold\x1b[0m", ) self.assertEqual( ANSI.gray("this is gray"), "\x1b[90mthis is gray\x1b[0m", ) self.assertEqual( ANSI.red("this is red"), "\x1b[1m\x1b[31mthis is red\x1b[0m", ) self.assertEqual( ANSI.gray(ANSI.bold("this is bold and grey")), "\x1b[90m\x1b[1mthis is bold and grey\x1b[0m\x1b[0m", ) @mock.patch.dict(os.environ, {"NO_COLOR": "1"}, clear=True) def test_ansi_no_color(self) -> None: """Test `ANSI` respects `NO_COLOR` env var.""" self.assertEqual( ANSI.bold("this is bold"), "this is bold", ) self.assertEqual( ANSI.gray("this is gray"), "this is gray", ) self.assertEqual( ANSI.red("this is red"), "this is red", ) self.assertEqual( ANSI.gray(ANSI.bold("this is bold and grey")), "this is bold and grey", ) def test_tabulate_utility(self) -> None: """Test `tabulate` works as expected.""" rows = [[1, 2, 3], ["a very long value", "foo", "bar"], ["", 123, 456]] headers = ["Header 1", "something else", "a third column"] self.assertEqual( tabulate(rows=rows, headers=headers), "Header 1 something else a third column \n" "----------------- -------------- -------------- \n" " 1 2 3 \n" "a very long value foo bar \n" " 123 456 ", ) def test_tabulate_utility_with_too_short_row(self) -> None: """ Test `tabulate` throw IndexError when a row has less values than the header list. """ self.assertRaises( IndexError, tabulate, rows=[[1]], headers=["Header 1", "Header 2"], ) huggingface_hub-0.31.1/tests/test_utils_datetime.py000066400000000000000000000027051500667546600225130ustar00rootroot00000000000000import unittest from datetime import datetime, timezone import pytest from huggingface_hub.utils import parse_datetime class TestDatetimeUtils(unittest.TestCase): def test_parse_datetime(self): """Test `parse_datetime` works correctly on datetimes returned by server.""" self.assertEqual( parse_datetime("2022-08-19T07:19:38.123Z"), datetime(2022, 8, 19, 7, 19, 38, 123000, tzinfo=timezone.utc), ) # Test nanoseconds precision (should be truncated to microseconds) self.assertEqual( parse_datetime("2022-08-19T07:19:38.123456789Z"), datetime(2022, 8, 19, 7, 19, 38, 123456, tzinfo=timezone.utc), ) # Test without milliseconds (should add .000) self.assertEqual( parse_datetime("2024-11-16T00:27:02Z"), datetime(2024, 11, 16, 0, 27, 2, 0, tzinfo=timezone.utc), ) with pytest.raises(ValueError, match=r".*Cannot parse '2022-08-19T07:19:38' as a datetime.*"): parse_datetime("2022-08-19T07:19:38") with pytest.raises( ValueError, match=r".*Cannot parse '2022-08-19T07:19:38.123' as a datetime.*", ): parse_datetime("2022-08-19T07:19:38.123") with pytest.raises( ValueError, match=r".*Cannot parse '2022-08-19 07:19:38.123Z\+6:00' as a datetime.*", ): parse_datetime("2022-08-19 07:19:38.123Z+6:00") huggingface_hub-0.31.1/tests/test_utils_deprecation.py000066400000000000000000000110631500667546600232110ustar00rootroot00000000000000import unittest import warnings import pytest from huggingface_hub.utils._deprecation import ( _deprecate_arguments, _deprecate_method, _deprecate_positional_args, ) class TestDeprecationUtils(unittest.TestCase): def test_deprecate_positional_args(self): """Test warnings are triggered when using deprecated positional args.""" @_deprecate_positional_args(version="xxx") def dummy_position_deprecated(a, *, b="b", c="c"): pass with warnings.catch_warnings(): # Assert no warnings when used correctly. # Taken from https://docs.pytest.org/en/latest/how-to/capture-warnings.html#additional-use-cases-of-warnings-in-tests warnings.simplefilter("error") dummy_position_deprecated(a="A", b="B", c="C") dummy_position_deprecated("A", b="B", c="C") with pytest.warns(FutureWarning): dummy_position_deprecated("A", "B", c="C") with pytest.warns(FutureWarning): dummy_position_deprecated("A", "B", "C") def test_deprecate_arguments(self): """Test warnings are triggered when using deprecated arguments.""" @_deprecate_arguments(version="xxx", deprecated_args={"c"}) def dummy_c_deprecated(a, b="b", c="c"): pass @_deprecate_arguments(version="xxx", deprecated_args={"b", "c"}) def dummy_b_c_deprecated(a, b="b", c="c"): pass with warnings.catch_warnings(): # Assert no warnings when used correctly. # Taken from https://docs.pytest.org/en/latest/how-to/capture-warnings.html#additional-use-cases-of-warnings-in-tests warnings.simplefilter("error") dummy_c_deprecated("A") dummy_c_deprecated("A", "B") dummy_c_deprecated("A", b="B") dummy_b_c_deprecated("A") dummy_b_c_deprecated("A", b="b") dummy_b_c_deprecated("A", b="b", c="c") with pytest.warns(FutureWarning): dummy_c_deprecated("A", "B", "C") with pytest.warns(FutureWarning): dummy_c_deprecated("A", c="C") with pytest.warns(FutureWarning): dummy_c_deprecated("A", b="B", c="C") with pytest.warns(FutureWarning): dummy_b_c_deprecated("A", b="B") with pytest.warns(FutureWarning): dummy_b_c_deprecated("A", c="C") with pytest.warns(FutureWarning): dummy_b_c_deprecated("A", b="B", c="C") def test_deprecate_arguments_with_default_warning_message(self) -> None: """Test default warning message when deprecating arguments.""" @_deprecate_arguments(version="xxx", deprecated_args={"a"}) def dummy_deprecated_default_message(a: str = "a") -> None: pass # Default message with pytest.warns(FutureWarning) as record: dummy_deprecated_default_message(a="A") self.assertEqual(len(record), 1) self.assertEqual( record[0].message.args[0], "Deprecated argument(s) used in 'dummy_deprecated_default_message': a." " Will not be supported from version 'xxx'.", ) def test_deprecate_arguments_with_custom_warning_message(self) -> None: """Test custom warning message when deprecating arguments.""" @_deprecate_arguments( version="xxx", deprecated_args={"a"}, custom_message="This is a custom message.", ) def dummy_deprecated_custom_message(a: str = "a") -> None: pass # Custom message with pytest.warns(FutureWarning) as record: dummy_deprecated_custom_message(a="A") self.assertEqual(len(record), 1) self.assertEqual( record[0].message.args[0], "Deprecated argument(s) used in 'dummy_deprecated_custom_message': a." " Will not be supported from version 'xxx'.\n\nThis is a custom" " message.", ) def test_deprecated_method(self) -> None: """Test deprecate method throw warning.""" @_deprecate_method(version="xxx", message="This is a custom message.") def dummy_deprecated() -> None: pass # Custom message with pytest.warns(FutureWarning) as record: dummy_deprecated() self.assertEqual(len(record), 1) self.assertEqual( record[0].message.args[0], "'dummy_deprecated' (from 'tests.test_utils_deprecation') is deprecated" " and will be removed from version 'xxx'. This is a custom message.", ) huggingface_hub-0.31.1/tests/test_utils_errors.py000066400000000000000000000374741500667546600222460ustar00rootroot00000000000000import unittest import pytest from requests.models import PreparedRequest, Response from huggingface_hub.errors import ( BadRequestError, DisabledRepoError, EntryNotFoundError, HfHubHTTPError, RepositoryNotFoundError, RevisionNotFoundError, ) from huggingface_hub.utils._http import REPO_API_REGEX, X_AMZN_TRACE_ID, X_REQUEST_ID, _format, hf_raise_for_status class TestErrorUtils(unittest.TestCase): def test_hf_raise_for_status_repo_not_found(self) -> None: response = Response() response.headers = {"X-Error-Code": "RepoNotFound", X_REQUEST_ID: 123} response.status_code = 404 with self.assertRaisesRegex(RepositoryNotFoundError, "Repository Not Found") as context: hf_raise_for_status(response) assert context.exception.response.status_code == 404 assert "Request ID: 123" in str(context.exception) def test_hf_raise_for_status_disabled_repo(self) -> None: response = Response() response.headers = {"X-Error-Message": "Access to this resource is disabled.", X_REQUEST_ID: 123} response.status_code = 403 with self.assertRaises(DisabledRepoError) as context: hf_raise_for_status(response) assert context.exception.response.status_code == 403 assert "Request ID: 123" in str(context.exception) def test_hf_raise_for_status_401_repo_url_not_invalid_token(self) -> None: response = Response() response.headers = {X_REQUEST_ID: 123} response.status_code = 401 response.request = PreparedRequest() response.request.url = "https://huggingface.co/api/models/username/reponame" with self.assertRaisesRegex(RepositoryNotFoundError, "Repository Not Found") as context: hf_raise_for_status(response) assert context.exception.response.status_code == 401 assert "Request ID: 123" in str(context.exception) def test_hf_raise_for_status_401_repo_url_invalid_token(self) -> None: response = Response() response.headers = {X_REQUEST_ID: 123, "X-Error-Message": "Invalid credentials in Authorization header"} response.status_code = 401 response.request = PreparedRequest() response.request.url = "https://huggingface.co/api/models/username/reponame" with self.assertRaisesRegex(HfHubHTTPError, "Invalid credentials in Authorization header") as context: hf_raise_for_status(response) assert context.exception.response.status_code == 401 assert "Request ID: 123" in str(context.exception) def test_hf_raise_for_status_403_wrong_token_scope(self) -> None: response = Response() response.headers = {X_REQUEST_ID: 123, "X-Error-Message": "specific error message"} response.status_code = 403 response.request = PreparedRequest() response.request.url = "https://huggingface.co/api/repos/create" expected_message_part = "403 Forbidden: specific error message" with self.assertRaisesRegex(HfHubHTTPError, expected_message_part) as context: hf_raise_for_status(response) assert context.exception.response.status_code == 403 assert "Request ID: 123" in str(context.exception) def test_hf_raise_for_status_401_not_repo_url(self) -> None: response = Response() response.headers = {X_REQUEST_ID: 123} response.status_code = 401 response.request = PreparedRequest() response.request.url = "https://huggingface.co/api/collections" with self.assertRaises(HfHubHTTPError) as context: hf_raise_for_status(response) assert context.exception.response.status_code == 401 assert "Request ID: 123" in str(context.exception) def test_hf_raise_for_status_revision_not_found(self) -> None: response = Response() response.headers = {"X-Error-Code": "RevisionNotFound", X_REQUEST_ID: 123} response.status_code = 404 with self.assertRaisesRegex(RevisionNotFoundError, "Revision Not Found") as context: hf_raise_for_status(response) assert context.exception.response.status_code == 404 assert "Request ID: 123" in str(context.exception) def test_hf_raise_for_status_entry_not_found(self) -> None: response = Response() response.headers = {"X-Error-Code": "EntryNotFound", X_REQUEST_ID: 123} response.status_code = 404 with self.assertRaisesRegex(EntryNotFoundError, "Entry Not Found") as context: hf_raise_for_status(response) assert context.exception.response.status_code == 404 assert "Request ID: 123" in str(context.exception) def test_hf_raise_for_status_bad_request_no_endpoint_name(self) -> None: """Test HTTPError converted to BadRequestError if error 400.""" response = Response() response.status_code = 400 with self.assertRaisesRegex(BadRequestError, "Bad request:") as context: hf_raise_for_status(response) assert context.exception.response.status_code == 400 def test_hf_raise_for_status_bad_request_with_endpoint_name(self) -> None: """Test endpoint name is added to BadRequestError message.""" response = Response() response.status_code = 400 with self.assertRaisesRegex(BadRequestError, "Bad request for preupload endpoint:") as context: hf_raise_for_status(response, endpoint_name="preupload") assert context.exception.response.status_code == 400 def test_hf_raise_for_status_fallback(self) -> None: """Test HTTPError is converted to HfHubHTTPError.""" response = Response() response.status_code = 404 response.headers = { X_REQUEST_ID: "test-id", } response.url = "test_URL" with self.assertRaisesRegex(HfHubHTTPError, "Request ID: test-id") as context: hf_raise_for_status(response) assert context.exception.response.status_code == 404 assert context.exception.response.url == "test_URL" class TestHfHubHTTPError(unittest.TestCase): response: Response def setUp(self) -> None: """Setup with a default response.""" self.response = Response() self.response.status_code = 404 self.response.url = "test_URL" def test_hf_hub_http_error_initialization(self) -> None: """Test HfHubHTTPError is initialized properly.""" error = HfHubHTTPError("this is a message", response=self.response) assert str(error) == "this is a message" assert error.response == self.response assert error.request_id is None assert error.server_message is None def test_hf_hub_http_error_init_with_request_id(self) -> None: """Test request id is added to the message.""" self.response.headers = {X_REQUEST_ID: "test-id"} error = _format(HfHubHTTPError, "this is a message", response=self.response) assert str(error) == "this is a message (Request ID: test-id)" assert error.request_id == "test-id" def test_hf_hub_http_error_init_with_request_id_and_multiline_message(self) -> None: """Test request id is added to the end of the first line.""" self.response.headers = {X_REQUEST_ID: "test-id"} error = _format(HfHubHTTPError, "this is a message\nthis is more details", response=self.response) assert str(error) == "this is a message (Request ID: test-id)\nthis is more details" error = _format(HfHubHTTPError, "this is a message\n\nthis is more details", response=self.response) assert str(error) == "this is a message (Request ID: test-id)\n\nthis is more details" def test_hf_hub_http_error_init_with_request_id_already_in_message(self) -> None: """Test request id is not duplicated in error message (case insensitive)""" self.response.headers = {X_REQUEST_ID: "test-id"} error = _format(HfHubHTTPError, "this is a message on request TEST-ID", response=self.response) assert str(error) == "this is a message on request TEST-ID" assert error.request_id == "test-id" def test_hf_hub_http_error_init_with_server_error(self) -> None: """Test server error is added to the error message.""" self.response._content = b'{"error": "This is a message returned by the server"}' error = _format(HfHubHTTPError, "this is a message", response=self.response) assert str(error) == "this is a message\n\nThis is a message returned by the server" assert error.server_message == "This is a message returned by the server" def test_hf_hub_http_error_init_with_server_error_and_multiline_message( self, ) -> None: """Test server error is added to the error message after the details.""" self.response._content = b'{"error": "This is a message returned by the server"}' error = _format(HfHubHTTPError, "this is a message\n\nSome details.", response=self.response) assert str(error) == "this is a message\n\nSome details.\nThis is a message returned by the server" def test_hf_hub_http_error_init_with_multiple_server_errors( self, ) -> None: """Test server errors are added to the error message after the details. Regression test for https://github.com/huggingface/huggingface_hub/issues/1114. """ self.response._content = ( b'{"httpStatusCode": 400, "errors": [{"message": "this is error 1", "type":' b' "error"}, {"message": "this is error 2", "type": "error"}]}' ) error = _format(HfHubHTTPError, "this is a message\n\nSome details.", response=self.response) assert str(error) == "this is a message\n\nSome details.\nthis is error 1\nthis is error 2" def test_hf_hub_http_error_init_with_server_error_already_in_message( self, ) -> None: """Test server error is not duplicated if already in details. Case insensitive. """ self.response._content = b'{"error": "repo NOT found"}' error = _format( HfHubHTTPError, "this is a message\n\nRepo Not Found. and more\nand more", response=self.response, ) assert str(error) == "this is a message\n\nRepo Not Found. and more\nand more" def test_hf_hub_http_error_init_with_unparsable_server_error( self, ) -> None: """Server returned a text message (not as JSON) => should be added to the exception.""" self.response._content = b"this is not a json-formatted string" error = _format(HfHubHTTPError, "this is a message", response=self.response) assert str(error) == "this is a message\n\nthis is not a json-formatted string" assert error.server_message == "this is not a json-formatted string" def test_hf_hub_http_error_append_to_message(self) -> None: """Test add extra information to existing HfHubHTTPError.""" error = _format(HfHubHTTPError, "this is a message", response=self.response) error.args = error.args + (1, 2, 3) # faking some extra args error.append_to_message("\nthis is an additional message") assert error.args == ("this is a message\nthis is an additional message", 1, 2, 3) assert error.server_message is None # added message is not from server def test_hf_hub_http_error_init_with_error_message_in_header(self) -> None: """Test server error from header is added to the error message.""" self.response.headers = {"X-Error-Message": "Error message from headers."} error = _format(HfHubHTTPError, "this is a message", response=self.response) assert str(error) == "this is a message\n\nError message from headers." assert error.server_message == "Error message from headers." def test_hf_hub_http_error_init_with_error_message_from_header_and_body( self, ) -> None: """Test server error from header and from body are added to the error message.""" self.response._content = b'{"error": "Error message from body."}' self.response.headers = {"X-Error-Message": "Error message from headers."} error = _format(HfHubHTTPError, "this is a message", response=self.response) assert str(error) == "this is a message\n\nError message from headers.\nError message from body." assert error.server_message == "Error message from headers.\nError message from body." def test_hf_hub_http_error_init_with_error_message_duplicated_in_header_and_body( self, ) -> None: """Test server error from header and from body are the same. Should not duplicate it in the raised `HfHubHTTPError`. """ self.response._content = b'{"error": "Error message duplicated in headers and body."}' self.response.headers = {"X-Error-Message": "Error message duplicated in headers and body."} error = _format(HfHubHTTPError, "this is a message", response=self.response) assert str(error) == "this is a message\n\nError message duplicated in headers and body." assert error.server_message == "Error message duplicated in headers and body." def test_hf_hub_http_error_without_request_id_with_amzn_trace_id(self) -> None: """Test request id is not duplicated in error message (case insensitive)""" self.response.headers = {X_AMZN_TRACE_ID: "test-trace-id"} error = _format(HfHubHTTPError, "this is a message", response=self.response) assert str(error) == "this is a message (Amzn Trace ID: test-trace-id)" assert error.request_id == "test-trace-id" def test_hf_hub_http_error_with_request_id_and_amzn_trace_id(self) -> None: """Test request id is not duplicated in error message (case insensitive)""" self.response.headers = {X_AMZN_TRACE_ID: "test-trace-id", X_REQUEST_ID: "test-id"} error = _format(HfHubHTTPError, "this is a message", response=self.response) assert str(error) == "this is a message (Request ID: test-id)" assert error.request_id == "test-id" @pytest.mark.parametrize( ("url", "should_match"), [ # Listing endpoints => False ("https://huggingface.co/api/models", False), ("https://huggingface.co/api/datasets", False), ("https://huggingface.co/api/spaces", False), # Create repo endpoint => False ("https://huggingface.co/api/repos/create", False), # Collection endpoints => False ("https://huggingface.co/api/collections", False), ("https://huggingface.co/api/collections/foo/bar", False), # Repo endpoints => True ("https://huggingface.co/api/models/repo_id", True), ("https://huggingface.co/api/datasets/repo_id", True), ("https://huggingface.co/api/spaces/repo_id", True), ("https://huggingface.co/api/models/username/repo_name/refs/main", True), ("https://huggingface.co/api/datasets/username/repo_name/refs/main", True), ("https://huggingface.co/api/spaces/username/repo_name/refs/main", True), # Inference Endpoint => False ("https://api.endpoints.huggingface.cloud/v2/endpoint/namespace", False), # Staging Endpoint => True ("https://hub-ci.huggingface.co/api/models/repo_id", True), ("https://hub-ci.huggingface.co/api/datasets/repo_id", True), ("https://hub-ci.huggingface.co/api/spaces/repo_id", True), # /resolve Endpoint => True ("https://huggingface.co/gpt2/resolve/main/README.md", True), ("https://huggingface.co/datasets/google/fleurs/resolve/revision/README.md", True), # Regression tests ("https://huggingface.co/bert-base/resolve/main/pytorch_model.bin", True), ("https://hub-ci.huggingface.co/__DUMMY_USER__/repo-1470b5/resolve/main/file.txt", True), ], ) def test_repo_api_regex(url: str, should_match: bool) -> None: """Test the regex used to match repo API URLs.""" if should_match: assert REPO_API_REGEX.match(url) else: assert REPO_API_REGEX.match(url) is None huggingface_hub-0.31.1/tests/test_utils_experimental.py000066400000000000000000000016151500667546600234130ustar00rootroot00000000000000import unittest import warnings from unittest.mock import patch from huggingface_hub.utils import experimental @experimental def dummy_function(): return "success" class TestExperimentalFlag(unittest.TestCase): def test_experimental_warning(self): with patch("huggingface_hub.constants.HF_HUB_DISABLE_EXPERIMENTAL_WARNING", False): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") self.assertEqual(dummy_function(), "success") self.assertEqual(len(w), 1) def test_experimental_no_warning(self): with patch("huggingface_hub.constants.HF_HUB_DISABLE_EXPERIMENTAL_WARNING", True): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") self.assertEqual(dummy_function(), "success") self.assertEqual(len(w), 0) huggingface_hub-0.31.1/tests/test_utils_fixes.py000066400000000000000000000035321500667546600220340ustar00rootroot00000000000000import logging import unittest from pathlib import Path import filelock import pytest from huggingface_hub.utils import SoftTemporaryDirectory, WeakFileLock, yaml_dump class TestYamlDump(unittest.TestCase): def test_yaml_dump_emoji(self) -> None: self.assertEqual(yaml_dump({"emoji": "👀"}), "emoji: 👀\n") def test_yaml_dump_japanese_characters(self) -> None: self.assertEqual(yaml_dump({"some unicode": "日本か"}), "some unicode: 日本か\n") def test_yaml_dump_explicit_no_unicode(self) -> None: self.assertEqual(yaml_dump({"emoji": "👀"}, allow_unicode=False), 'emoji: "\\U0001F440"\n') class TestTemporaryDirectory(unittest.TestCase): def test_temporary_directory(self) -> None: with SoftTemporaryDirectory(prefix="prefix", suffix="suffix") as path: self.assertIsInstance(path, Path) self.assertTrue(path.name.startswith("prefix")) self.assertTrue(path.name.endswith("suffix")) self.assertTrue(path.is_dir()) # Tmpdir is deleted self.assertFalse(path.is_dir()) class TestWeakFileLock: def test_lock_log_every( self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture ) -> None: monkeypatch.setattr("huggingface_hub.constants.FILELOCK_LOG_EVERY_SECONDS", 0.1) lock_file = tmp_path / ".lock" with caplog.at_level(logging.INFO, logger="huggingface_hub.utils._fixes"): with WeakFileLock(lock_file): with pytest.raises(filelock.Timeout) as exc_info: with WeakFileLock(lock_file, timeout=0.3): pass assert exc_info.value.lock_file == str(lock_file) assert len(caplog.records) >= 3 assert caplog.records[0].message.startswith(f"Still waiting to acquire lock on {lock_file}") huggingface_hub-0.31.1/tests/test_utils_git_credentials.py000066400000000000000000000047621500667546600240640ustar00rootroot00000000000000import time import unittest from pathlib import Path import pytest from huggingface_hub.constants import ENDPOINT from huggingface_hub.utils import run_interactive_subprocess, run_subprocess from huggingface_hub.utils._git_credential import ( _parse_credential_output, list_credential_helpers, set_git_credential, unset_git_credential, ) STORE_AND_CACHE_HELPERS_CONFIG = """ [credential] helper = store helper = cache --timeout 30000 """ @pytest.mark.usefixtures("fx_cache_dir") class TestGitCredentials(unittest.TestCase): cache_dir: Path def setUp(self): """Initialize and configure a local repo. Avoid to configure git helpers globally on a contributor's machine. """ run_subprocess("git init", folder=self.cache_dir) with (self.cache_dir / ".git" / "config").open("w") as f: f.write(STORE_AND_CACHE_HELPERS_CONFIG) def test_list_credential_helpers(self) -> None: helpers = list_credential_helpers(folder=self.cache_dir) self.assertIn("cache", helpers) self.assertIn("store", helpers) def test_set_and_unset_git_credential(self) -> None: username = "hf_test_user_" + str(round(time.time())) # make username unique # Set credentials set_git_credential(token="hf_test_token", username=username, folder=self.cache_dir) # Check credentials are stored with run_interactive_subprocess("git credential fill", folder=self.cache_dir) as (stdin, stdout): stdin.write(f"url={ENDPOINT}\nusername={username}\n\n") stdin.flush() output = stdout.read() self.assertIn("password=hf_test_token", output) # Unset credentials unset_git_credential(username=username, folder=self.cache_dir) # Check credentials are NOT stored # Cannot check with `git credential fill` as it would hang forever: only # checking `store` helper instead. with run_interactive_subprocess("git credential-store get", folder=self.cache_dir) as (stdin, stdout): stdin.write(f"url={ENDPOINT}\nusername={username}\n\n") stdin.flush() output = stdout.read() self.assertEqual("", output) def test_git_credential_parsing_regex(self) -> None: output = """ credential.helper = store credential.helper = cache --timeout 30000 credential.helper = osxkeychain""" assert _parse_credential_output(output) == ["cache", "osxkeychain", "store"] huggingface_hub-0.31.1/tests/test_utils_headers.py000066400000000000000000000164231500667546600223340ustar00rootroot00000000000000import unittest from unittest.mock import Mock, patch from huggingface_hub.utils import get_hf_hub_version, get_python_version from huggingface_hub.utils._headers import _deduplicate_user_agent, _http_user_agent, build_hf_headers from .testing_utils import handle_injection_in_test # Only for tests that are not related to user agent DEFAULT_USER_AGENT = _http_user_agent() FAKE_TOKEN = "123456789" FAKE_TOKEN_ORG = "api_org_123456789" FAKE_TOKEN_HEADER = { "authorization": f"Bearer {FAKE_TOKEN}", "user-agent": DEFAULT_USER_AGENT, } NO_AUTH_HEADER = {"user-agent": DEFAULT_USER_AGENT} # @patch("huggingface_hub.utils._headers.HfFolder") # @handle_injection class TestAuthHeadersUtil(unittest.TestCase): def test_use_auth_token_str(self) -> None: self.assertEqual(build_hf_headers(use_auth_token=FAKE_TOKEN), FAKE_TOKEN_HEADER) @patch("huggingface_hub.utils._headers.get_token", return_value=None) def test_use_auth_token_true_no_cached_token(self, mock_get_token: Mock) -> None: with self.assertRaises(EnvironmentError): build_hf_headers(use_auth_token=True) @patch("huggingface_hub.utils._headers.get_token", return_value=FAKE_TOKEN) def test_use_auth_token_true_has_cached_token(self, mock_get_token: Mock) -> None: self.assertEqual(build_hf_headers(use_auth_token=True), FAKE_TOKEN_HEADER) @patch("huggingface_hub.utils._headers.get_token", return_value=FAKE_TOKEN) def test_use_auth_token_false(self, mock_get_token: Mock) -> None: self.assertEqual(build_hf_headers(use_auth_token=False), NO_AUTH_HEADER) @patch("huggingface_hub.utils._headers.get_token", return_value=None) def test_use_auth_token_none_no_cached_token(self, mock_get_token: Mock) -> None: self.assertEqual(build_hf_headers(), NO_AUTH_HEADER) @patch("huggingface_hub.utils._headers.get_token", return_value=FAKE_TOKEN) def test_use_auth_token_none_has_cached_token(self, mock_get_token: Mock) -> None: self.assertEqual(build_hf_headers(), FAKE_TOKEN_HEADER) @patch("huggingface_hub.utils._headers.get_token", return_value=FAKE_TOKEN) def test_implicit_use_disabled(self, mock_get_token: Mock) -> None: with patch( # not as decorator to avoid friction with @handle_injection "huggingface_hub.constants.HF_HUB_DISABLE_IMPLICIT_TOKEN", True ): self.assertEqual(build_hf_headers(), NO_AUTH_HEADER) # token is not sent @patch("huggingface_hub.utils._headers.get_token", return_value=FAKE_TOKEN) def test_implicit_use_disabled_but_explicit_use(self, mock_get_token: Mock) -> None: with patch( # not as decorator to avoid friction with @handle_injection "huggingface_hub.constants.HF_HUB_DISABLE_IMPLICIT_TOKEN", True ): # This is not an implicit use so we still send it self.assertEqual(build_hf_headers(use_auth_token=True), FAKE_TOKEN_HEADER) class TestUserAgentHeadersUtil(unittest.TestCase): def _get_user_agent(self, **kwargs) -> str: return build_hf_headers(**kwargs)["user-agent"] @patch("huggingface_hub.utils._headers.get_fastai_version") @patch("huggingface_hub.utils._headers.get_fastcore_version") @patch("huggingface_hub.utils._headers.get_tf_version") @patch("huggingface_hub.utils._headers.get_torch_version") @patch("huggingface_hub.utils._headers.is_fastai_available") @patch("huggingface_hub.utils._headers.is_fastcore_available") @patch("huggingface_hub.utils._headers.is_tf_available") @patch("huggingface_hub.utils._headers.is_torch_available") @handle_injection_in_test def test_default_user_agent( self, mock_get_fastai_version: Mock, mock_get_fastcore_version: Mock, mock_get_tf_version: Mock, mock_get_torch_version: Mock, mock_is_fastai_available: Mock, mock_is_fastcore_available: Mock, mock_is_tf_available: Mock, mock_is_torch_available: Mock, ) -> None: mock_get_fastai_version.return_value = "fastai_version" mock_get_fastcore_version.return_value = "fastcore_version" mock_get_tf_version.return_value = "tf_version" mock_get_torch_version.return_value = "torch_version" mock_is_fastai_available.return_value = True mock_is_fastcore_available.return_value = True mock_is_tf_available.return_value = True mock_is_torch_available.return_value = True self.assertEqual( self._get_user_agent(), f"unknown/None; hf_hub/{get_hf_hub_version()};" f" python/{get_python_version()}; torch/torch_version;" " tensorflow/tf_version; fastai/fastai_version;" " fastcore/fastcore_version", ) @patch("huggingface_hub.utils._headers.is_torch_available") @patch("huggingface_hub.utils._headers.is_tf_available") @handle_injection_in_test def test_user_agent_with_library_name_multiple_missing( self, mock_is_torch_available: Mock, mock_is_tf_available: Mock ) -> None: mock_is_torch_available.return_value = False mock_is_tf_available.return_value = False self.assertNotIn("torch", self._get_user_agent()) self.assertNotIn("tensorflow", self._get_user_agent()) def test_user_agent_with_library_name_and_version(self) -> None: self.assertTrue( self._get_user_agent( library_name="foo", library_version="bar", ).startswith("foo/bar;") ) def test_user_agent_with_library_name_no_version(self) -> None: self.assertTrue(self._get_user_agent(library_name="foo").startswith("foo/None;")) def test_user_agent_with_custom_agent_string(self) -> None: self.assertTrue(self._get_user_agent(user_agent="this is a custom agent").endswith("this is a custom agent")) def test_user_agent_with_custom_agent_dict(self) -> None: self.assertTrue(self._get_user_agent(user_agent={"a": "b", "c": "d"}).endswith("a/b; c/d")) def test_user_agent_deduplicate(self) -> None: self.assertEqual( _deduplicate_user_agent( "python/3.7; python/3.8; hf_hub/0.12; transformers/None; hf_hub/0.12; python/3.7; diffusers/0.12.1" ), # 1. "python" is kept twice with different values # 2. "python/3.7" second occurrence is removed # 3. "hf_hub" second occurrence is removed # 4. order is preserved "python/3.7; python/3.8; hf_hub/0.12; transformers/None; diffusers/0.12.1", ) @patch("huggingface_hub.utils._telemetry.constants.HF_HUB_USER_AGENT_ORIGIN", "custom-origin") def test_user_agent_with_origin(self) -> None: self.assertTrue(self._get_user_agent().endswith("origin/custom-origin")) @patch("huggingface_hub.utils._telemetry.constants.HF_HUB_USER_AGENT_ORIGIN", "custom-origin") def test_user_agent_with_origin_and_user_agent(self) -> None: self.assertTrue( self._get_user_agent(user_agent={"a": "b", "c": "d"}).endswith("a/b; c/d; origin/custom-origin") ) @patch("huggingface_hub.utils._telemetry.constants.HF_HUB_USER_AGENT_ORIGIN", "custom-origin") def test_user_agent_with_origin_and_user_agent_str(self) -> None: self.assertTrue(self._get_user_agent(user_agent="a/b;c/d").endswith("a/b; c/d; origin/custom-origin")) huggingface_hub-0.31.1/tests/test_utils_hf_folder.py000066400000000000000000000035401500667546600226450ustar00rootroot00000000000000# Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contain tests for `HfFolder` utility.""" import os import unittest from uuid import uuid4 from huggingface_hub.utils import HfFolder def _generate_token() -> str: return f"token-{uuid4()}" class HfFolderTest(unittest.TestCase): def test_token_workflow(self): """ Test the whole token save/get/delete workflow, with the desired behavior with respect to non-existent tokens. """ token = _generate_token() HfFolder.save_token(token) self.assertEqual(HfFolder.get_token(), token) HfFolder.delete_token() HfFolder.delete_token() # ^^ not an error, we test that the # second call does not fail. self.assertEqual(HfFolder.get_token(), None) # test TOKEN in env self.assertEqual(HfFolder.get_token(), None) with unittest.mock.patch.dict(os.environ, {"HF_TOKEN": token}): self.assertEqual(HfFolder.get_token(), token) def test_token_strip(self): """ Test the workflow when the token is mistakenly finishing with new-line or space character. """ token = _generate_token() HfFolder.save_token(" " + token + "\n") self.assertEqual(HfFolder.get_token(), token) HfFolder.delete_token() huggingface_hub-0.31.1/tests/test_utils_http.py000066400000000000000000000302711500667546600216750ustar00rootroot00000000000000import os import threading import time import unittest from multiprocessing import Process, Queue from typing import Generator, Optional from unittest.mock import Mock, call, patch from uuid import UUID import pytest import requests from requests import ConnectTimeout, HTTPError from huggingface_hub.constants import ENDPOINT from huggingface_hub.utils._http import ( OfflineModeIsEnabled, _adjust_range_header, configure_http_backend, fix_hf_endpoint_in_url, get_session, http_backoff, reset_sessions, ) URL = "https://www.google.com" class TestHttpBackoff(unittest.TestCase): def setUp(self) -> None: get_session_mock = Mock() self.mock_request = get_session_mock().request self.patcher = patch("huggingface_hub.utils._http.get_session", get_session_mock) self.patcher.start() def tearDown(self) -> None: self.patcher.stop() def test_backoff_no_errors(self) -> None: """Test normal usage of `http_backoff`.""" data_mock = Mock() response = http_backoff("GET", URL, data=data_mock) self.mock_request.assert_called_once_with(method="GET", url=URL, data=data_mock) self.assertIs(response, self.mock_request()) def test_backoff_3_calls(self) -> None: """Test `http_backoff` with 2 fails.""" response_mock = Mock() self.mock_request.side_effect = (ValueError(), ValueError(), response_mock) response = http_backoff( # retry on ValueError, instant retry "GET", URL, retry_on_exceptions=ValueError, base_wait_time=0.0 ) self.assertEqual(self.mock_request.call_count, 3) self.mock_request.assert_has_calls( calls=[ call(method="GET", url=URL), call(method="GET", url=URL), call(method="GET", url=URL), ] ) self.assertIs(response, response_mock) def test_backoff_on_exception_until_max(self) -> None: """Test `http_backoff` until max limit is reached with exceptions.""" self.mock_request.side_effect = ConnectTimeout() with self.assertRaises(ConnectTimeout): http_backoff("GET", URL, base_wait_time=0.0, max_retries=3) self.assertEqual(self.mock_request.call_count, 4) def test_backoff_on_status_code_until_max(self) -> None: """Test `http_backoff` until max limit is reached with status codes.""" mock_503 = Mock() mock_503.status_code = 503 mock_504 = Mock() mock_504.status_code = 504 mock_504.raise_for_status.side_effect = HTTPError() self.mock_request.side_effect = (mock_503, mock_504, mock_503, mock_504) with self.assertRaises(HTTPError): http_backoff( "GET", URL, base_wait_time=0.0, max_retries=3, retry_on_status_codes=(503, 504), ) self.assertEqual(self.mock_request.call_count, 4) def test_backoff_on_exceptions_and_status_codes(self) -> None: """Test `http_backoff` until max limit with status codes and exceptions.""" mock_503 = Mock() mock_503.status_code = 503 self.mock_request.side_effect = (mock_503, ConnectTimeout()) with self.assertRaises(ConnectTimeout): http_backoff("GET", URL, base_wait_time=0.0, max_retries=1) self.assertEqual(self.mock_request.call_count, 2) def test_backoff_on_valid_status_code(self) -> None: """Test `http_backoff` until max limit with a valid status code. Quite a corner case: the user wants to retry is status code is 200. Requests are retried but in the end, the HTTP 200 response is returned if the server returned only 200 responses. """ mock_200 = Mock() mock_200.status_code = 200 self.mock_request.side_effect = (mock_200, mock_200, mock_200, mock_200) response = http_backoff("GET", URL, base_wait_time=0.0, max_retries=3, retry_on_status_codes=200) self.assertEqual(self.mock_request.call_count, 4) self.assertIs(response, mock_200) def test_backoff_sleep_time(self) -> None: """Test `http_backoff` sleep time goes exponential until max limit. Since timing between 2 requests is sleep duration + some other stuff, this test can be unstable. However, sleep durations between 10ms and 50ms should be enough to make the approximation that measured durations are the "sleep time" waited by `http_backoff`. If this is not the case, just increase `base_wait_time`, `max_wait_time` and `expected_sleep_times` with bigger values. """ sleep_times = [] def _side_effect_timer() -> Generator[ConnectTimeout, None, None]: t0 = time.time() while True: yield ConnectTimeout() t1 = time.time() sleep_times.append(round(t1 - t0, 1)) t0 = t1 self.mock_request.side_effect = _side_effect_timer() with self.assertRaises(ConnectTimeout): http_backoff("GET", URL, base_wait_time=0.1, max_wait_time=0.5, max_retries=5) self.assertEqual(self.mock_request.call_count, 6) # Assert sleep times are exponential until plateau expected_sleep_times = [0.1, 0.2, 0.4, 0.5, 0.5] self.assertListEqual(sleep_times, expected_sleep_times) class TestConfigureSession(unittest.TestCase): def setUp(self) -> None: # Reconfigure + clear session cache between each test configure_http_backend() @classmethod def tearDownClass(cls) -> None: # Clear all sessions after tests configure_http_backend() @staticmethod def _factory() -> requests.Session: session = requests.Session() session.headers.update({"x-test-header": 4}) return session def test_default_configuration(self) -> None: session = get_session() self.assertEqual(session.headers["connection"], "keep-alive") # keep connection alive by default self.assertIsNone(session.auth) self.assertEqual(session.proxies, {}) self.assertEqual(session.verify, True) self.assertIsNone(session.cert) self.assertEqual(session.max_redirects, 30) self.assertEqual(session.trust_env, True) self.assertEqual(session.hooks, {"response": []}) def test_set_configuration(self) -> None: configure_http_backend(backend_factory=self._factory) # Check headers have been set correctly session = get_session() self.assertNotEqual(session.headers, {"x-test-header": 4}) self.assertEqual(session.headers["x-test-header"], 4) def test_get_session_twice(self): session_1 = get_session() session_2 = get_session() self.assertIs(session_1, session_2) # exact same instance def test_get_session_twice_but_reconfigure_in_between(self): """Reconfiguring the session clears the cache.""" session_1 = get_session() configure_http_backend(backend_factory=self._factory) session_2 = get_session() self.assertIsNot(session_1, session_2) self.assertIsNone(session_1.headers.get("x-test-header")) self.assertEqual(session_2.headers["x-test-header"], 4) def test_get_session_multiple_threads(self): N = 3 sessions = [None] * N def _get_session_in_thread(index: int) -> None: time.sleep(0.1) sessions[index] = get_session() # Get main thread session main_session = get_session() # Start 3 threads and get sessions in each of them threads = [threading.Thread(target=_get_session_in_thread, args=(index,)) for index in range(N)] for th in threads: th.start() print(th) for th in threads: th.join() # Check all sessions are different for i in range(N): self.assertIsNot(main_session, sessions[i]) for j in range(N): if i != j: self.assertIsNot(sessions[i], sessions[j]) @unittest.skipIf(os.name == "nt", "Works differently on Windows.") def test_get_session_in_forked_process(self): # Get main process session main_session = get_session() def _child_target(): # Put `repr(session)` in queue because putting the `Session` object directly would duplicate it. # Repr looks like this: "" process_queue.put(repr(get_session())) # Fork a new process and get session in it process_queue = Queue() Process(target=_child_target).start() child_session = process_queue.get() # Check sessions are different self.assertNotEqual(repr(main_session), child_session) class OfflineModeSessionTest(unittest.TestCase): def tearDown(self) -> None: reset_sessions() return super().tearDown() @patch("huggingface_hub.constants.HF_HUB_OFFLINE", True) def test_offline_mode(self): configure_http_backend() session = get_session() with self.assertRaises(OfflineModeIsEnabled): session.get("https://huggingface.co") class TestUniqueRequestId(unittest.TestCase): api_endpoint = ENDPOINT + "/api/tasks" # any endpoint is fine def test_request_id_is_used_by_server(self): response = get_session().get(self.api_endpoint) request_id = response.request.headers.get("X-Amzn-Trace-Id") response_id = response.headers.get("x-request-id") self.assertIn(request_id, response_id) self.assertTrue(_is_uuid(request_id)) def test_request_id_is_unique(self): response_1 = get_session().get(self.api_endpoint) response_2 = get_session().get(self.api_endpoint) request_id_1 = response_1.request.headers["X-Amzn-Trace-Id"] request_id_2 = response_2.request.headers["X-Amzn-Trace-Id"] self.assertNotEqual(request_id_1, request_id_2) self.assertTrue(_is_uuid(request_id_1)) self.assertTrue(_is_uuid(request_id_2)) def test_request_id_not_overwritten(self): response = get_session().get(self.api_endpoint, headers={"x-request-id": "custom-id"}) request_id = response.request.headers["x-request-id"] self.assertEqual(request_id, "custom-id") response_id = response.headers["x-request-id"] self.assertEqual(response_id, "custom-id") def _is_uuid(string: str) -> bool: # Taken from https://stackoverflow.com/a/33245493 try: uuid_obj = UUID(string) except ValueError: return False return str(uuid_obj) == string @pytest.mark.parametrize( ("base_url", "endpoint", "expected_url"), [ # Staging url => unchanged ("https://hub-ci.huggingface.co/resolve/...", None, "https://hub-ci.huggingface.co/resolve/..."), # Prod url => unchanged ("https://huggingface.co/resolve/...", None, "https://huggingface.co/resolve/..."), # Custom endpoint + staging url => fixed ("https://hub-ci.huggingface.co/api/models", "https://mirror.co", "https://mirror.co/api/models"), # Custom endpoint + prod url => fixed ("https://huggingface.co/api/models", "https://mirror.co", "https://mirror.co/api/models"), ], ) def test_fix_hf_endpoint_in_url(base_url: str, endpoint: Optional[str], expected_url: str) -> None: assert fix_hf_endpoint_in_url(base_url, endpoint) == expected_url def test_adjust_range_header(): # Basic cases assert _adjust_range_header(None, 10) == "bytes=10-" assert _adjust_range_header("bytes=0-100", 10) == "bytes=10-100" assert _adjust_range_header("bytes=-100", 10) == "bytes=-90" assert _adjust_range_header("bytes=100-", 10) == "bytes=110-" with pytest.raises(RuntimeError): _adjust_range_header("invalid", 10) with pytest.raises(RuntimeError): _adjust_range_header("bytes=-", 10) # Multiple ranges with pytest.raises(ValueError): _adjust_range_header("bytes=0-100,200-300", 10) # Resume size exceeds range with pytest.raises(RuntimeError): _adjust_range_header("bytes=0-100", 150) with pytest.raises(RuntimeError): _adjust_range_header("bytes=-50", 100) huggingface_hub-0.31.1/tests/test_utils_pagination.py000066400000000000000000000055201500667546600230460ustar00rootroot00000000000000import unittest from unittest.mock import Mock, call, patch from huggingface_hub.utils._pagination import paginate from .testing_utils import handle_injection_in_test class TestPagination(unittest.TestCase): @patch("huggingface_hub.utils._pagination.get_session") @patch("huggingface_hub.utils._pagination.http_backoff") @patch("huggingface_hub.utils._pagination.hf_raise_for_status") @handle_injection_in_test def test_mocked_paginate( self, mock_get_session: Mock, mock_http_backoff: Mock, mock_hf_raise_for_status: Mock ) -> None: mock_get = mock_get_session().get mock_params = Mock() mock_headers = Mock() # Simulate page 1 mock_response_page_1 = Mock() mock_response_page_1.json.return_value = [1, 2, 3] mock_response_page_1.links = {"next": {"url": "url_p2"}} # Simulate page 2 mock_response_page_2 = Mock() mock_response_page_2.json.return_value = [4, 5, 6] mock_response_page_2.links = {"next": {"url": "url_p3"}} # Simulate page 3 mock_response_page_3 = Mock() mock_response_page_3.json.return_value = [7, 8] mock_response_page_3.links = {} # Mock response mock_get.side_effect = [ mock_response_page_1, ] mock_http_backoff.side_effect = [ mock_response_page_2, mock_response_page_3, ] results = paginate("url", params=mock_params, headers=mock_headers) # Requests are made only when generator is yielded assert mock_get.call_count == 0 # Results after concatenating pages assert list(results) == [1, 2, 3, 4, 5, 6, 7, 8] # All pages requested: 3 requests, 3 raise for status # First request with `get_session.get` (we want at least 1 request to succeed correctly) and 2 with `http_backoff` assert mock_get.call_count == 1 assert mock_http_backoff.call_count == 2 assert mock_hf_raise_for_status.call_count == 3 # Params not passed to next pages assert mock_get.call_args_list == [call("url", params=mock_params, headers=mock_headers)] assert mock_http_backoff.call_args_list == [ call("GET", "url_p2", max_retries=20, retry_on_status_codes=429, headers=mock_headers), call("GET", "url_p3", max_retries=20, retry_on_status_codes=429, headers=mock_headers), ] def test_paginate_github_api(self) -> None: # Real test: paginate over huggingface repos on Github # Use enumerate and stop after first page to avoid loading all repos for num, _ in enumerate( paginate("https://api.github.com/orgs/huggingface/repos?limit=4", params={}, headers={}) ): if num == 6: break else: self.fail("Did not get more than 6 repos") huggingface_hub-0.31.1/tests/test_utils_paths.py000066400000000000000000000115161500667546600220360ustar00rootroot00000000000000import unittest from dataclasses import dataclass from pathlib import Path from typing import Any, Callable, List, Optional, Union from huggingface_hub.utils import DEFAULT_IGNORE_PATTERNS, filter_repo_objects @dataclass class DummyObject: path: Path DUMMY_FILES = ["not_hidden.pdf", "profile.jpg", ".hidden.pdf", ".hidden_picture.png"] DUMMY_PATHS = [Path(path) for path in DUMMY_FILES] DUMMY_OBJECTS = [DummyObject(path=path) for path in DUMMY_FILES] class TestPathsUtils(unittest.TestCase): def test_get_all_pdfs(self) -> None: """Get all PDFs even hidden ones.""" self._check( items=DUMMY_FILES, expected_items=["not_hidden.pdf", ".hidden.pdf"], allow_patterns=["*.pdf"], ) def test_get_all_pdfs_except_hidden(self) -> None: """Get all PDFs except hidden ones.""" self._check( items=DUMMY_FILES, expected_items=["not_hidden.pdf"], allow_patterns=["*.pdf"], ignore_patterns=[".*"], ) def test_get_all_pdfs_except_hidden_using_single_pattern(self) -> None: """Get all PDFs except hidden ones, using single pattern.""" self._check( items=DUMMY_FILES, expected_items=["not_hidden.pdf"], allow_patterns="*.pdf", # not a list ignore_patterns=".*", # not a list ) def test_get_all_images(self) -> None: """Get all images.""" self._check( items=DUMMY_FILES, expected_items=["profile.jpg", ".hidden_picture.png"], allow_patterns=["*.png", "*.jpg"], ) def test_get_all_images_except_hidden_from_paths(self) -> None: """Get all images except hidden ones, from Path list.""" self._check( items=DUMMY_PATHS, expected_items=[Path("profile.jpg")], allow_patterns=["*.png", "*.jpg"], ignore_patterns=".*", ) def test_get_all_images_except_hidden_from_objects(self) -> None: """Get all images except hidden ones, from object list.""" self._check( items=DUMMY_OBJECTS, expected_items=[DummyObject(path="profile.jpg")], allow_patterns=["*.png", "*.jpg"], ignore_patterns=".*", key=lambda x: x.path, ) def test_filter_objects_key_not_provided(self) -> None: """Test ValueError is raised if filtering non-string objects.""" with self.assertRaisesRegex(ValueError, "Please provide `key` argument"): list( filter_repo_objects( items=DUMMY_OBJECTS, allow_patterns=["*.png", "*.jpg"], ignore_patterns=".*", ) ) def test_filter_object_with_folder(self) -> None: self._check( items=[ "file.txt", "lfs.bin", "path/to/file.txt", "path/to/lfs.bin", "nested/path/to/file.txt", "nested/path/to/lfs.bin", ], expected_items=["path/to/file.txt", "path/to/lfs.bin"], allow_patterns=["path/to/"], ) def _check( self, items: List[Any], expected_items: List[Any], allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, key: Optional[Callable[[Any], str]] = None, ) -> None: """Run `filter_repo_objects` and check output against expected result.""" self.assertListEqual( list( filter_repo_objects( items=items, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=key, ) ), expected_items, ) class TestDefaultIgnorePatterns(unittest.TestCase): PATHS_TO_IGNORE = [ ".git", ".git/file.txt", ".git/folder/file.txt", "path/to/folder/.git", "path/to/folder/.git/file.txt", "path/to/.git/folder/file.txt", ".cache/huggingface", ".cache/huggingface/file.txt", ".cache/huggingface/folder/file.txt", "path/to/.cache/huggingface", "path/to/.cache/huggingface/file.txt", ] VALID_PATHS = [ ".gitignore", "path/foo.git/file.txt", "path/.git_bar/file.txt", "path/to/file.git", "file.huggingface", "path/file.huggingface", ".cache/huggingface_folder", ".cache/huggingface_folder/file.txt", ] def test_exclude_git_folder(self): filtered_paths = filter_repo_objects( items=self.PATHS_TO_IGNORE + self.VALID_PATHS, ignore_patterns=DEFAULT_IGNORE_PATTERNS ) self.assertListEqual(list(filtered_paths), self.VALID_PATHS) huggingface_hub-0.31.1/tests/test_utils_runtime.py000066400000000000000000000005671500667546600224060ustar00rootroot00000000000000import unittest from huggingface_hub.utils._runtime import is_google_colab, is_notebook class TestRuntimeUtils(unittest.TestCase): def test_is_notebook(self) -> None: """Test `is_notebook`.""" self.assertFalse(is_notebook()) def test_is_google_colab(self) -> None: """Test `is_google_colab`.""" self.assertFalse(is_google_colab()) huggingface_hub-0.31.1/tests/test_utils_sha.py000066400000000000000000000026141500667546600214710ustar00rootroot00000000000000import os import subprocess from hashlib import sha256 from io import BytesIO from huggingface_hub.utils import SoftTemporaryDirectory from huggingface_hub.utils.sha import git_hash, sha_fileobj def test_sha_fileobj(): with SoftTemporaryDirectory() as tmpdir: content = b"Random content" * 1000 sha = sha256(content).digest() # Test with file object filepath = os.path.join(tmpdir, "file.bin") with open(filepath, "wb+") as file: file.write(content) with open(filepath, "rb") as fileobj: assert sha_fileobj(fileobj, None) == sha with open(filepath, "rb") as fileobj: assert sha_fileobj(fileobj, 50) == sha with open(filepath, "rb") as fileobj: assert sha_fileobj(fileobj, 50_000) == sha # Test with in-memory file object assert sha_fileobj(BytesIO(content), None) == sha assert sha_fileobj(BytesIO(content), 50) == sha assert sha_fileobj(BytesIO(content), 50_000) == sha def test_git_hash(tmpdir): """Test the `git_hash` output is the same as `git hash-object` command.""" path = os.path.join(tmpdir, "file.txt") with open(path, "wb") as file: file.write(b"Hello, World!") output = subprocess.run(f"git hash-object -t blob {path}", shell=True, capture_output=True, text=True) assert output.stdout.strip() == git_hash(b"Hello, World!") huggingface_hub-0.31.1/tests/test_utils_telemetry.py000066400000000000000000000104041500667546600227240ustar00rootroot00000000000000import unittest from queue import Queue from unittest.mock import Mock, patch from huggingface_hub.utils._telemetry import send_telemetry from .testing_constants import ENDPOINT_STAGING @patch("huggingface_hub.utils._telemetry._TELEMETRY_QUEUE", new_callable=Queue) @patch("huggingface_hub.utils._telemetry._TELEMETRY_THREAD", None) class TestSendTelemetry(unittest.TestCase): def setUp(self) -> None: get_session_mock = Mock() self.mock_head = get_session_mock().head self.patcher = patch("huggingface_hub.utils._telemetry.get_session", get_session_mock) self.patcher.start() def tearDown(self) -> None: self.patcher.stop() def test_topic_normal(self, queue: Queue) -> None: send_telemetry(topic="examples") queue.join() # Wait for the telemetry tasks to be completed self.mock_head.assert_called_once() self.assertEqual(self.mock_head.call_args[0][0], f"{ENDPOINT_STAGING}/api/telemetry/examples") def test_topic_multiple(self, queue: Queue) -> None: send_telemetry(topic="example1") send_telemetry(topic="example2") send_telemetry(topic="example3") queue.join() # Wait for the telemetry tasks to be completed self.assertEqual(self.mock_head.call_count, 3) # 3 calls and order is preserved self.assertEqual(self.mock_head.call_args_list[0][0][0], f"{ENDPOINT_STAGING}/api/telemetry/example1") self.assertEqual(self.mock_head.call_args_list[1][0][0], f"{ENDPOINT_STAGING}/api/telemetry/example2") self.assertEqual(self.mock_head.call_args_list[2][0][0], f"{ENDPOINT_STAGING}/api/telemetry/example3") def test_topic_with_subtopic(self, queue: Queue) -> None: send_telemetry(topic="gradio/image/this_one") queue.join() # Wait for the telemetry tasks to be completed self.mock_head.assert_called_once() self.assertEqual(self.mock_head.call_args[0][0], f"{ENDPOINT_STAGING}/api/telemetry/gradio/image/this_one") def test_topic_quoted(self, queue: Queue) -> None: send_telemetry(topic="foo bar") queue.join() # Wait for the telemetry tasks to be completed self.mock_head.assert_called_once() self.assertEqual(self.mock_head.call_args[0][0], f"{ENDPOINT_STAGING}/api/telemetry/foo%20bar") @patch("huggingface_hub.utils._telemetry.constants.HF_HUB_OFFLINE", True) def test_hub_offline(self, queue: Queue) -> None: send_telemetry(topic="topic") self.assertTrue(queue.empty()) # no tasks self.mock_head.assert_not_called() @patch("huggingface_hub.utils._telemetry.constants.HF_HUB_DISABLE_TELEMETRY", True) def test_telemetry_disabled(self, queue: Queue) -> None: send_telemetry(topic="topic") self.assertTrue(queue.empty()) # no tasks self.mock_head.assert_not_called() @patch("huggingface_hub.utils._telemetry.build_hf_headers") def test_telemetry_use_build_hf_headers(self, mock_headers: Mock, queue: Queue) -> None: send_telemetry(topic="topic") queue.join() # Wait for the telemetry tasks to be completed self.mock_head.assert_called_once() mock_headers.assert_called_once() self.assertEqual(self.mock_head.call_args[1]["headers"], mock_headers.return_value) @patch("huggingface_hub.utils._telemetry._TELEMETRY_QUEUE", new_callable=Queue) @patch("huggingface_hub.utils._telemetry._TELEMETRY_THREAD", None) class TestSendTelemetryConnectionError(unittest.TestCase): def setUp(self) -> None: get_session_mock = Mock() get_session_mock().head.side_effect = Exception("whatever") self.patcher = patch("huggingface_hub.utils._telemetry.get_session", get_session_mock) self.patcher.start() def tearDown(self) -> None: self.patcher.stop() def test_telemetry_exception_silenced(self, queue: Queue) -> None: with self.assertLogs(logger="huggingface_hub.utils._telemetry", level="DEBUG") as captured: send_telemetry(topic="topic") queue.join() # Assert debug message with traceback for debug purposes self.assertEqual(len(captured.output), 1) self.assertEqual( captured.output[0], "DEBUG:huggingface_hub.utils._telemetry:Error while sending telemetry: whatever", ) huggingface_hub-0.31.1/tests/test_utils_tqdm.py000066400000000000000000000222551500667546600216660ustar00rootroot00000000000000import time import unittest from pathlib import Path from unittest.mock import patch import pytest from pytest import CaptureFixture from huggingface_hub.utils import ( SoftTemporaryDirectory, are_progress_bars_disabled, disable_progress_bars, enable_progress_bars, tqdm, tqdm_stream_file, ) class CapsysBaseTest(unittest.TestCase): @pytest.fixture(autouse=True) def capsys(self, capsys: CaptureFixture) -> None: """Workaround to make capsys work in unittest framework. Capsys is a convenient pytest fixture to capture stdout. See https://waylonwalker.com/pytest-capsys/. Taken from https://github.com/pytest-dev/pytest/issues/2504#issuecomment-309475790. """ self.capsys = capsys class TestTqdmUtils(CapsysBaseTest): def setUp(self) -> None: """Get verbosity to set it back after the tests.""" self._previous_are_progress_bars_disabled = are_progress_bars_disabled() return super().setUp() def tearDown(self) -> None: """Set back progress bars verbosity as before testing.""" if self._previous_are_progress_bars_disabled: disable_progress_bars() else: enable_progress_bars() @patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", None) def test_tqdm_helpers(self) -> None: """Test helpers to enable/disable progress bars.""" disable_progress_bars() assert are_progress_bars_disabled() enable_progress_bars() assert not are_progress_bars_disabled() @patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", True) def test_cannot_enable_tqdm_when_env_variable_is_set(self) -> None: """ Test helpers cannot enable/disable progress bars when `HF_HUB_DISABLE_PROGRESS_BARS` is set. """ disable_progress_bars() assert are_progress_bars_disabled() with self.assertWarns(UserWarning): enable_progress_bars() assert are_progress_bars_disabled() # Still disabled @patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", False) def test_cannot_disable_tqdm_when_env_variable_is_set(self) -> None: """ Test helpers cannot enable/disable progress bars when `HF_HUB_DISABLE_PROGRESS_BARS` is set. """ enable_progress_bars() assert not are_progress_bars_disabled() with self.assertWarns(UserWarning): disable_progress_bars() assert not are_progress_bars_disabled() # Still enabled @patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", None) def test_tqdm_disabled(self) -> None: """Test TQDM not outputting anything when globally disabled.""" disable_progress_bars() for _ in tqdm(range(10)): pass captured = self.capsys.readouterr() self.assertEqual(captured.out, "") self.assertEqual(captured.err, "") @patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", None) def test_tqdm_disabled_cannot_be_forced(self) -> None: """Test TQDM cannot be forced when globally disabled.""" disable_progress_bars() for _ in tqdm(range(10), disable=False): pass captured = self.capsys.readouterr() self.assertEqual(captured.out, "") self.assertEqual(captured.err, "") @patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", None) def test_tqdm_can_be_disabled_when_globally_enabled(self) -> None: """Test TQDM can still be locally disabled even when globally enabled.""" enable_progress_bars() for _ in tqdm(range(10), disable=True): pass captured = self.capsys.readouterr() self.assertEqual(captured.out, "") self.assertEqual(captured.err, "") @patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", None) def test_tqdm_enabled(self) -> None: """Test TQDM work normally when globally enabled.""" enable_progress_bars() for _ in tqdm(range(10)): pass captured = self.capsys.readouterr() self.assertEqual(captured.out, "") self.assertIn("10/10", captured.err) # tqdm log def test_tqdm_stream_file(self) -> None: with SoftTemporaryDirectory() as tmpdir: filepath = Path(tmpdir) / "config.json" with filepath.open("w") as f: f.write("#" * 1000) with tqdm_stream_file(filepath) as f: while True: data = f.read(100) if not data: break time.sleep(0.001) # Simulate a delay between each chunk captured = self.capsys.readouterr() self.assertEqual(captured.out, "") self.assertIn("config.json: 100%", captured.err) # log file name self.assertIn("|█████████", captured.err) # tqdm bar self.assertIn("1.00k/1.00k", captured.err) # size in B class TestTqdmGroup(CapsysBaseTest): def setUp(self): """Set up the initial condition for each test.""" super().setUp() enable_progress_bars() # Ensure all are enabled before each test def tearDown(self): """Clean up after each test.""" super().tearDown() enable_progress_bars() @patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", None) def test_disable_specific_group(self): """Test disabling a specific group only affects that group and its subgroups.""" disable_progress_bars("peft.foo") assert not are_progress_bars_disabled("peft") assert not are_progress_bars_disabled("peft.something") assert are_progress_bars_disabled("peft.foo") assert are_progress_bars_disabled("peft.foo.bar") @patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", None) def test_enable_specific_subgroup(self): """Test that enabling a subgroup does not affect the disabled state of its parent.""" disable_progress_bars("peft.foo") enable_progress_bars("peft.foo.bar") assert are_progress_bars_disabled("peft.foo") assert not are_progress_bars_disabled("peft.foo.bar") @patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", True) def test_disable_override_by_environment_variable(self): """Ensure progress bars are disabled regardless of local settings when environment variable is set.""" with self.assertWarns(UserWarning): enable_progress_bars() assert are_progress_bars_disabled("peft") assert are_progress_bars_disabled("peft.foo") @patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", False) def test_enable_override_by_environment_variable(self): """Ensure progress bars are enabled regardless of local settings when environment variable is set.""" with self.assertWarns(UserWarning): disable_progress_bars("peft.foo") assert not are_progress_bars_disabled("peft.foo") @patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", None) def test_partial_group_name_not_affected(self): """Ensure groups with similar names but not exactly matching are not affected.""" disable_progress_bars("peft.foo") assert not are_progress_bars_disabled("peft.footprint") @patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", None) def test_nested_subgroup_behavior(self): """Test enabling and disabling nested subgroups.""" disable_progress_bars("peft") enable_progress_bars("peft.foo") disable_progress_bars("peft.foo.bar") assert are_progress_bars_disabled("peft") assert not are_progress_bars_disabled("peft.foo") assert are_progress_bars_disabled("peft.foo.bar") @patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", None) def test_empty_group_is_root(self): """Test the behavior with invalid or empty group names.""" disable_progress_bars("") assert not are_progress_bars_disabled("peft") enable_progress_bars("123.invalid.name") assert not are_progress_bars_disabled("123.invalid.name") @patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", None) def test_multiple_level_toggling(self): """Test multiple levels of enabling and disabling.""" disable_progress_bars("peft") enable_progress_bars("peft.foo") disable_progress_bars("peft.foo.bar.something") assert are_progress_bars_disabled("peft") assert not are_progress_bars_disabled("peft.foo") assert are_progress_bars_disabled("peft.foo.bar.something") def test_progress_bar_respects_group(self) -> None: disable_progress_bars("foo.bar") for _ in tqdm(range(10), name="foo.bar.something"): pass captured = self.capsys.readouterr() assert captured.out == "" assert captured.err == "" enable_progress_bars("foo.bar.something") for _ in tqdm(range(10), name="foo.bar.something"): pass captured = self.capsys.readouterr() assert captured.out == "" assert "10/10" in captured.err huggingface_hub-0.31.1/tests/test_utils_typing.py000066400000000000000000000065301500667546600222310ustar00rootroot00000000000000import json import sys from typing import Optional, Type, Union import pytest from huggingface_hub.utils._typing import is_jsonable, is_simple_optional_type, unwrap_simple_optional_type class NotSerializableClass: pass class CustomType: pass OBJ_WITH_CIRCULAR_REF = {"hello": "world"} OBJ_WITH_CIRCULAR_REF["recursive"] = OBJ_WITH_CIRCULAR_REF @pytest.mark.parametrize( "data", [ 123, # 3.14, "Hello, world!", True, None, [], [1, 2, 3], [(1, 2.0, "string"), True], {}, {"name": "Alice", "age": 30}, {0: "LABEL_0", 1.0: "LABEL_1"}, ], ) def test_is_jsonable_success(data): assert is_jsonable(data) json.dumps(data) @pytest.mark.parametrize( "data", [ set([1, 2, 3]), lambda x: x + 1, NotSerializableClass(), {"obj": NotSerializableClass()}, OBJ_WITH_CIRCULAR_REF, ], ) def test_is_jsonable_failure(data): assert not is_jsonable(data) with pytest.raises((TypeError, ValueError)): json.dumps(data) @pytest.mark.parametrize( "type_, is_optional", [ (Optional[int], True), (Union[None, int], True), (Union[int, None], True), (Optional[CustomType], True), (Union[None, CustomType], True), (Union[CustomType, None], True), (int, False), (None, False), (Union[int, float, None], False), (Union[Union[int, float], None], False), (Optional[Union[int, float]], False), ], ) def test_is_simple_optional_type(type_: Type, is_optional: bool): assert is_simple_optional_type(type_) is is_optional @pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") @pytest.mark.parametrize( "type_, is_optional", [ ("int | None", True), ("None | int", True), ("CustomType | None", True), ("None | CustomType", True), ("int | float", False), ("int | float | None", False), ("(int | float) | None", False), ("Union[int, float] | None", False), ], ) def test_is_simple_optional_type_pipe(type_: str, is_optional: bool): assert is_simple_optional_type(eval(type_)) is is_optional @pytest.mark.parametrize( "optional_type, inner_type", [ (Optional[int], int), (Union[int, None], int), (Union[None, int], int), (Optional[CustomType], CustomType), (Union[CustomType, None], CustomType), (Union[None, CustomType], CustomType), ], ) def test_unwrap_simple_optional_type(optional_type: Type, inner_type: Type): assert unwrap_simple_optional_type(optional_type) is inner_type @pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") @pytest.mark.parametrize( "optional_type, inner_type", [ ("None | int", int), ("int | None", int), ("None | CustomType", CustomType), ("CustomType | None", CustomType), ], ) def test_unwrap_simple_optional_type_pipe(optional_type: str, inner_type: Type): assert unwrap_simple_optional_type(eval(optional_type)) is inner_type @pytest.mark.parametrize("non_optional_type", [int, None, CustomType]) def test_unwrap_simple_optional_type_fail(non_optional_type: Type): with pytest.raises(ValueError): unwrap_simple_optional_type(non_optional_type) huggingface_hub-0.31.1/tests/test_utils_validators.py000066400000000000000000000104221500667546600230620ustar00rootroot00000000000000import unittest from pathlib import Path from unittest.mock import Mock, patch from huggingface_hub.utils import ( HFValidationError, smoothly_deprecate_use_auth_token, validate_hf_hub_args, validate_repo_id, ) @patch("huggingface_hub.utils._validators.validate_repo_id") class TestHfHubValidator(unittest.TestCase): """Test `validate_hf_hub_args` decorator calls all default validators.""" def test_validate_repo_id_as_arg(self, validate_repo_id_mock: Mock) -> None: """Test `validate_repo_id` is called when `repo_id` is passed as arg.""" self.dummy_function(123) validate_repo_id_mock.assert_called_once_with(123) def test_validate_repo_id_as_kwarg(self, validate_repo_id_mock: Mock) -> None: """Test `validate_repo_id` is called when `repo_id` is passed as kwarg.""" self.dummy_function(repo_id=123) validate_repo_id_mock.assert_called_once_with(123) @staticmethod @validate_hf_hub_args def dummy_function(repo_id: str) -> None: pass class TestRepoIdValidator(unittest.TestCase): VALID_VALUES = ( "123", "foo", "foo/bar", "Foo-BAR_foo.bar123", ) NOT_VALID_VALUES = ( Path("foo/bar"), # Must be a string "a" * 100, # Too long "datasets/foo/bar", # Repo_type forbidden in repo_id ".repo_id", # Cannot start with . "repo_id.", # Cannot end with . "foo--bar", # Cannot contain "--" "foo..bar", # Cannot contain "." "foo.git", # Cannot end with ".git" ) def test_valid_repo_ids(self) -> None: """Test `repo_id` validation on valid values.""" for repo_id in self.VALID_VALUES: validate_repo_id(repo_id) def test_not_valid_repo_ids(self) -> None: """Test `repo_id` validation on not valid values.""" for repo_id in self.NOT_VALID_VALUES: with self.assertRaises(HFValidationError, msg=f"'{repo_id}' must not be valid"): validate_repo_id(repo_id) class TestSmoothlyDeprecateUseAuthToken(unittest.TestCase): def test_token_normal_usage_as_arg(self) -> None: self.assertEqual( self.dummy_token_function("this_is_a_token"), ("this_is_a_token", {}), ) def test_token_normal_usage_as_kwarg(self) -> None: self.assertEqual( self.dummy_token_function(token="this_is_a_token"), ("this_is_a_token", {}), ) def test_token_normal_usage_with_more_kwargs(self) -> None: self.assertEqual( self.dummy_token_function(token="this_is_a_token", foo="bar"), ("this_is_a_token", {"foo": "bar"}), ) def test_token_with_smoothly_deprecated_use_auth_token(self) -> None: self.assertEqual( self.dummy_token_function(use_auth_token="this_is_a_use_auth_token"), ("this_is_a_use_auth_token", {}), ) def test_input_kwargs_not_mutated_by_smooth_deprecation(self) -> None: initial_kwargs = {"a": "b", "use_auth_token": "token"} kwargs = smoothly_deprecate_use_auth_token(fn_name="name", has_token=False, kwargs=initial_kwargs) self.assertEqual(kwargs, {"a": "b", "token": "token"}) self.assertEqual(initial_kwargs, {"a": "b", "use_auth_token": "token"}) # not mutated! def test_with_both_token_and_use_auth_token(self) -> None: with self.assertWarns(UserWarning): # `use_auth_token` is ignored ! self.assertEqual( self.dummy_token_function(token="this_is_a_token", use_auth_token="this_is_a_use_auth_token"), ("this_is_a_token", {}), ) def test_not_deprecated_use_auth_token(self) -> None: # `use_auth_token` is accepted by `dummy_use_auth_token_function` # => `smoothly_deprecate_use_auth_token` is not called self.assertEqual( self.dummy_use_auth_token_function(use_auth_token="this_is_a_use_auth_token"), ("this_is_a_use_auth_token", {}), ) @staticmethod @validate_hf_hub_args def dummy_token_function(token: str, **kwargs) -> None: return token, kwargs @staticmethod @validate_hf_hub_args def dummy_use_auth_token_function(use_auth_token: str, **kwargs) -> None: return use_auth_token, kwargs huggingface_hub-0.31.1/tests/test_webhooks_server.py000066400000000000000000000237371500667546600227160ustar00rootroot00000000000000import unittest from unittest.mock import patch from fastapi import Request from huggingface_hub.utils import capture_output, is_gradio_available from .testing_utils import requires if is_gradio_available(): import gradio as gr from fastapi.testclient import TestClient import huggingface_hub._webhooks_server from huggingface_hub import WebhookPayload, WebhooksServer # Taken from https://huggingface.co/docs/hub/webhooks#event WEBHOOK_PAYLOAD_CREATE_DISCUSSION = { "event": {"action": "create", "scope": "discussion"}, "repo": { "type": "model", "name": "gpt2", "id": "621ffdc036468d709f17434d", "private": False, "url": {"web": "https://huggingface.co/gpt2", "api": "https://huggingface.co/api/models/gpt2"}, "owner": {"id": "628b753283ef59b5be89e937"}, }, "discussion": { "id": "6399f58518721fdd27fc9ca9", "title": "Update co2 emissions", "url": { "web": "https://huggingface.co/gpt2/discussions/19", "api": "https://huggingface.co/api/models/gpt2/discussions/19", }, "status": "open", "author": {"id": "61d2f90c3c2083e1c08af22d"}, "num": 19, "isPullRequest": True, "changes": {"base": "refs/heads/main"}, }, "comment": { "id": "6399f58518721fdd27fc9caa", "author": {"id": "61d2f90c3c2083e1c08af22d"}, "content": "Add co2 emissions information to the model card", "hidden": False, "url": {"web": "https://huggingface.co/gpt2/discussions/19#6399f58518721fdd27fc9caa"}, }, "webhook": {"id": "6390e855e30d9209411de93b", "version": 3}, } WEBHOOK_PAYLOAD_UPDATE_DISCUSSION = { # valid payload but doesn't have a "comment" value "event": {"action": "update", "scope": "discussion"}, "repo": { "type": "space", "name": "Wauplin/leaderboard", "id": "656896965808298301ed7ccf", "private": False, "url": { "web": "https://huggingface.co/spaces/Wauplin/leaderboard", "api": "https://huggingface.co/api/spaces/Wauplin/leaderboard", }, "owner": {"id": "6273f303f6d63a28483fde12"}, }, "discussion": { "id": "656a0dfcadba74cd5ef4545b", "title": "Update space_ci/webhook.py", "url": { "web": "https://huggingface.co/spaces/Wauplin/leaderboard/discussions/4", "api": "https://huggingface.co/api/spaces/Wauplin/leaderboard/discussions/4", }, "status": "closed", "author": {"id": "6273f303f6d63a28483fde12"}, "num": 4, "isPullRequest": True, "changes": {"base": "refs/heads/main"}, }, "webhook": {"id": "656a05348c99518820a4dd54", "version": 3}, } WEBHOOK_PAYLOAD_WITH_UPDATED_REFS = { "event": {"action": "update", "scope": "repo.content"}, "repo": { "type": "space", "name": "Wauplin/gradio-user-history", "id": "651311c46de9c503f3f34a9e", "private": False, "subdomain": "wauplin-gradio-user-history", "url": { "web": "https://huggingface.co/spaces/Wauplin/gradio-user-history", "api": "https://huggingface.co/api/spaces/Wauplin/gradio-user-history", }, "headSha": "5e7f29fffcc579cb52539fddb14a1a4f85f39e44", "owner": { "id": "6273f303f6d63a28483fde12", }, }, "webhook": { "id": "65a14fd933eca76f4639fc84", "version": 3, }, "updatedRefs": [ { "ref": "refs/pr/5", "oldSha": None, "newSha": "227c78346870a85e5de4fff8a585db68df975406", } ], } def test_deserialize_payload_example_with_comment() -> None: """Confirm that the test stub can actually be deserialized.""" payload = WebhookPayload.model_validate(WEBHOOK_PAYLOAD_CREATE_DISCUSSION) assert payload.event.scope == WEBHOOK_PAYLOAD_CREATE_DISCUSSION["event"]["scope"] assert payload.comment is not None assert payload.comment.content == "Add co2 emissions information to the model card" def test_deserialize_payload_example_without_comment() -> None: """Confirm that the test stub can actually be deserialized.""" payload = WebhookPayload.model_validate(WEBHOOK_PAYLOAD_UPDATE_DISCUSSION) assert payload.event.scope == WEBHOOK_PAYLOAD_UPDATE_DISCUSSION["event"]["scope"] assert payload.comment is None def test_deserialize_payload_example_with_updated_refs() -> None: """Confirm that the test stub can actually be deserialized.""" payload = WebhookPayload.model_validate(WEBHOOK_PAYLOAD_WITH_UPDATED_REFS) assert payload.updatedRefs is not None assert payload.updatedRefs[0].ref == "refs/pr/5" assert payload.updatedRefs[0].oldSha is None assert payload.updatedRefs[0].newSha == "227c78346870a85e5de4fff8a585db68df975406" @requires("gradio") class TestWebhooksServerDontRun(unittest.TestCase): def test_add_webhook_implicit_path(self): # Test adding a webhook app = WebhooksServer() @app.add_webhook async def handler(): pass self.assertIn("/webhooks/handler", app.registered_webhooks) def test_add_webhook_explicit_path(self): # Test adding a webhook app = WebhooksServer() @app.add_webhook(path="/test_webhook") async def handler(): pass self.assertIn("/webhooks/test_webhook", app.registered_webhooks) # still registered under /webhooks def test_add_webhook_twice_should_fail(self): # Test adding a webhook app = WebhooksServer() @app.add_webhook("my_webhook") async def test_webhook(): pass # Registering twice the same webhook should raise an error with self.assertRaises(ValueError): @app.add_webhook("my_webhook") async def test_webhook_2(): pass @requires("gradio") class TestWebhooksServerRun(unittest.TestCase): HEADERS_VALID_SECRET = {"x-webhook-secret": "my_webhook_secret"} HEADERS_WRONG_SECRET = {"x-webhook-secret": "wrong_webhook_secret"} def setUp(self) -> None: with gr.Blocks() as ui: gr.Markdown("Hello World!") app = WebhooksServer(ui=ui, webhook_secret="my_webhook_secret") # Route to check payload parsing @app.add_webhook async def test_webhook(payload: WebhookPayload) -> None: return {"scope": payload.event.scope} # Routes to check secret validation # Checks all 4 cases (async/sync, with/without request parameter) @app.add_webhook async def async_with_request(request: Request) -> None: return {"success": True} @app.add_webhook def sync_with_request(request: Request) -> None: return {"success": True} @app.add_webhook async def async_no_request() -> None: return {"success": True} @app.add_webhook def sync_no_request() -> None: return {"success": True} # Route to check explicit path @app.add_webhook(path="/explicit_path") async def with_explicit_path() -> None: return {"success": True} self.ui = ui self.app = app self.client = self.mocked_run_app() def tearDown(self) -> None: self.ui.server.close() def mocked_run_app(self) -> "TestClient": with patch.object(self.ui, "block_thread"): # Run without blocking with patch.object(huggingface_hub._webhooks_server, "_is_local", False): # Run without tunnel self.app.launch() return TestClient(self.app.fastapi_app) def test_run_print_instructions(self): """Test that the instructions are printed when running the app.""" # Test running the app with capture_output() as output: self.mocked_run_app() instructions = output.getvalue() assert "Webhooks are correctly setup and ready to use:" in instructions assert "- POST http://127.0.0.1:" in instructions # port is usually 7860 but can be dynamic assert "/webhooks/test_webhook" in instructions def test_run_parse_payload(self): """Test that the payload is correctly parsed when running the app.""" response = self.client.post( "/webhooks/test_webhook", headers=self.HEADERS_VALID_SECRET, json=WEBHOOK_PAYLOAD_CREATE_DISCUSSION ) self.assertEqual(response.status_code, 200) self.assertEqual(response.json(), {"scope": "discussion"}) def test_with_webhook_secret_should_succeed(self): """Test success if valid secret is sent.""" for path in ["async_with_request", "sync_with_request", "async_no_request", "sync_no_request"]: with self.subTest(path): response = self.client.post(f"/webhooks/{path}", headers=self.HEADERS_VALID_SECRET) self.assertEqual(response.status_code, 200) self.assertEqual(response.json(), {"success": True}) def test_no_webhook_secret_should_be_unauthorized(self): """Test failure if valid secret is sent.""" for path in ["async_with_request", "sync_with_request", "async_no_request", "sync_no_request"]: with self.subTest(path): response = self.client.post(f"/webhooks/{path}") self.assertEqual(response.status_code, 401) def test_wrong_webhook_secret_should_be_forbidden(self): """Test failure if valid secret is sent.""" for path in ["async_with_request", "sync_with_request", "async_no_request", "sync_no_request"]: with self.subTest(path): response = self.client.post(f"/webhooks/{path}", headers=self.HEADERS_WRONG_SECRET) self.assertEqual(response.status_code, 403) def test_route_with_explicit_path(self): """Test that the route with an explicit path is correctly registered.""" response = self.client.post("/webhooks/explicit_path", headers=self.HEADERS_VALID_SECRET) self.assertEqual(response.status_code, 200) huggingface_hub-0.31.1/tests/test_windows.py000066400000000000000000000007311500667546600211660ustar00rootroot00000000000000"""Contains tests that are specific to windows machines.""" import os import unittest from huggingface_hub.file_download import are_symlinks_supported def require_windows(test_case): if os.name != "nt": return unittest.skip("test of git lfs workflow")(test_case) else: return test_case @require_windows class WindowsTests(unittest.TestCase): def test_are_symlink_supported(self) -> None: self.assertFalse(are_symlinks_supported()) huggingface_hub-0.31.1/tests/test_xet_download.py000066400000000000000000000277571500667546600222040ustar00rootroot00000000000000import os from contextlib import contextmanager from pathlib import Path from typing import Tuple from unittest.mock import DEFAULT, Mock, patch from huggingface_hub import snapshot_download from huggingface_hub.file_download import ( HfFileMetadata, get_hf_file_metadata, hf_hub_download, hf_hub_url, try_to_load_from_cache, xet_get, ) from huggingface_hub.utils import ( XetConnectionInfo, XetFileData, refresh_xet_connection_info, ) from .testing_utils import ( DUMMY_XET_FILE, DUMMY_XET_MODEL_ID, requires, with_production_testing, ) @requires("hf_xet") @with_production_testing class TestXetFileDownload: @contextmanager def _patch_xet_file_metadata(self, with_xet_data: bool): patcher = patch("huggingface_hub.file_download.get_hf_file_metadata") mock_metadata = patcher.start() mock_metadata.return_value = HfFileMetadata( commit_hash="mock_commit", etag="mock_etag", location="mock_location", size=1024, xet_file_data=XetFileData(file_hash="mock_hash", refresh_route="mock/route") if with_xet_data else None, ) try: yield mock_metadata finally: patcher.stop() @contextmanager def _patch_get_refresh_xet_connection_info(self): patcher = patch("huggingface_hub.utils.refresh_xet_connection_info") connection_info = ( XetConnectionInfo( endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999, ), ) mock_xet_connection = patcher.start() mock_xet_connection.return_value = connection_info try: yield mock_xet_connection finally: patcher.stop() def test_xet_get_called_when_xet_metadata_present(self, tmp_path): """Test that xet_get is called when xet metadata is present.""" with self._patch_xet_file_metadata(with_xet_data=True) as mock_file_metadata: with self._patch_get_refresh_xet_connection_info(): with patch("huggingface_hub.file_download.xet_get") as mock_xet_get: with patch("huggingface_hub.file_download._create_symlink"): hf_hub_download( DUMMY_XET_MODEL_ID, filename=DUMMY_XET_FILE, cache_dir=tmp_path, force_download=True, ) # Verify xet_get was called with correct parameters mock_xet_get.assert_called_once() _, kwargs = mock_xet_get.call_args assert "xet_file_data" in kwargs assert kwargs["xet_file_data"] == mock_file_metadata.return_value.xet_file_data def test_backward_compatibility_no_xet_metadata(self, tmp_path): """Test backward compatibility when response has no xet metadata.""" with self._patch_xet_file_metadata(with_xet_data=False): with patch("huggingface_hub.file_download.http_get") as mock_http_get: with patch("huggingface_hub.file_download._create_symlink"): hf_hub_download( DUMMY_XET_MODEL_ID, filename=DUMMY_XET_FILE, cache_dir=tmp_path, force_download=True, ) # Verify http_get was called mock_http_get.assert_called_once() def test_get_xet_file_metadata_basic(self) -> None: """Test getting metadata from a file on the Hub.""" url = hf_hub_url( repo_id=DUMMY_XET_MODEL_ID, filename=DUMMY_XET_FILE, ) metadata = get_hf_file_metadata(url) assert metadata.xet_file_data is not None assert metadata.xet_file_data.file_hash is not None connection_info = refresh_xet_connection_info(file_data=metadata.xet_file_data, headers={}) assert connection_info is not None assert connection_info.endpoint is not None assert connection_info.access_token is not None assert isinstance(connection_info.expiration_unix_epoch, int) def test_basic_download(self, tmp_path): # Make sure that xet_get is called with patch("huggingface_hub.file_download.xet_get", wraps=xet_get) as _xet_get: filepath = hf_hub_download( DUMMY_XET_MODEL_ID, filename=DUMMY_XET_FILE, cache_dir=tmp_path, ) assert os.path.exists(filepath) assert os.path.getsize(filepath) > 0 _xet_get.assert_called_once() def test_try_to_load_from_cache(self, tmp_path): cached_path = try_to_load_from_cache( DUMMY_XET_MODEL_ID, filename=DUMMY_XET_FILE, cache_dir=tmp_path, ) assert cached_path is None downloaded_path = hf_hub_download( DUMMY_XET_MODEL_ID, filename=DUMMY_XET_FILE, cache_dir=tmp_path, ) # Now should find it in cache cached_path = try_to_load_from_cache( DUMMY_XET_MODEL_ID, filename=DUMMY_XET_FILE, cache_dir=tmp_path, ) assert cached_path == downloaded_path def test_cache_reuse(self, tmp_path): path1 = hf_hub_download( DUMMY_XET_MODEL_ID, filename=DUMMY_XET_FILE, cache_dir=tmp_path, ) assert os.path.exists(path1) with patch("huggingface_hub.file_download._download_to_tmp_and_move") as mock: # Second download should use cache path2 = hf_hub_download( DUMMY_XET_MODEL_ID, filename=DUMMY_XET_FILE, cache_dir=tmp_path, ) assert path1 == path2 mock.assert_not_called() def test_download_to_local_dir(self, tmp_path): local_dir = tmp_path / "local_dir" local_dir.mkdir(exist_ok=True, parents=True) cache_dir = tmp_path / "cache" cache_dir.mkdir(exist_ok=True, parents=True) # Download to local dir returned_path = hf_hub_download( DUMMY_XET_MODEL_ID, filename=DUMMY_XET_FILE, local_dir=local_dir, cache_dir=cache_dir, ) assert local_dir in Path(returned_path).parents for path in cache_dir.glob("**/blobs/**"): assert not path.is_file() for path in cache_dir.glob("**/snapshots/**"): assert not path.is_file() def test_force_download(self, tmp_path): # First download path1 = hf_hub_download( DUMMY_XET_MODEL_ID, filename=DUMMY_XET_FILE, cache_dir=tmp_path, ) # Force download should re-download even if in cache with patch("huggingface_hub.file_download.xet_get") as mock_xet_get: path2 = hf_hub_download( DUMMY_XET_MODEL_ID, filename=DUMMY_XET_FILE, cache_dir=tmp_path, force_download=True, ) assert path1 == path2 mock_xet_get.assert_called_once() def test_fallback_to_http_when_xet_not_available(self, tmp_path): """Test that http_get is used when hf_xet is not available.""" with self._patch_xet_file_metadata(with_xet_data=True): with self._patch_get_refresh_xet_connection_info(): # Mock is_xet_available to return False with patch.multiple( "huggingface_hub.file_download", is_xet_available=Mock(return_value=False), http_get=DEFAULT, xet_get=DEFAULT, _create_symlink=DEFAULT, ) as mocks: hf_hub_download( DUMMY_XET_MODEL_ID, filename=DUMMY_XET_FILE, cache_dir=tmp_path, force_download=True, ) # Verify http_get was called and xet_get was not mocks["http_get"].assert_called_once() mocks["xet_get"].assert_not_called() def test_use_xet_when_available(self, tmp_path): """Test that xet_get is used when hf_xet is available.""" with self._patch_xet_file_metadata(with_xet_data=True): with self._patch_get_refresh_xet_connection_info(): with patch.multiple( "huggingface_hub.file_download", is_xet_available=Mock(return_value=True), http_get=DEFAULT, xet_get=DEFAULT, _create_symlink=DEFAULT, ) as mocks: hf_hub_download( DUMMY_XET_MODEL_ID, filename=DUMMY_XET_FILE, cache_dir=tmp_path, force_download=True, ) # Verify xet_get was called and http_get was not mocks["xet_get"].assert_called_once() mocks["http_get"].assert_not_called() @requires("hf_xet") @with_production_testing class TestXetSnapshotDownload: def test_download_model(self, tmp_path): """Test that snapshot_download works with Xet storage.""" storage_folder = snapshot_download( DUMMY_XET_MODEL_ID, cache_dir=tmp_path, ) assert os.path.exists(storage_folder) assert os.path.isdir(storage_folder) assert os.path.exists(os.path.join(storage_folder, DUMMY_XET_FILE)) with open(os.path.join(storage_folder, DUMMY_XET_FILE), "rb") as f: content = f.read() assert len(content) > 0 def test_snapshot_download_cache_reuse(self, tmp_path): """Test that snapshot_download reuses cached files.""" # First download storage_folder1 = snapshot_download( DUMMY_XET_MODEL_ID, cache_dir=tmp_path, ) with patch("huggingface_hub.file_download.xet_get") as mock_xet_get: # Second download should use cache storage_folder2 = snapshot_download( DUMMY_XET_MODEL_ID, cache_dir=tmp_path, ) # Verify same folder is returned assert storage_folder1 == storage_folder2 # Verify xet_get was not called (files were cached) mock_xet_get.assert_not_called() def test_download_backward_compatibility(self, tmp_path): """Test that xet download works with the old pointer file protocol. Until the next major version of hf-xet is released, we need to support the old pointer file based download to support old huggingface_hub versions. """ file_path = os.path.join(tmp_path, DUMMY_XET_FILE) file_metadata = get_hf_file_metadata( hf_hub_url( repo_id=DUMMY_XET_MODEL_ID, filename=DUMMY_XET_FILE, ) ) xet_file_data = file_metadata.xet_file_data # Mock the response to not include xet metadata from hf_xet import PyPointerFile, download_files connection_info = refresh_xet_connection_info(file_data=xet_file_data, headers={}) def token_refresher() -> Tuple[str, int]: connection_info = refresh_xet_connection_info(file_data=xet_file_data, headers={}) return connection_info.access_token, connection_info.expiration_unix_epoch pointer_files = [PyPointerFile(path=file_path, hash=xet_file_data.file_hash, filesize=file_metadata.size)] download_files( pointer_files, endpoint=connection_info.endpoint, token_info=(connection_info.access_token, connection_info.expiration_unix_epoch), token_refresher=token_refresher, progress_updater=None, ) assert os.path.exists(file_path) huggingface_hub-0.31.1/tests/test_xet_upload.py000066400000000000000000000347251500667546600216520ustar00rootroot00000000000000# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager from io import BytesIO from pathlib import Path from typing import Tuple from unittest.mock import MagicMock, patch import pytest from huggingface_hub import HfApi, RepoUrl from huggingface_hub._commit_api import _upload_lfs_files, _upload_xet_files from huggingface_hub.file_download import ( _get_metadata_or_catch_error, get_hf_file_metadata, hf_hub_download, hf_hub_url, ) from huggingface_hub.utils import build_hf_headers, refresh_xet_connection_info from .testing_constants import ENDPOINT_STAGING, TOKEN from .testing_utils import repo_name, requires @contextmanager def assert_upload_mode(mode: str): if mode not in ("xet", "lfs"): raise ValueError("Mode must be either 'xet' or 'lfs'") with patch("huggingface_hub.hf_api._upload_xet_files", wraps=_upload_xet_files) as mock_xet: with patch("huggingface_hub.hf_api._upload_lfs_files", wraps=_upload_lfs_files) as mock_lfs: yield assert mock_xet.called == (mode == "xet"), ( f"Expected {'XET' if mode == 'xet' else 'LFS'} upload to be used" ) assert mock_lfs.called == (mode == "lfs"), ( f"Expected {'LFS' if mode == 'lfs' else 'XET'} upload to be used" ) @pytest.fixture(scope="module") def api(): return HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) @pytest.fixture def repo_url(api, repo_type: str = "model"): repo_url = api.create_repo(repo_id=repo_name(prefix=repo_type), repo_type=repo_type) api.update_repo_settings(repo_id=repo_url.repo_id, xet_enabled=True) yield repo_url api.delete_repo(repo_id=repo_url.repo_id, repo_type=repo_type) @requires("hf_xet") class TestXetUpload: @pytest.fixture(autouse=True) def setup(self, tmp_path): self.folder_path = tmp_path # Create a regular text file text_file = self.folder_path / "text_file.txt" self.text_content = "This is a regular text file" text_file.write_text(self.text_content) # Create a binary file self.bin_file = self.folder_path / "binary_file.bin" self.bin_content = b"0" * (1 * 1024 * 1024) self.bin_file.write_bytes(self.bin_content) # Create nested directory structure nested_dir = self.folder_path / "nested" nested_dir.mkdir() # Create a nested text file nested_text_file = nested_dir / "nested_text.txt" self.nested_text_content = "This is a nested text file" nested_text_file.write_text(self.nested_text_content) # Create a nested binary file nested_bin_file = nested_dir / "nested_binary.safetensors" self.nested_bin_content = b"1" * (1 * 1024 * 1024) nested_bin_file.write_bytes(self.nested_bin_content) def test_upload_file(self, api, tmp_path, repo_url): filename_in_repo = "binary_file.bin" repo_id = repo_url.repo_id with assert_upload_mode("xet"): return_val = api.upload_file( path_or_fileobj=self.bin_file, path_in_repo=filename_in_repo, repo_id=repo_id, ) assert return_val == f"{api.endpoint}/{repo_id}/blob/main/{filename_in_repo}" # Download and verify content downloaded_file = hf_hub_download(repo_id=repo_id, filename=filename_in_repo, cache_dir=tmp_path) with open(downloaded_file, "rb") as f: downloaded_content = f.read() assert downloaded_content == self.bin_content # Check xet metadata url = hf_hub_url( repo_id=repo_id, filename=filename_in_repo, ) metadata = get_hf_file_metadata(url) assert metadata.xet_file_data is not None xet_connection = refresh_xet_connection_info(file_data=metadata.xet_file_data, headers={}) assert xet_connection is not None def test_upload_file_with_bytesio(self, api, tmp_path, repo_url): repo_id = repo_url.repo_id content = BytesIO(self.bin_content) with assert_upload_mode("lfs"): api.upload_file( path_or_fileobj=content, path_in_repo="bytesio_file.bin", repo_id=repo_id, ) # Download and verify content downloaded_file = hf_hub_download(repo_id=repo_id, filename="bytesio_file.bin", cache_dir=tmp_path) with open(downloaded_file, "rb") as f: downloaded_content = f.read() assert downloaded_content == self.bin_content def test_upload_file_with_byte_array(self, api, tmp_path, repo_url): repo_id = repo_url.repo_id content = self.bin_content with assert_upload_mode("xet"): api.upload_file( path_or_fileobj=content, path_in_repo="bytearray_file.bin", repo_id=repo_id, ) # Download and verify content downloaded_file = hf_hub_download(repo_id=repo_id, filename="bytearray_file.bin", cache_dir=tmp_path) with open(downloaded_file, "rb") as f: downloaded_content = f.read() assert downloaded_content == self.bin_content def test_fallback_to_lfs_when_xet_not_available(self, api, repo_url): repo_id = repo_url.repo_id with patch("huggingface_hub.hf_api.is_xet_available", return_value=False): with assert_upload_mode("lfs"): api.upload_file( path_or_fileobj=self.bin_file, path_in_repo="fallback_file.bin", repo_id=repo_id, ) def test_upload_based_on_xet_enabled_setting(self, api, repo_url): repo_id = repo_url.repo_id # Test when xet is enabled -> use Xet upload with patch("huggingface_hub.hf_api.HfApi.repo_info") as mock_repo_info: mock_repo_info.return_value.xet_enabled = True with assert_upload_mode("xet"): api.upload_file( path_or_fileobj=self.bin_file, path_in_repo="xet_enabled.bin", repo_id=repo_id, ) # Test when xet is disabled -> use LFS upload with patch("huggingface_hub.hf_api.HfApi.repo_info") as mock_repo_info: mock_repo_info.return_value.xet_enabled = False with assert_upload_mode("lfs"): api.upload_file( path_or_fileobj=self.bin_file, path_in_repo="xet_disabled.bin", repo_id=repo_id, ) def test_upload_folder(self, api, repo_url): repo_id = repo_url.repo_id folder_in_repo = "temp" with assert_upload_mode("xet"): return_val = api.upload_folder( folder_path=self.folder_path, path_in_repo=folder_in_repo, repo_id=repo_id, ) assert return_val == f"{api.endpoint}/{repo_id}/tree/main/{folder_in_repo}" files_in_repo = set(api.list_repo_files(repo_id=repo_id)) files = { f"{folder_in_repo}/text_file.txt", f"{folder_in_repo}/binary_file.bin", f"{folder_in_repo}/nested/nested_text.txt", f"{folder_in_repo}/nested/nested_binary.safetensors", } assert all(file in files_in_repo for file in files) for rpath in files: local_file = Path(rpath).relative_to(folder_in_repo) local_path = self.folder_path / local_file filepath = hf_hub_download(repo_id=repo_id, filename=rpath) assert Path(local_path).read_bytes() == Path(filepath).read_bytes() def test_upload_folder_create_pr(self, api, repo_url) -> None: repo_id = repo_url.repo_id folder_in_repo = "temp_create_pr" with assert_upload_mode("xet"): return_val = api.upload_folder( folder_path=self.folder_path, path_in_repo=folder_in_repo, repo_id=repo_id, create_pr=True, ) assert return_val == f"{api.endpoint}/{repo_id}/tree/refs%2Fpr%2F1/{folder_in_repo}" for rpath in ["text_file.txt", "nested/nested_binary.safetensors"]: local_path = self.folder_path / rpath filepath = hf_hub_download( repo_id=repo_id, filename=f"{folder_in_repo}/{rpath}", revision=return_val.pr_revision ) assert Path(local_path).read_bytes() == Path(filepath).read_bytes() @requires("hf_xet") class TestXetLargeUpload: def test_upload_large_folder(self, api, tmp_path, repo_url: RepoUrl) -> None: N_FILES_PER_FOLDER = 4 repo_id = repo_url.repo_id folder = Path(tmp_path) / "large_folder" for i in range(N_FILES_PER_FOLDER): subfolder = folder / f"subfolder_{i}" subfolder.mkdir(parents=True, exist_ok=True) for j in range(N_FILES_PER_FOLDER): (subfolder / f"file_xet_{i}_{j}.bin").write_bytes(f"content_lfs_{i}_{j}".encode()) (subfolder / f"file_regular_{i}_{j}.txt").write_bytes(f"content_regular_{i}_{j}".encode()) with assert_upload_mode("xet"): api.upload_large_folder(repo_id=repo_id, repo_type="model", folder_path=folder, num_workers=4) # Check all files have been uploaded uploaded_files = api.list_repo_files(repo_id=repo_id) # Download and verify content local_dir = Path(tmp_path) / "snapshot" local_dir.mkdir() api.snapshot_download(repo_id=repo_id, local_dir=local_dir, cache_dir=None) for i in range(N_FILES_PER_FOLDER): for j in range(N_FILES_PER_FOLDER): assert f"subfolder_{i}/file_xet_{i}_{j}.bin" in uploaded_files assert f"subfolder_{i}/file_regular_{i}_{j}.txt" in uploaded_files # Check xet metadata url = hf_hub_url( repo_id=repo_id, filename=f"subfolder_{i}/file_xet_{i}_{j}.bin", ) metadata = get_hf_file_metadata(url) xet_filedata = metadata.xet_file_data assert xet_filedata is not None # Verify xet files xet_file = local_dir / f"subfolder_{i}/file_xet_{i}_{j}.bin" assert xet_file.read_bytes() == f"content_lfs_{i}_{j}".encode() # Verify regular files regular_file = local_dir / f"subfolder_{i}/file_regular_{i}_{j}.txt" assert regular_file.read_bytes() == f"content_regular_{i}_{j}".encode() @requires("hf_xet") class TestXetE2E(TestXetUpload): def test_hf_xet_with_token_refresher(self, api, tmp_path, repo_url): """ Test the hf_xet.download_files function with a token refresher. This test manually calls the hf_xet.download_files function with a token refresher function to verify that the token refresh mechanism works as expected. It aims to identify regressions in the hf_xet.download_files function. * Define a token refresher function that issues a token refresh by returning a new access token and expiration time. * Mock the token refresher function. * Construct the necessary headers and metadata for the file to be downloaded. * Call the download_files function with the token refresher, forcing a token refresh. * Assert that the token refresher function was called as expected. This test ensures that the downloaded file is the same as the uploaded file. """ from hf_xet import PyXetDownloadInfo, download_files filename_in_repo = "binary_file.bin" repo_id = repo_url.repo_id # Upload a file api.upload_file( path_or_fileobj=self.bin_file, path_in_repo=filename_in_repo, repo_id=repo_id, ) # headers headers = build_hf_headers(token=TOKEN) # metadata for url (url_to_download, etag, commit_hash, expected_size, xet_filedata, head_call_error) = ( _get_metadata_or_catch_error( repo_id=repo_id, filename=filename_in_repo, revision="main", repo_type="model", headers=headers, endpoint=api.endpoint, token=TOKEN, proxies=None, etag_timeout=None, local_files_only=False, ) ) xet_connection_info = refresh_xet_connection_info(file_data=xet_filedata, headers=headers) # manually construct parameters to hf_xet.download_files and use a locally defined token_refresher function # to verify that token refresh works as expected. def token_refresher() -> Tuple[str, int]: # Issue a token refresh by returning a new access token and expiration time new_connection = refresh_xet_connection_info(file_data=xet_filedata, headers=headers) return new_connection.access_token, new_connection.expiration_unix_epoch mock_token_refresher = MagicMock(side_effect=token_refresher) incomplete_path = Path(tmp_path) / "file.bin.incomplete" file_info = [ PyXetDownloadInfo( destination_path=str(incomplete_path.absolute()), hash=xet_filedata.file_hash, file_size=expected_size ) ] # Call the download_files function with the token refresher, set expiration to 0 forcing a refresh download_files( file_info, endpoint=xet_connection_info.endpoint, token_info=(xet_connection_info.access_token, 0), token_refresher=mock_token_refresher, progress_updater=None, ) # assert that our local token_refresher function was called by hfxet as expected. mock_token_refresher.assert_called_once() # Check that the downloaded file is the same as the uploaded file with open(incomplete_path, "rb") as f: downloaded_content = f.read() assert downloaded_content == self.bin_content huggingface_hub-0.31.1/tests/test_xet_utils.py000066400000000000000000000213221500667546600215130ustar00rootroot00000000000000from unittest.mock import MagicMock import pytest from _pytest.monkeypatch import MonkeyPatch from huggingface_hub import constants from huggingface_hub.utils._xet import ( XetFileData, _fetch_xet_connection_info_with_url, parse_xet_connection_info_from_headers, parse_xet_file_data_from_response, refresh_xet_connection_info, ) def test_parse_valid_headers_file_info() -> None: mock_response = MagicMock() mock_response.headers = { "X-Xet-Hash": "sha256:abcdef", "X-Xet-Refresh-Route": "/api/refresh", } mock_response.links = {} file_data = parse_xet_file_data_from_response(mock_response) assert file_data is not None assert file_data.refresh_route == "/api/refresh" assert file_data.file_hash == "sha256:abcdef" def test_parse_valid_headers_file_info_with_link() -> None: mock_response = MagicMock() mock_response.headers = { "X-Xet-Hash": "sha256:abcdef", } mock_response.links = { "xet-auth": {"url": "/api/refresh"}, } file_data = parse_xet_file_data_from_response(mock_response) assert file_data is not None assert file_data.refresh_route == "/api/refresh" assert file_data.file_hash == "sha256:abcdef" def test_parse_invalid_headers_file_info() -> None: mock_response = MagicMock() mock_response.headers = {"X-foo": "bar"} mock_response.links = {} assert parse_xet_file_data_from_response(mock_response) is None def test_parse_valid_headers_connection_info() -> None: headers = { "X-Xet-Cas-Url": "https://xet.example.com", "X-Xet-Access-Token": "xet_token_abc", "X-Xet-Token-Expiration": "1234567890", } connection_info = parse_xet_connection_info_from_headers(headers) assert connection_info is not None assert connection_info.endpoint == "https://xet.example.com" assert connection_info.access_token == "xet_token_abc" assert connection_info.expiration_unix_epoch == 1234567890 def test_parse_valid_headers_full() -> None: mock_response = MagicMock() mock_response.headers = { "X-Xet-Cas-Url": "https://xet.example.com", "X-Xet-Access-Token": "xet_token_abc", "X-Xet-Token-Expiration": "1234567890", "X-Xet-Refresh-Route": "/api/refresh", "X-Xet-Hash": "sha256:abcdef", } mock_response.links = {} file_metadata = parse_xet_file_data_from_response(mock_response) connection_info = parse_xet_connection_info_from_headers(mock_response.headers) assert file_metadata is not None assert file_metadata.refresh_route == "/api/refresh" assert file_metadata.file_hash == "sha256:abcdef" assert connection_info is not None assert connection_info.endpoint == "https://xet.example.com" assert connection_info.access_token == "xet_token_abc" assert connection_info.expiration_unix_epoch == 1234567890 @pytest.mark.parametrize( "missing_key", [ "X-Xet-Cas-Url", "X-Xet-Access-Token", "X-Xet-Token-Expiration", ], ) def test_parse_missing_required_header(missing_key: str) -> None: headers = { "X-Xet-Cas-Url": "https://xet.example.com", "X-Xet-Access-Token": "xet_token_abc", "X-Xet-Token-Expiration": "1234567890", } # Remove the key to test headers.pop(missing_key) connection_info = parse_xet_connection_info_from_headers(headers) assert connection_info is None def test_parse_invalid_expiration() -> None: """Test parsing headers with invalid expiration format returns None.""" headers = { "X-Xet-Cas-Url": "https://xet.example.com", "X-Xet-Access-Token": "xet_token_abc", "X-Xet-Token-Expiration": "not-a-number", } connection_info = parse_xet_connection_info_from_headers(headers) assert connection_info is None def test_refresh_metadata_success(mocker) -> None: # Mock headers for the refreshed response mock_response = MagicMock() mock_response.headers = { "X-Xet-Cas-Url": "https://example.xethub.hf.co", "X-Xet-Access-Token": "new_token", "X-Xet-Token-Expiration": "1234599999", "X-Xet-Refresh-Route": f"{constants.ENDPOINT}/api/models/username/repo_name/xet-read-token/token", } mock_session = MagicMock() mock_session.get.return_value = mock_response mocker.patch("huggingface_hub.utils._xet.get_session", return_value=mock_session) headers = {"user-agent": "user-agent-example"} refreshed_connection = refresh_xet_connection_info( file_data=XetFileData( refresh_route=f"{constants.ENDPOINT}/api/models/username/repo_name/xet-read-token/token", file_hash="sha256:abcdef", ), headers=headers, ) # Verify the request expected_url = f"{constants.ENDPOINT}/api/models/username/repo_name/xet-read-token/token" mock_session.get.assert_called_once_with( headers=headers, url=expected_url, params=None, ) assert refreshed_connection.endpoint == "https://example.xethub.hf.co" assert refreshed_connection.access_token == "new_token" assert refreshed_connection.expiration_unix_epoch == 1234599999 def test_refresh_metadata_custom_endpoint(mocker) -> None: custom_endpoint = "https://custom.xethub.hf.co" # Mock headers for the refreshed response mock_response = MagicMock() mock_response.headers = { "X-Xet-Cas-Url": "https://custom.xethub.hf.co", "X-Xet-Access-Token": "new_token", "X-Xet-Token-Expiration": "1234599999", } mock_session = MagicMock() mock_session.get.return_value = mock_response mocker.patch("huggingface_hub.utils._xet.get_session", return_value=mock_session) headers = {"user-agent": "user-agent-example"} refresh_xet_connection_info( file_data=XetFileData( refresh_route=f"{custom_endpoint}/api/models/username/repo_name/xet-read-token/token", file_hash="sha256:abcdef", ), headers=headers, ) # Verify the request used the custom endpoint expected_url = f"{custom_endpoint}/api/models/username/repo_name/xet-read-token/token" mock_session.get.assert_called_once_with( headers=headers, url=expected_url, params=None, ) def test_refresh_metadata_missing_refresh_route() -> None: # Create metadata without refresh_route headers = {"user-agent": "user-agent-example"} # Verify it raises ValueError with pytest.raises(ValueError, match="The provided xet metadata does not contain a refresh endpoint."): refresh_xet_connection_info( file_data=XetFileData( refresh_route=None, file_hash="sha256:abcdef", ), headers=headers, ) def test_fetch_xet_metadata_with_url(mocker) -> None: mock_response = MagicMock() mock_response.headers = { "X-Xet-Cas-Url": "https://example.xethub.hf.co", "X-Xet-Access-Token": "xet_token123", "X-Xet-Token-Expiration": "1234567890", } # Mock the session.get method mock_session = MagicMock() mock_session.get.return_value = mock_response mocker.patch("huggingface_hub.utils._xet.get_session", return_value=mock_session) # Call the function url = "https://example.xethub.hf.co/api/models/username/repo_name/xet-read-token/token" headers = {"user-agent": "user-agent-example"} metadata = _fetch_xet_connection_info_with_url(url=url, headers=headers) # Verify the request mock_session.get.assert_called_once_with( headers=headers, url=url, params=None, ) # Verify returned metadata assert metadata.endpoint == "https://example.xethub.hf.co" assert metadata.access_token == "xet_token123" assert metadata.expiration_unix_epoch == 1234567890 def test_fetch_xet_metadata_with_url_invalid_response(mocker) -> None: mock_response = MagicMock() mock_response.headers = {"Content-Type": "application/json"} # No XET headers # Mock the session.get method mock_session = MagicMock() mock_session.get.return_value = mock_response mocker.patch("huggingface_hub.utils._xet.get_session", return_value=mock_session) url = "https://example.xethub.hf.co/api/models/username/repo_name/xet-read-token/token" headers = {"user-agent": "user-agent-example"} with pytest.raises(ValueError, match="Xet headers have not been correctly set by the server."): _fetch_xet_connection_info_with_url(url=url, headers=headers) def test_env_var_hf_hub_disable_xet() -> None: """Test that setting HF_HUB_DISABLE_XET results in is_xet_available() returning False.""" from huggingface_hub.utils._runtime import is_xet_available monkeypatch = MonkeyPatch() monkeypatch.setenv("HF_HUB_DISABLE_XET", "1") assert not is_xet_available() huggingface_hub-0.31.1/tests/testing_constants.py000066400000000000000000000015071500667546600222100ustar00rootroot00000000000000USER = "__DUMMY_TRANSFORMERS_USER__" FULL_NAME = "Dummy User" PASS = "__DUMMY_TRANSFORMERS_PASS__" # Not critical, only usable on the sandboxed CI instance. TOKEN = "hf_94wBhPGp6KrrTH3KDchhKpRxZwd6dmHWLL" # Used to create repos that we don't own (example: for gated repo) # Token is not critical. Also public in https://github.com/huggingface/datasets-server OTHER_USER = "DVUser" OTHER_TOKEN = "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" # Used to test enterprise features, typically creating private repos by default ENTERPRISE_USER = "EnterpriseAdmin" ENTERPRISE_ORG = "EnterpriseOrgPrivate" ENTERPRISE_TOKEN = "hf_enterprise_admin_token" ENDPOINT_PRODUCTION = "https://huggingface.co" ENDPOINT_STAGING = "https://hub-ci.huggingface.co" ENDPOINT_PRODUCTION_URL_SCHEME = ENDPOINT_PRODUCTION + "/{repo_id}/resolve/{revision}/{filename}" huggingface_hub-0.31.1/tests/testing_utils.py000066400000000000000000000367601500667546600213450ustar00rootroot00000000000000import inspect import os import shutil import stat import time import unittest import uuid from contextlib import contextmanager from enum import Enum from functools import wraps from pathlib import Path from typing import Callable, Optional, Type, TypeVar, Union from unittest.mock import Mock, patch import pytest import requests from huggingface_hub.utils import is_package_available, logging, reset_sessions from tests.testing_constants import ENDPOINT_PRODUCTION, ENDPOINT_PRODUCTION_URL_SCHEME logger = logging.get_logger(__name__) SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer" # Example model ids # An actual model hosted on huggingface.co, # w/ more details. DUMMY_MODEL_ID = "julien-c/dummy-unknown" DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT = "f2c752cfc5c0ab6f4bdec59acea69eefbee381c2" # One particular commit (not the top of `main`) DUMMY_MODEL_ID_REVISION_INVALID = "aaaaaaa" # This commit does not exist, so we should 404. DUMMY_MODEL_ID_PINNED_SHA1 = "d9e9f15bc825e4b2c9249e9578f884bbcb5e3684" # Sha-1 of config.json on the top of `main`, for checking purposes DUMMY_MODEL_ID_PINNED_SHA256 = "4b243c475af8d0a7754e87d7d096c92e5199ec2fe168a2ee7998e3b8e9bcb1d3" # Sha-256 of pytorch_model.bin on the top of `main`, for checking purposes # "hf-internal-testing/dummy-will-be-renamed" has been renamed to "hf-internal-testing/dummy-renamed" DUMMY_RENAMED_OLD_MODEL_ID = "hf-internal-testing/dummy-will-be-renamed" DUMMY_RENAMED_NEW_MODEL_ID = "hf-internal-testing/dummy-renamed" SAMPLE_DATASET_IDENTIFIER = "lhoestq/custom_squad" # Example dataset ids DUMMY_DATASET_ID = "gaia-benchmark/GAIA" DUMMY_DATASET_ID_REVISION_ONE_SPECIFIC_COMMIT = "c603981e170e9e333934a39781d2ae3a2677e81f" # on branch "test-branch" YES = ("y", "yes", "t", "true", "on", "1") NO = ("n", "no", "f", "false", "off", "0") # Xet testing DUMMY_XET_MODEL_ID = "celinah/dummy-xet-testing" DUMMY_XET_FILE = "dummy.safetensors" DUMMY_XET_REGULAR_FILE = "dummy.txt" # extra large file for testing on production DUMMY_EXTRA_LARGE_FILE_MODEL_ID = "brianronan/dummy-xet-edge-case-files" DUMMY_EXTRA_LARGE_FILE_NAME = "verylargemodel.safetensors" # > 50GB file DUMMY_TINY_FILE_NAME = "tiny.safetensors" # 45 byte file def repo_name(id: Optional[str] = None, prefix: str = "repo") -> str: """ Return a readable pseudo-unique repository name for tests. Example: ```py >>> repo_name() repo-2fe93f-16599646671840 >>> repo_name("my-space", prefix='space') space-my-space-16599481979701 """ if id is None: id = uuid.uuid4().hex[:6] ts = int(time.time() * 10e3) return f"{prefix}-{id}-{ts}" def parse_flag_from_env(key: str, default: bool = False) -> bool: try: value = os.environ[key] except KeyError: # KEY isn't set, default to `default`. return default # KEY is set, convert it to True or False. if value.lower() in YES: return True elif value.lower() in NO: return False else: # More values are supported, but let's keep the message simple. raise ValueError(f"If set, '{key}' must be one of {YES + NO}. Got '{value}'.") def parse_int_from_env(key, default=None): try: value = os.environ[key] except KeyError: _value = default else: try: _value = int(value) except ValueError: raise ValueError("If set, {} must be a int.".format(key)) return _value _run_git_lfs_tests = parse_flag_from_env("RUN_GIT_LFS_TESTS", default=False) def require_git_lfs(test_case): """ Decorator to mark tests that requires git-lfs. git-lfs requires additional dependencies, and tests are skipped by default. Set the RUN_GIT_LFS_TESTS environment variable to a truthy value to run them. """ if not _run_git_lfs_tests: return unittest.skip("test of git lfs workflow")(test_case) else: return test_case def requires(package_name: str): """ Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed. """ def _inner(test_case): if not is_package_available(package_name): return unittest.skip(f"Test requires '{package_name}'")(test_case) else: return test_case return _inner class RequestWouldHangIndefinitelyError(Exception): pass class OfflineSimulationMode(Enum): CONNECTION_FAILS = 0 CONNECTION_TIMES_OUT = 1 HF_HUB_OFFLINE_SET_TO_1 = 2 @contextmanager def offline(mode=OfflineSimulationMode.CONNECTION_FAILS, timeout=1e-16): """ Simulate offline mode. There are three offline simulation modes: CONNECTION_FAILS (default mode): a ConnectionError is raised for each network call. Connection errors are created by mocking socket.socket CONNECTION_TIMES_OUT: the connection hangs until it times out. The default timeout value is low (1e-16) to speed up the tests. Timeout errors are created by mocking requests.request HF_HUB_OFFLINE_SET_TO_1: the HF_HUB_OFFLINE_SET_TO_1 environment variable is set to 1. This makes the http/ftp calls of the library instantly fail and raise an OfflineModeEnabled error. """ import socket from requests import request as online_request def timeout_request(method, url, **kwargs): # Change the url to an invalid url so that the connection hangs invalid_url = "https://10.255.255.1" if kwargs.get("timeout") is None: raise RequestWouldHangIndefinitelyError( f"Tried a call to {url} in offline mode with no timeout set. Please set a timeout." ) kwargs["timeout"] = timeout try: return online_request(method, invalid_url, **kwargs) except Exception as e: # The following changes in the error are just here to make the offline timeout error prettier e.request.url = url max_retry_error = e.args[0] max_retry_error.args = (max_retry_error.args[0].replace("10.255.255.1", f"OfflineMock[{url}]"),) e.args = (max_retry_error,) raise def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled.") if mode is OfflineSimulationMode.CONNECTION_FAILS: # inspired from https://stackoverflow.com/a/18601897 with patch("socket.socket", offline_socket): with patch("huggingface_hub.utils._http.get_session") as get_session_mock: get_session_mock.return_value = requests.Session() # not an existing one yield elif mode is OfflineSimulationMode.CONNECTION_TIMES_OUT: # inspired from https://stackoverflow.com/a/904609 with patch("requests.request", timeout_request): with patch("huggingface_hub.utils._http.get_session") as get_session_mock: get_session_mock().request = timeout_request yield elif mode is OfflineSimulationMode.HF_HUB_OFFLINE_SET_TO_1: with patch("huggingface_hub.constants.HF_HUB_OFFLINE", True): reset_sessions() yield reset_sessions() else: raise ValueError("Please use a value from the OfflineSimulationMode enum.") def set_write_permission_and_retry(func, path, excinfo): os.chmod(path, stat.S_IWRITE) func(path) def rmtree_with_retry(path: Union[str, Path]) -> None: shutil.rmtree(path, onerror=set_write_permission_and_retry) def with_production_testing(func): file_download = patch("huggingface_hub.file_download.HUGGINGFACE_CO_URL_TEMPLATE", ENDPOINT_PRODUCTION_URL_SCHEME) hf_api = patch("huggingface_hub.constants.ENDPOINT", ENDPOINT_PRODUCTION) return hf_api(file_download(func)) def expect_deprecation(function_name: str): """ Decorator to flag tests that we expect to use deprecated arguments. Args: function_name (`str`): Name of the function that we expect to use in a deprecated way. NOTE: if a test is expected to warns FutureWarnings but is not, the test will fail. Context: over time, some arguments/methods become deprecated. In order to track deprecation in tests, we run pytest with flag `-Werror::FutureWarning`. In order to keep old tests during the deprecation phase (before removing the feature completely) without changing them internally, we can flag them with this decorator. See full discussion in https://github.com/huggingface/huggingface_hub/pull/952. This decorator works hand-in-hand with the `_deprecate_arguments` and `_deprecate_positional_args` decorators. Example ```py # in src/hub_mixins.py from .utils._deprecation import _deprecate_arguments @_deprecate_arguments(version="0.12", deprecated_args={"repo_url"}) def push_to_hub(...): (...) # in tests/test_something.py from .testing_utils import expect_deprecation class SomethingTest(unittest.TestCase): (...) @expect_deprecation("push_to_hub"): def test_push_to_hub_git_version(self): (...) push_to_hub(repo_url="something") <- Should warn with FutureWarnings (...) ``` """ def _inner_decorator(test_function: Callable) -> Callable: @wraps(test_function) def _inner_test_function(*args, **kwargs): with pytest.warns(FutureWarning, match=f".*'{function_name}'.*"): return test_function(*args, **kwargs) return _inner_test_function return _inner_decorator def xfail_on_windows(reason: str, raises: Optional[Type[Exception]] = None): """ Decorator to flag tests that we expect to fail on Windows. Will not raise an error if the expected error happens while running on Windows machine. If error is expected but does not happen, the test fails as well. Args: reason (`str`): Reason why it should fail. raises (`Type[Exception]`): The error type we except to happen. """ def _inner_decorator(test_function: Callable) -> Callable: return pytest.mark.xfail(os.name == "nt", reason=reason, raises=raises, strict=True, run=True)(test_function) return _inner_decorator T = TypeVar("T") def handle_injection(cls: T) -> T: """Handle mock injection for each test of a test class. When patching variables on a class level, only relevant mocks will be injected to the tests. This has 2 advantages: 1. There is no need to expect all mocks in test arguments when they are not needed. 2. Default mock injection append all mocks 1 by 1 to the test args. If the order of the patch calls or test argument is changed, it can lead to unexpected behavior. NOTE: `@handle_injection` has to be defined after the `@patch` calls. Example: ```py @patch("something.foo") @patch("something_else.foo.bar") # order doesn't matter @handle_injection # after @patch calls def TestHelloWorld(unittest.TestCase): def test_hello_foo(self, mock_foo: Mock) -> None: (...) def test_hello_bar(self, mock_bar: Mock) -> None (...) def test_hello_both(self, mock_foo: Mock, mock_bar: Mock) -> None: (...) ``` There are limitations with the current implementation: 1. All patched variables must have different names. Named injection will not work with both `@patch("something.foo")` and `@patch("something_else.foo")` patches. 2. Tests are expected to take only `self` and mock arguments. If it's not the case, this helper will fail. 3. Tests arguments must follow the `mock_{variable_name}` naming. Example: `@patch("something._foo")` -> `"mock__foo"`. 4. Tests arguments must be typed as `Mock`. If required, we can improve the current implementation in the future to mitigate those limitations. Based on: - https://stackoverflow.com/a/3467879 - https://stackoverflow.com/a/30764825 - https://stackoverflow.com/a/57115876 NOTE: this decorator is inspired from the fixture system from pytest. """ # Iterate over class functions and decorate tests # Taken from https://stackoverflow.com/a/3467879 # and https://stackoverflow.com/a/30764825 for name, fn in inspect.getmembers(cls): if name.startswith("test_"): setattr(cls, name, handle_injection_in_test(fn)) # Return decorated class return cls def handle_injection_in_test(fn: Callable) -> Callable: """ Handle injections at a test level. See `handle_injection` for more details. Example: ```py def TestHelloWorld(unittest.TestCase): @patch("something.foo") @patch("something_else.foo.bar") # order doesn't matter @handle_injection_in_test # after @patch calls def test_hello_foo(self, mock_foo: Mock) -> None: (...) ``` """ signature = inspect.signature(fn) parameters = signature.parameters @wraps(fn) def _inner(*args, **kwargs): assert kwargs == {} # Initialize new dict at least with `self`. assert len(args) > 0 assert len(parameters) > 0 new_kwargs = {"self": args[0]} # Check which mocks have been injected mocks = {} for value in args[1:]: assert isinstance(value, Mock) mock_name = "mock_" + value._extract_mock_name() mocks[mock_name] = value # Check which mocks are expected for name, parameter in parameters.items(): if name == "self": continue assert parameter.annotation is Mock assert name in mocks, ( f"Mock `{name}` not found for test `{fn.__name__}`. Available: {', '.join(sorted(mocks.keys()))}" ) new_kwargs[name] = mocks[name] # Run test only with a subset of mocks return fn(**new_kwargs) return _inner def use_tmp_repo(repo_type: str = "model") -> Callable[[T], T]: """ Test decorator to create a repo for the test and properly delete it afterward. TODO: could we make `_api`, `_user` and `_token` cleaner ? Example: ```py from huggingface_hub import RepoUrl from .testing_utils import use_tmp_repo class HfApiCommonTest(unittest.TestCase): _api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) @use_tmp_repo() def test_create_tag_on_model(self, repo_url: RepoUrl) -> None: (...) @use_tmp_repo("dataset") def test_create_tag_on_dataset(self, repo_url: RepoUrl) -> None: (...) ``` """ def _inner_use_tmp_repo(test_fn: T) -> T: @wraps(test_fn) def _inner(*args, **kwargs): self = args[0] assert isinstance(self, unittest.TestCase) create_repo_kwargs = {} if repo_type == "space": create_repo_kwargs["space_sdk"] = "gradio" repo_url = self._api.create_repo( repo_id=repo_name(prefix=repo_type), repo_type=repo_type, **create_repo_kwargs ) try: return test_fn(*args, **kwargs, repo_url=repo_url) finally: self._api.delete_repo(repo_id=repo_url.repo_id, repo_type=repo_type) return _inner return _inner_use_tmp_repo def assert_in_logs(caplog: pytest.LogCaptureFixture, expected_output): """Helper to check if a message appears in logs.""" log_text = "\n".join(record.message for record in caplog.records) assert expected_output in log_text, f"Expected '{expected_output}' not found in logs" huggingface_hub-0.31.1/utils/000077500000000000000000000000001500667546600160605ustar00rootroot00000000000000huggingface_hub-0.31.1/utils/_legacy_check_future_compatible_signatures.py000066400000000000000000000232451500667546600272350ustar00rootroot00000000000000# coding=utf-8 # Copyright 2022-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains a tool to add/check the definition of "async" methods of `HfApi` in `huggingface_hub.hf_api.py`. WARNING: this is a script kept to help with `@future_compatible` methods of `HfApi` but it is not 100% correct. Keeping it here for reference but it is not used in the CI/Makefile. What is done correctly: 1. Add "as_future" as argument to the method signature 2. Set Union[T, Future[T]] as return type to the method signature 3. Document "as_future" argument in the docstring of the method What is NOT done correctly: 1. Generated stubs are grouped at the top of the `HfApi` class. They must be copy-pasted (overload definition must be just before the method implementation) 2. `#type: ignore` must be adjusted in the first stub (if multiline definition) """ import argparse import inspect import os import re import tempfile from pathlib import Path from typing import Callable, NoReturn from ruff.__main__ import find_ruff_bin from huggingface_hub.hf_api import HfApi STUBS_SECTION_TEMPLATE = """ ### Stubs section start ### # This section contains stubs for the methods that are marked as `@future_compatible`. Those methods have a # different return type depending on the `as_future: bool` value. For better integrations with IDEs, we provide # stubs for both return types. The actual implementation of those methods is written below. # WARNING: this section have been generated automatically. Do not modify it manually. If you modify it manually, your # changes will be overwritten. To re-generate this section, run `make style` (or `python utils/check_future_compatible_signatures.py` # directly). # FAQ: # 1. Why should we have these? For better type annotation which helps with IDE features like autocompletion. # 2. Why not a separate `hf_api.pyi` file? Would require to re-defined all the existing annotations from `hf_api.py`. # 3. Why not at the end of the module? Because `@overload` methods must be defined first. # 4. Why not another solution? I'd be glad, but this is the "less worse" I could find. # For more details, see https://github.com/huggingface/huggingface_hub/pull/1458 {stubs} # WARNING: this section have been generated automatically. Do not modify it manually. If you modify it manually, your # changes will be overwritten. To re-generate this section, run `make style` (or `python utils/check_future_compatible_signatures.py` # directly). ### Stubs section end ### """ STUBS_SECTION_TEMPLATE_REGEX = re.compile(r"### Stubs section start ###.*### Stubs section end ###", re.DOTALL) AS_FUTURE_SIGNATURE_TEMPLATE = "as_future: bool = False" AS_FUTURE_DOCSTRING_TEMPLATE = """ as_future (`bool`, *optional*): Whether or not to run this method in the background. Background jobs are run sequentially without blocking the main thread. Passing `as_future=True` will return a [Future](https://docs.python.org/3/library/concurrent.futures.html#future-objects) object. Defaults to `False`.""" ARGS_DOCSTRING_REGEX = re.compile( """ ^[ ]{8}Args: # Match args section ... (.*?) # ... everything ... ^[ ]{8}\\S # ... until next section or end of docstring """, re.MULTILINE | re.IGNORECASE | re.VERBOSE | re.DOTALL, ) SIGNATURE_REGEX_FULL = re.compile(r"^\s*def.*?-> (.*?):", re.DOTALL | re.MULTILINE) SIGNATURE_REGEX_RETURN_TYPE = re.compile(r"-> (.*?):") SIGNATURE_REGEX_RETURN_TYPE_WITH_FUTURE = re.compile(r"-> Union\[(.*?), (.*?)\]:") HF_API_FILE_PATH = Path(__file__).parents[1] / "src" / "huggingface_hub" / "hf_api.py" HF_API_FILE_CONTENT = HF_API_FILE_PATH.read_text() def generate_future_compatible_method(method: Callable, method_source: str) -> str: # 1. Document `as_future` parameter if AS_FUTURE_DOCSTRING_TEMPLATE not in method_source: match = ARGS_DOCSTRING_REGEX.search(method_source) if match is None: raise ValueError(f"Could not find `Args` section in docstring of {method}.") args_docs = match.group(1).strip() method_source = method_source.replace(args_docs, args_docs + AS_FUTURE_DOCSTRING_TEMPLATE) # 2. Update signature # 2.a. Add `as_future` parameter if AS_FUTURE_SIGNATURE_TEMPLATE not in method_source: match = SIGNATURE_REGEX_FULL.search(method_source) if match is None: raise ValueError(f"Could not find signature of {method} in source.") method_source = method_source.replace( match.group(), match.group().replace(") ->", f" {AS_FUTURE_SIGNATURE_TEMPLATE}) ->"), 1 ) # 2.b. Update return value if "Future[" not in method_source: match = SIGNATURE_REGEX_RETURN_TYPE.search(method_source) if match is None: raise ValueError(f"Could not find return type of {method} in source.") base_type = match.group(1).strip() return_type = f"Union[{base_type}, Future[{base_type}]]" return_value_replaced = match.group().replace(match.group(1), return_type) method_source = method_source.replace(match.group(), return_value_replaced) # 3. Generate @overload stubs match = SIGNATURE_REGEX_FULL.search(method_source) if match is None: raise ValueError(f"Could not find signature of {method} in source.") method_sig = match.group() match = SIGNATURE_REGEX_RETURN_TYPE_WITH_FUTURE.search(method_sig) if match is None: raise ValueError(f"Could not find return type (with Future) of {method} in source.") no_future_return_type = match.group(1).strip() with_future_return_type = match.group(2).strip() # 3.a. Stub when `as_future=False` no_future_stub = " @overload\n" + method_sig no_future_stub = no_future_stub.replace(AS_FUTURE_SIGNATURE_TEMPLATE, "as_future: Literal[False] = ...") no_future_stub = SIGNATURE_REGEX_RETURN_TYPE.sub(rf"-> {no_future_return_type}:", no_future_stub) no_future_stub += " # type: ignore\n ..." # only the first stub requires "type: ignore" # 3.b. Stub when `as_future=True` with_future_stub = " @overload\n" + method_sig with_future_stub = with_future_stub.replace(AS_FUTURE_SIGNATURE_TEMPLATE, "as_future: Literal[True] = ...") with_future_stub = SIGNATURE_REGEX_RETURN_TYPE.sub(rf"-> {with_future_return_type}:", with_future_stub) with_future_stub += "\n ..." stubs_source = no_future_stub + "\n\n" + with_future_stub + "\n\n" # 4. All good! return method_source, stubs_source def generate_hf_api_module() -> str: raw_code = HF_API_FILE_CONTENT # Process all Future-compatible methods all_stubs_source = "" for _, method in inspect.getmembers(HfApi, predicate=inspect.isfunction): if not getattr(method, "is_future_compatible", False): continue source = inspect.getsource(method) method_source, stubs_source = generate_future_compatible_method(method, source) raw_code = raw_code.replace(source, method_source) all_stubs_source += "\n\n" + stubs_source # Generate code with stubs generated_code = STUBS_SECTION_TEMPLATE_REGEX.sub(STUBS_SECTION_TEMPLATE.format(stubs=all_stubs_source), raw_code) # Format (ruff) return format_generated_code(generated_code) def format_generated_code(code: str) -> str: """ Format some code with ruff. Cannot be done "on the fly" so we first save the code in a temporary file. """ # Format with ruff with tempfile.TemporaryDirectory() as tmpdir: filepath = Path(tmpdir) / "__init__.py" filepath.write_text(code) ruff_bin = find_ruff_bin() os.spawnv(os.P_WAIT, ruff_bin, ["ruff", str(filepath), "--fix", "--quiet"]) return filepath.read_text() def check_future_compatible_hf_api(update: bool) -> NoReturn: """Check that the code defining the threaded version of HfApi is up-to-date.""" # If expected `__init__.py` content is different, test fails. If '--update-init-file' # is used, `__init__.py` file is updated before the test fails. expected_content = generate_hf_api_module() if expected_content != HF_API_FILE_CONTENT: if update: with HF_API_FILE_PATH.open("w") as f: f.write(expected_content) print( "✅ Signature/docstring/annotations for Future-compatible methods have been updated in" " `./src/huggingface_hub/hf_api.py`.\n Please make sure the changes are accurate and commit them." ) exit(0) else: print( "❌ Expected content mismatch for Future compatible methods in `./src/huggingface_hub/hf_api.py`.\n " " Please run `make style` or `python utils/check_future_compatible_signatures.py --update`." ) exit(1) print("✅ All good! (Future-compatible methods)") exit(0) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--update", action="store_true", help="Whether to override `./src/huggingface_hub/hf_api.py` if a change is detected.", ) args = parser.parse_args() check_future_compatible_hf_api(update=args.update) huggingface_hub-0.31.1/utils/check_all_variable.py000066400000000000000000000101071500667546600222030ustar00rootroot00000000000000# coding=utf-8 # Copyright 2025-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Script to check and update the __all__ variable for huggingface_hub/__init__.py.""" import argparse import re from pathlib import Path from typing import Dict, List, NoReturn from huggingface_hub import _SUBMOD_ATTRS INIT_FILE_PATH = Path(__file__).parents[1] / "src" / "huggingface_hub" / "__init__.py" def format_all_definition(submod_attrs: Dict[str, List[str]]) -> str: """ Generate a formatted static __all__ definition with grouped comments. """ all_attrs = sorted(attr for attrs in submod_attrs.values() for attr in attrs) lines = ["__all__ = ["] lines.extend(f' "{attr}",' for attr in all_attrs) lines.append("]") return "\n".join(lines) def parse_all_definition(content: str) -> List[str]: """ Extract the current __all__ contents from file content. This is preferred over "from huggingface_hub import __all__ as current_items" to handle case where __all__ is not defined or malformed in the file we want to be able to fix such issues rather than crash also, we are interested in the file content. """ match = re.search(r"__all__\s*=\s*\[(.*?)\]", content, re.DOTALL) if not match: return [] # Extract items while preserving order, properly cleaning whitespace and quotes return [ line.strip().strip("\",'") for line in match.group(1).split("\n") if line.strip() and not line.strip().startswith("#") ] def check_static_all(update: bool) -> NoReturn: """Check if __all__ is aligned with _SUBMOD_ATTRS or update it.""" content = INIT_FILE_PATH.read_text() new_all = format_all_definition(_SUBMOD_ATTRS) expected_items = sorted(attr for attrs in _SUBMOD_ATTRS.values() for attr in attrs) current_items = list(parse_all_definition(content)) if current_items == expected_items: print("✅ All good! the __all__ variable is up to date") exit(0) if update: all_pattern = re.compile(r"__all__\s*=\s*\[[^\]]*\]", re.MULTILINE | re.DOTALL) if all_pattern.search(content): new_content = all_pattern.sub(new_all, content) else: submod_attrs_pattern = re.compile(r"_SUBMOD_ATTRS\s*=\s*{[^}]*}", re.MULTILINE | re.DOTALL) match = submod_attrs_pattern.search(content) if not match: print("Error: _SUBMOD_ATTRS dictionary not found in `./src/huggingface_hub/__init__.py`.") exit(1) dict_end = match.end() new_content = content[:dict_end] + "\n\n\n" + new_all + "\n\n" + content[dict_end:] INIT_FILE_PATH.write_text(new_content) print( "✅ __all__ variable has been updated in `./src/huggingface_hub/__init__.py`." "\n Please make sure the changes are accurate and commit them." ) exit(0) else: print( "❌ Expected content mismatch in" " `./src/huggingface_hub/__init__.py`.\n It is most likely that" " a module was added to the `_SUBMOD_ATTRS` mapping and did not update the" " '__all__' variable.\n Please run `make style` or `python" " utils/check_all_variable.py --update`." ) exit(1) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--update", action="store_true", help="Whether to fix `./src/huggingface_hub/__init__.py` if a change is detected.", ) args = parser.parse_args() check_static_all(update=args.update) huggingface_hub-0.31.1/utils/check_contrib_list.py000066400000000000000000000075121500667546600222670ustar00rootroot00000000000000# coding=utf-8 # Copyright 2022-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains a tool to list contrib test suites automatically.""" import argparse import re from pathlib import Path from typing import NoReturn ROOT_DIR = Path(__file__).parent.parent CONTRIB_PATH = ROOT_DIR / "contrib" MAKEFILE_PATH = ROOT_DIR / "Makefile" WORKFLOW_PATH = ROOT_DIR / ".github" / "workflows" / "contrib-tests.yml" MAKEFILE_REGEX = re.compile(r"^CONTRIB_LIBS := .*$", flags=re.MULTILINE) WORKFLOW_REGEX = re.compile( r""" # First: match "contrib: [" (?P^\s{8}contrib:\s\[\n) # Match list of libs (\s{10}\".*\",\n)* # Finally: match trailing "]" (?P^\s{8}\]) """, flags=re.MULTILINE | re.VERBOSE, ) def check_contrib_list(update: bool) -> NoReturn: """List `contrib` test suites. Make sure `Makefile` and `.github/workflows/contrib-tests.yml` are consistent with the list.""" # List contrib test suites contrib_list = sorted( path.name for path in CONTRIB_PATH.glob("*") if path.is_dir() and not path.name.startswith("_") ) # Check Makefile is consistent with list makefile_content = MAKEFILE_PATH.read_text() makefile_expected_content = MAKEFILE_REGEX.sub(f"CONTRIB_LIBS := {' '.join(contrib_list)}", makefile_content) # Check workflow is consistent with list workflow_content = WORKFLOW_PATH.read_text() _substitute = "\n".join(f'{" " * 10}"{lib}",' for lib in contrib_list) workflow_content_expected = WORKFLOW_REGEX.sub(rf"\g{_substitute}\n\g", workflow_content) # failed = False if makefile_content != makefile_expected_content: if update: print( "✅ Contrib libs have been updated in `Makefile`." "\n Please make sure the changes are accurate and commit them." ) MAKEFILE_PATH.write_text(makefile_expected_content) else: print( "❌ Expected content mismatch in `Makefile`.\n It is most likely that" " you added a contrib test and did not update the Makefile.\n Please" " run `make style` or `python utils/check_contrib_list.py --update`." ) failed = True if workflow_content != workflow_content_expected: if update: print( f"✅ Contrib libs have been updated in `{WORKFLOW_PATH}`." "\n Please make sure the changes are accurate and commit them." ) WORKFLOW_PATH.write_text(workflow_content_expected) else: print( f"❌ Expected content mismatch in `{WORKFLOW_PATH}`.\n It is most" " likely that you added a contrib test and did not update the github" " workflow file.\n Please run `make style` or `python" " utils/check_contrib_list.py --update`." ) failed = True if failed: exit(1) print("✅ All good! (contrib list)") exit(0) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--update", action="store_true", help="Whether to fix Makefile and github workflow if a new lib is detected.", ) args = parser.parse_args() check_contrib_list(update=args.update) huggingface_hub-0.31.1/utils/check_inference_input_params.py000066400000000000000000000076521500667546600243210ustar00rootroot00000000000000# coding=utf-8 # Copyright 2024-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utility script to check consistency between input parameters of InferenceClient methods and generated types. TODO: check all methods TODO: check parameters types TODO: check parameters default values TODO: check parameters (type, description) are consistent in the docstrings TODO: (low priority) automatically generate the input types from the methods """ import inspect from dataclasses import is_dataclass from typing import Any, get_args from huggingface_hub import InferenceClient from huggingface_hub.inference._generated import types METHODS_TO_SKIP = [ # + all private methods "post", "conversational", ] PARAMETERS_TO_SKIP = { "chat_completion": {}, "text_generation": { "stop", # stop_sequence instead (for legacy reasons) }, } UNDOCUMENTED_PARAMETERS = { "self", } def check_method(method_name: str, method: Any): input_type_name = "".join(part.capitalize() for part in method_name.split("_")) + "Input" if not hasattr(types, input_type_name): return [f"Missing input type for method {method_name}"] input_type = getattr(types, input_type_name) docstring = method.__doc__ if method_name == "chat_completion": # Special case for chat_completion parameters_type = input_type else: parameters_field = input_type.__dataclass_fields__.get("parameters", None) if parameters_field is None: return [f"Missing 'parameters' field for type {input_type}"] parameters_type = get_args(parameters_field.type)[0] if not is_dataclass(parameters_type): return [f"'parameters' field is not a dataclass for type {input_type} ({parameters_type})"] # For each expected parameter, check it is defined logs = [] method_params = inspect.signature(method).parameters for param_name in parameters_type.__dataclass_fields__: if param_name in PARAMETERS_TO_SKIP[method_name]: continue if param_name not in method_params: logs.append(f"Missing parameter {param_name} in method signature") # Check parameter is documented in docstring for param_name in method_params: if param_name in UNDOCUMENTED_PARAMETERS: continue if param_name in PARAMETERS_TO_SKIP[method_name]: continue if f" {param_name} (" not in docstring: logs.append(f"Parameter {param_name} is not documented") return logs # Inspect InferenceClient methods individually exit_code = 0 all_logs = [] # print details only if errors are found for method_name, method in inspect.getmembers(InferenceClient, predicate=inspect.isfunction): if method_name.startswith("_") or method_name in METHODS_TO_SKIP: continue if method_name not in PARAMETERS_TO_SKIP: all_logs.append(f" ⏩️ {method_name}: skipped") continue logs = check_method(method_name, method) if len(logs) > 0: exit_code = 1 all_logs.append(f" ❌ {method_name}: errors found") all_logs.append("\n".join(" " * 4 + log for log in logs)) continue else: all_logs.append(f" ✅ {method_name}: success!") continue if exit_code == 0: print("✅ All good! (inference inputs)") else: print("❌ Inconsistency found in inference inputs.") for log in all_logs: print(log) exit(exit_code) huggingface_hub-0.31.1/utils/check_static_imports.py000066400000000000000000000112571500667546600226410ustar00rootroot00000000000000# coding=utf-8 # Copyright 2022-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains a tool to reformat static imports in `huggingface_hub.__init__.py`.""" import argparse import os import re import tempfile from pathlib import Path from typing import NoReturn from ruff.__main__ import find_ruff_bin from huggingface_hub import _SUBMOD_ATTRS INIT_FILE_PATH = Path(__file__).parents[1] / "src" / "huggingface_hub" / "__init__.py" IF_TYPE_CHECKING_LINE = "\nif TYPE_CHECKING: # pragma: no cover\n" SUBMOD_ATTRS_PATTERN = re.compile("_SUBMOD_ATTRS = {[^}]+}") # match the all dict def check_static_imports(update: bool) -> NoReturn: """Check all imports are made twice (1 in lazy-loading and 1 in static checks). For more explanations, see `./src/huggingface_hub/__init__.py`. This script is used in the `make style` and `make quality` checks. """ with INIT_FILE_PATH.open() as f: init_content = f.read() # Get first half of the `__init__.py` file. # WARNING: Content after this part will be entirely re-generated which means # human-edited changes will be lost ! init_content_before_static_checks = init_content.split(IF_TYPE_CHECKING_LINE)[0] # Search and replace `_SUBMOD_ATTRS` dictionary definition. This ensures modules # and functions that can be lazy-loaded are alphabetically ordered for readability. if SUBMOD_ATTRS_PATTERN.search(init_content_before_static_checks) is None: print("Error: _SUBMOD_ATTRS dictionary definition not found in `./src/huggingface_hub/__init__.py`.") exit(1) _submod_attrs_definition = ( "_SUBMOD_ATTRS = {\n" + "\n".join( f' "{module}": [\n' + "\n".join(f' "{attr}",' for attr in sorted(set(_SUBMOD_ATTRS[module]))) + "\n ]," for module in sorted(set(_SUBMOD_ATTRS.keys())) ) + "\n}" ) reordered_content_before_static_checks = SUBMOD_ATTRS_PATTERN.sub( _submod_attrs_definition, init_content_before_static_checks ) # Generate the static imports given the `_SUBMOD_ATTRS` dictionary. static_imports = [ f" from .{module} import {attr} # noqa: F401" for module, attributes in _SUBMOD_ATTRS.items() for attr in attributes ] # Generate the expected `__init__.py` file content and apply formatter on it. with tempfile.TemporaryDirectory() as tmpdir: filepath = Path(tmpdir) / "__init__.py" filepath.write_text( reordered_content_before_static_checks + IF_TYPE_CHECKING_LINE + "\n".join(static_imports) + "\n" ) ruff_bin = find_ruff_bin() os.spawnv(os.P_WAIT, ruff_bin, ["ruff", "check", str(filepath), "--fix", "--quiet"]) os.spawnv(os.P_WAIT, ruff_bin, ["ruff", "format", str(filepath), "--quiet"]) expected_init_content = filepath.read_text() # If expected `__init__.py` content is different, test fails. If '--update-init-file' # is used, `__init__.py` file is updated before the test fails. if init_content != expected_init_content: if update: with INIT_FILE_PATH.open("w") as f: f.write(expected_init_content) print( "✅ Imports have been updated in `./src/huggingface_hub/__init__.py`." "\n Please make sure the changes are accurate and commit them." ) exit(0) else: print( "❌ Expected content mismatch in" " `./src/huggingface_hub/__init__.py`.\n It is most likely that you" " added a module/function to `_SUBMOD_ATTRS` and did not update the" " 'static import'-part.\n Please run `make style` or `python" " utils/check_static_imports.py --update`." ) exit(1) print("✅ All good! (static imports)") exit(0) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--update", action="store_true", help="Whether to fix `./src/huggingface_hub/__init__.py` if a change is detected.", ) args = parser.parse_args() check_static_imports(update=args.update) huggingface_hub-0.31.1/utils/check_task_parameters.py000066400000000000000000001035771500667546600227710ustar00rootroot00000000000000# coding=utf-8 # Copyright 2024-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Utility script to check and update the InferenceClient task methods arguments and docstrings based on the tasks input parameters. What this script does: - [x] detect missing parameters in method signature - [x] add missing parameters to methods signature - [x] detect missing parameters in method docstrings - [x] add missing parameters to methods docstrings - [x] detect outdated parameters in method signature - [x] update outdated parameters in method signature - [x] detect outdated parameters in method docstrings - [x] update outdated parameters in method docstrings - [ ] detect when parameter not used in method implementation - [ ] update method implementation when parameter not used Related resources: - https://github.com/huggingface/huggingface_hub/issues/2063 - https://github.com/huggingface/huggingface_hub/issues/2557 - https://github.com/huggingface/huggingface_hub/pull/2561 """ import argparse import builtins import inspect import re import textwrap from collections import defaultdict from pathlib import Path from typing import Any, Dict, List, NoReturn, Optional, Set, Tuple import libcst as cst from helpers import format_source_code from libcst.codemod import CodemodContext from libcst.codemod.visitors import GatherImportsVisitor from huggingface_hub import InferenceClient # Paths to project files BASE_DIR = Path(__file__).parents[1] / "src" / "huggingface_hub" INFERENCE_TYPES_PATH = BASE_DIR / "inference" / "_generated" / "types" INFERENCE_CLIENT_FILE = BASE_DIR / "inference" / "_client.py" DEFAULT_MODULE = "huggingface_hub.inference._generated.types" # Temporary solution to skip tasks where there is no Parameters dataclass or the schema needs to be updated TASKS_TO_SKIP = [ "chat_completion", "text_generation", "depth_estimation", "audio_to_audio", "feature_extraction", "sentence_similarity", "automatic_speech_recognition", "image_to_text", ] PARAMETERS_DATACLASS_REGEX = re.compile( r""" ^@dataclass_with_extra \nclass\s(\w+Parameters)\(BaseInferenceType\): """, re.VERBOSE | re.MULTILINE, ) CORE_PARAMETERS = { "model", # Model identifier "text", # Text input "image", # Image input "audio", # Audio input "inputs", # Generic inputs "input", # Generic input "prompt", # For generation tasks "question", # For QA tasks "context", # For QA tasks "labels", # For classification tasks "extra_body", # For extra parameters } #### NODE VISITORS (READING THE CODE) class DataclassFieldCollector(cst.CSTVisitor): """A visitor that collects fields (parameters) from a dataclass.""" def __init__(self, dataclass_name: str): self.dataclass_name = dataclass_name self.parameters: Dict[str, Dict[str, str]] = {} def visit_ClassDef(self, node: cst.ClassDef) -> None: """Visit class definitions to find the target dataclass.""" if node.name.value == self.dataclass_name: body_statements = node.body.body for index, field in enumerate(body_statements): # Check if the statement is a simple statement (like a variable declaration) if isinstance(field, cst.SimpleStatementLine): for stmt in field.body: # Check if it's an annotated assignment (typical for dataclass fields) if isinstance(stmt, cst.AnnAssign) and isinstance(stmt.target, cst.Name): param_name = stmt.target.value param_type = cst.Module([]).code_for_node(stmt.annotation.annotation) docstring = self._extract_docstring(body_statements, index) # Check if there's a default value has_default = stmt.value is not None default_value = cst.Module([]).code_for_node(stmt.value) if has_default else None self.parameters[param_name] = { "type": param_type, "docstring": docstring, "has_default": has_default, "default_value": default_value, } @staticmethod def _extract_docstring( body_statements: List[cst.CSTNode], field_index: int, ) -> str: """Extract the docstring following a field definition.""" if field_index + 1 < len(body_statements): # Check if the next statement is a simple statement (like a string) next_stmt = body_statements[field_index + 1] if isinstance(next_stmt, cst.SimpleStatementLine): for stmt in next_stmt.body: # Check if the statement is a string expression (potential docstring) if isinstance(stmt, cst.Expr) and isinstance(stmt.value, cst.SimpleString): return stmt.value.evaluated_value.strip() # No docstring found or there's no statement after the field return "" class ModulesCollector(cst.CSTVisitor): """Visitor that maps type names to their defining modules.""" def __init__(self): self.type_to_module = {} def visit_ClassDef(self, node: cst.ClassDef): """Map class definitions to the current module.""" self.type_to_module[node.name.value] = DEFAULT_MODULE def visit_ImportFrom(self, node: cst.ImportFrom): """Map imported types to their modules.""" if node.module: module_name = node.module.value for alias in node.names: self.type_to_module[alias.name.value] = module_name class MethodArgumentsCollector(cst.CSTVisitor): """Collects parameter types and docstrings from a method.""" def __init__(self, method_name: str): self.method_name = method_name self.parameters: Dict[str, Dict[str, str]] = {} def visit_FunctionDef(self, node: cst.FunctionDef) -> None: if node.name.value != self.method_name: return # Extract docstring docstring = self._extract_docstring(node) param_docs = self._parse_docstring_params(docstring) # Collect parameters for param in node.params.params + node.params.kwonly_params: if param.name.value == "self" or param.name.value in CORE_PARAMETERS: continue param_type = cst.Module([]).code_for_node(param.annotation.annotation) if param.annotation else "Any" self.parameters[param.name.value] = {"type": param_type, "docstring": param_docs.get(param.name.value, "")} def _extract_docstring(self, node: cst.FunctionDef) -> str: """Extract docstring from function node.""" if ( isinstance(node.body.body[0], cst.SimpleStatementLine) and isinstance(node.body.body[0].body[0], cst.Expr) and isinstance(node.body.body[0].body[0].value, cst.SimpleString) ): return node.body.body[0].body[0].value.evaluated_value return "" def _parse_docstring_params(self, docstring: str) -> Dict[str, str]: """Parse parameter descriptions from docstring.""" param_docs = {} lines = docstring.split("\n") # Find Args section args_idx = next((i for i, line in enumerate(lines) if line.strip().lower() == "args:"), None) if args_idx is None: return param_docs # Parse parameter descriptions current_param = None current_desc = [] for line in lines[args_idx + 1 :]: stripped_line = line.strip() if not stripped_line or stripped_line.lower() in ("returns:", "raises:", "example:", "examples:"): break if stripped_line.endswith(":"): # Parameter line if current_param: param_docs[current_param] = " ".join(current_desc) current_desc = [] # Extract only the parameter name before the first space or parenthesis current_param = re.split(r"\s|\(", stripped_line[:-1], 1)[0].strip() else: # Description line current_desc.append(stripped_line) if current_param: # Save last parameter param_docs[current_param] = " ".join(current_desc) return param_docs #### TREE TRANSFORMERS (UPDATING THE CODE) class AddImports(cst.CSTTransformer): """Transformer that adds import statements to the module.""" def __init__(self, imports_to_add: List[cst.BaseStatement]): self.imports_to_add = imports_to_add self.added = False def leave_Module( self, original_node: cst.Module, updated_node: cst.Module, ) -> cst.Module: """Insert the import statements into the module.""" # If imports were already added, don't add them again if self.added: return updated_node insertion_index = 0 # Find the index where to insert the imports: make sure the imports are inserted before any code and after all imports (not necessary, we can remove/simplify this part) for idx, stmt in enumerate(updated_node.body): if not isinstance(stmt, cst.SimpleStatementLine): insertion_index = idx break elif not isinstance(stmt.body[0], (cst.Import, cst.ImportFrom)): insertion_index = idx break # Insert the imports new_body = ( list(updated_node.body[:insertion_index]) + list(self.imports_to_add) + list(updated_node.body[insertion_index:]) ) self.added = True return updated_node.with_changes(body=new_body) class UpdateParameters(cst.CSTTransformer): """Updates a method's parameters, types, and docstrings.""" def __init__(self, method_name: str, param_updates: Dict[str, Dict[str, str]]): self.method_name = method_name self.param_updates = param_updates self.found_method = False # Flag to check if the method is found def leave_FunctionDef( self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef, ) -> cst.FunctionDef: # Only proceed if the current function is the target method if original_node.name.value != self.method_name: return updated_node self.found_method = True # Set the flag as the method is found # Update the parameters and docstring of the method new_params = self._update_parameters(updated_node.params) updated_body = self._update_docstring(updated_node.body) # Return the updated function definition return updated_node.with_changes(params=new_params, body=updated_body) def _update_parameters(self, params: cst.Parameters) -> cst.Parameters: """Update parameter types and add new parameters.""" new_params = list(params.params) # Copy regular parameters (e.g., 'self') new_kwonly_params = [] # Collect existing parameter names to avoid duplicates existing_params = {p.name.value for p in params.params + params.kwonly_params} # Update existing keyword-only parameters for param in params.kwonly_params: param_name = param.name.value if param_name in self.param_updates: # Update the type annotation for the parameter new_annotation = cst.Annotation( annotation=cst.parse_expression(self.param_updates[param_name]["type"]) ) new_kwonly_params.append(param.with_changes(annotation=new_annotation)) else: # Keep the parameter as is if no update is needed new_kwonly_params.append(param) # Add new parameters that are not already present for param_name, param_info in self.param_updates.items(): if param_name not in existing_params: # Create a new parameter with the provided type and a default value of None annotation = cst.Annotation(annotation=cst.parse_expression(param_info["type"])) new_param = cst.Param( name=cst.Name(param_name), annotation=annotation, default=cst.Name(param_info["default_value"]), ) new_kwonly_params.append(new_param) # Return the updated parameters object with new and updated parameters return params.with_changes(params=new_params, kwonly_params=new_kwonly_params) def _update_docstring(self, body: cst.IndentedBlock) -> cst.IndentedBlock: """Update parameter descriptions in the docstring.""" # Check if the first statement is a docstring if not ( isinstance(body.body[0], cst.SimpleStatementLine) and isinstance(body.body[0].body[0], cst.Expr) and isinstance(body.body[0].body[0].value, cst.SimpleString) ): # Return the body unchanged if no docstring is found return body docstring_expr = body.body[0].body[0] docstring = docstring_expr.value.evaluated_value # Get the docstring content # Update the docstring content with new and updated parameters updated_docstring = self._update_docstring_content(docstring) new_docstring = cst.SimpleString(f'"""{updated_docstring}"""') # Replace the old docstring with the updated one new_body = [body.body[0].with_changes(body=[docstring_expr.with_changes(value=new_docstring)])] + list( body.body[1:] ) # Return the updated function body return body.with_changes(body=new_body) def _update_docstring_content(self, docstring: str) -> str: """Update parameter descriptions in the docstring content.""" # Split parameters into new and updated ones based on their status new_params = {name: info for name, info in self.param_updates.items() if info["status"] == "new"} update_params = { name: info for name, info in self.param_updates.items() if info["status"] in ("update_type", "update_doc") } # Split the docstring into lines for processing docstring_lines = docstring.split("\n") # Find or create the "Args:" section and compute indentation levels args_index = next((i for i, line in enumerate(docstring_lines) if line.strip().lower() == "args:"), None) if args_index is None: # If 'Args:' section is not found, insert it before 'Returns:' or at the end insertion_index = next( ( i for i, line in enumerate(docstring_lines) if line.strip().lower() in ("returns:", "raises:", "examples:", "example:") ), len(docstring_lines), ) docstring_lines.insert(insertion_index, "Args:") args_index = insertion_index # Update the args_index with the new section base_indent = docstring_lines[args_index][: -len(docstring_lines[args_index].lstrip())] param_indent = base_indent + " " # Indentation for parameter lines desc_indent = param_indent + " " # Indentation for description lines # Update existing parameters in the docstring if update_params: docstring_lines, params_updated = self._process_existing_params( docstring_lines, update_params, args_index, param_indent, desc_indent ) # When params_updated is still not empty, it means there are new parameters that are not in the docstring # but are in the method signature new_params = {**new_params, **params_updated} # Add new parameters to the docstring if new_params: docstring_lines = self._add_new_params(docstring_lines, new_params, args_index, param_indent, desc_indent) # Join the docstring lines back into a single string return "\n".join(docstring_lines) def _format_param_docstring( self, param_name: str, param_info: Dict[str, str], param_indent: str, desc_indent: str, ) -> List[str]: """Format the docstring lines for a single parameter.""" # Extract and format the parameter type param_type = param_info["type"] if param_type.startswith("Optional["): param_type = param_type[len("Optional[") : -1] # Remove Optional[ and closing ] optional_str = ", *optional*" else: optional_str = "" # Create the parameter line with type and optionality param_line = f"{param_indent}{param_name} (`{param_type}`{optional_str}):" # Get and clean up the parameter description param_desc = (param_info.get("docstring") or "").strip() param_desc = " ".join(param_desc.split()) if param_desc: # Wrap the description text to maintain line width and indentation wrapped_desc = textwrap.fill( param_desc, width=119, initial_indent=desc_indent, subsequent_indent=desc_indent, ) return [param_line, wrapped_desc] else: # Return only the parameter line if there's no description return [param_line] def _process_existing_params( self, docstring_lines: List[str], params_to_update: Dict[str, Dict[str, str]], args_index: int, param_indent: str, desc_indent: str, ) -> Tuple[List[str], Dict[str, Dict[str, str]]]: """Update existing parameters in the docstring.""" # track the params that are updated params_updated = params_to_update.copy() i = args_index + 1 # Start after the 'Args:' section while i < len(docstring_lines): line = docstring_lines[i] stripped_line = line.strip() if not stripped_line: # Skip empty lines i += 1 continue if stripped_line.lower() in ("returns:", "raises:", "example:", "examples:"): # Stop processing if another section starts break if stripped_line.endswith(":"): # Check if the line is a parameter line param_line = stripped_line param_name = param_line.strip().split()[0] # Extract parameter name if param_name in params_updated: # Get the updated parameter info param_info = params_updated.pop(param_name) # Format the new parameter docstring param_doc_lines = self._format_param_docstring(param_name, param_info, param_indent, desc_indent) # Find the end of the current parameter's description start_idx = i end_idx = i + 1 while end_idx < len(docstring_lines): next_line = docstring_lines[end_idx] # Next parameter or section starts or another section starts or empty line if ( (next_line.strip().endswith(":") and not next_line.startswith(desc_indent)) or next_line.lower() in ("returns:", "raises:", "example:", "examples:") or not next_line ): break end_idx += 1 # Insert new param docs and preserve the rest of the docstring docstring_lines = ( docstring_lines[:start_idx] # Keep everything before + param_doc_lines # Insert new parameter docs + docstring_lines[end_idx:] # Keep everything after ) i = start_idx + len(param_doc_lines) # Update index to after inserted lines i += 1 else: i += 1 # Move to the next line if not a parameter line return docstring_lines, params_updated def _add_new_params( self, docstring_lines: List[str], new_params: Dict[str, Dict[str, str]], args_index: int, param_indent: str, desc_indent: str, ) -> List[str]: """Add new parameters to the docstring.""" # Find the insertion point after existing parameters insertion_index = args_index + 1 empty_line_index = None while insertion_index < len(docstring_lines): line = docstring_lines[insertion_index] stripped_line = line.strip() # Track empty line at the end of Args section if not stripped_line: if empty_line_index is None: # Remember first empty line empty_line_index = insertion_index insertion_index += 1 continue if stripped_line.lower() in ("returns:", "raises:", "example:", "examples:"): break empty_line_index = None # Reset if we find more content if stripped_line.endswith(":") and not line.startswith(desc_indent.strip()): insertion_index += 1 else: insertion_index += 1 # If we found an empty line at the end of the Args section, insert before it if empty_line_index is not None: insertion_index = empty_line_index # Prepare the new parameter documentation lines param_docs = [] for param_name, param_info in new_params.items(): param_doc_lines = self._format_param_docstring(param_name, param_info, param_indent, desc_indent) param_docs.extend(param_doc_lines) # Insert the new parameters into the docstring docstring_lines[insertion_index:insertion_index] = param_docs return docstring_lines #### UTILS def _check_parameters( inference_client_module: cst.Module, parameters_module: cst.Module, method_name: str, parameter_type_name: str, ) -> Dict[str, Dict[str, Any]]: """ Check for missing parameters and outdated types/docstrings. Args: inference_client_module: Module containing the InferenceClient parameters_module: Module containing the parameters dataclass method_name: Name of the method to check parameter_type_name: Name of the parameters dataclass Returns: Dict mapping parameter names to their updates: {param_name: { "type": str, # Type annotation "docstring": str, # Parameter documentation "status": "new"|"update_type"|"update_doc" # Whether parameter is new or needs update }} """ # Get parameters from the dataclass params_collector = DataclassFieldCollector(parameter_type_name) parameters_module.visit(params_collector) dataclass_params = params_collector.parameters # Get existing parameters from the method method_collector = MethodArgumentsCollector(method_name) inference_client_module.visit(method_collector) existing_params = method_collector.parameters updates = {} # Check for new and updated parameters for param_name, param_info in dataclass_params.items(): if param_name in CORE_PARAMETERS: continue if param_name not in existing_params: # New parameter updates[param_name] = {**param_info, "status": "new"} else: # Check for type/docstring changes current = existing_params[param_name] normalized_current_doc = _normalize_docstring(current["docstring"]) normalized_new_doc = _normalize_docstring(param_info["docstring"]) if current["type"] != param_info["type"]: updates[param_name] = {**param_info, "status": "update_type"} if normalized_current_doc != normalized_new_doc: updates[param_name] = {**param_info, "status": "update_doc"} return updates def _update_parameters( module: cst.Module, method_name: str, param_updates: Dict[str, Dict[str, str]], ) -> cst.Module: """ Update method parameters, types and docstrings. Args: module: The module to update method_name: Name of the method to update param_updates: Dictionary of parameter updates with their type and docstring Format: {param_name: {"type": str, "docstring": str, "status": "new"|"update_type"|"update_doc"}} Returns: Updated module """ transformer = UpdateParameters(method_name, param_updates) return module.visit(transformer) def _get_imports_to_add( parameters: Dict[str, Dict[str, str]], parameters_module: cst.Module, inference_client_module: cst.Module, ) -> Dict[str, List[str]]: """ Get the needed imports for missing parameters. Args: parameters (Dict[str, Dict[str, str]]): Dictionary of parameters with their type and docstring. eg: {"function_to_apply": {"type": "ClassificationOutputTransform", "docstring": "Function to apply to the input."}} parameters_module (cst.Module): The module where the parameters are defined. inference_client_module (cst.Module): The module of the inference client. Returns: Dict[str, List[str]]: A dictionary mapping modules to list of types to import. eg: {"huggingface_hub.inference._generated.types": ["ClassificationOutputTransform"]} """ # Collect all type names from parameter annotations types_to_import = set() for param_info in parameters.values(): types_to_import.update(_collect_type_hints_from_annotation(param_info["type"])) # Gather existing imports in the inference client module context = CodemodContext() gather_visitor = GatherImportsVisitor(context) inference_client_module.visit(gather_visitor) # Map types to their defining modules in the parameters module module_collector = ModulesCollector() parameters_module.visit(module_collector) # Determine which imports are needed needed_imports = {} for type_name in types_to_import: types_to_modules = module_collector.type_to_module module = types_to_modules.get(type_name, DEFAULT_MODULE) # Maybe no need to check that since the code formatter will handle duplicate imports? if module not in gather_visitor.object_mapping or type_name not in gather_visitor.object_mapping[module]: needed_imports.setdefault(module, []).append(type_name) return needed_imports def _generate_import_statements(import_dict: Dict[str, List[str]]) -> str: """ Generate import statements from a dictionary of needed imports. Args: import_dict (Dict[str, List[str]]): Dictionary mapping modules to list of types to import. eg: {"typing": ["List", "Dict"], "huggingface_hub.inference._generated.types": ["ClassificationOutputTransform"]} Returns: str: The import statements as a string. """ import_statements = [] for module, imports in import_dict.items(): if imports: import_list = ", ".join(imports) import_statements.append(f"from {module} import {import_list}") else: import_statements.append(f"import {module}") return "\n".join(import_statements) def _normalize_docstring(docstring: str) -> str: """Normalize a docstring by removing extra whitespace, newlines and indentation.""" # Split into lines, strip whitespace from each line, and join back return " ".join(line.strip() for line in docstring.split("\n")).strip() # TODO: Needs to be improved, maybe using `typing.get_type_hints` instead (we gonna need to access the method though)? def _collect_type_hints_from_annotation(annotation_str: str) -> Set[str]: """ Collect type hints from an annotation string. Args: annotation_str (str): The annotation string. Returns: Set[str]: A set of type hints. """ type_string = annotation_str.replace(" ", "") builtin_types = {d for d in dir(builtins) if isinstance(getattr(builtins, d), type)} types = re.findall(r"\w+|'[^']+'|\"[^\"]+\"", type_string) extracted_types = {t.strip("\"'") for t in types if t.strip("\"'") not in builtin_types} return extracted_types def _get_parameter_type_name(method_name: str) -> Optional[str]: file_path = INFERENCE_TYPES_PATH / f"{method_name}.py" if not file_path.is_file(): print(f"File not found: {file_path}") return None content = file_path.read_text(encoding="utf-8") match = PARAMETERS_DATACLASS_REGEX.search(content) return match.group(1) if match else None def _parse_module_from_file(filepath: Path) -> Optional[cst.Module]: try: code = filepath.read_text(encoding="utf-8") return cst.parse_module(code) except FileNotFoundError: print(f"File not found: {filepath}") except cst.ParserSyntaxError as e: print(f"Syntax error while parsing {filepath}: {e}") return None def _check_and_update_parameters( method_params: Dict[str, str], update: bool, ) -> NoReturn: """ Check if task methods have missing parameters and update the InferenceClient source code if needed. """ merged_imports = defaultdict(set) logs = [] inference_client_filename = INFERENCE_CLIENT_FILE # Read and parse the inference client module inference_client_module = _parse_module_from_file(inference_client_filename) modified_module = inference_client_module has_changes = False for method_name, parameter_type_name in method_params.items(): parameters_filename = INFERENCE_TYPES_PATH / f"{method_name}.py" parameters_module = _parse_module_from_file(parameters_filename) # Check for missing parameters updates = _check_parameters( modified_module, parameters_module, method_name, parameter_type_name, ) if not updates: continue if update: ## Get missing imports to add needed_imports = _get_imports_to_add(updates, parameters_module, modified_module) for module, imports_to_add in needed_imports.items(): merged_imports[module].update(imports_to_add) modified_module = _update_parameters(modified_module, method_name, updates) has_changes = True else: logs.append(f"\n🔧 Updates needed in method `{method_name}`:") new_params = [p for p, i in updates.items() if i["status"] == "new"] updated_params = { p: "type" if i["status"] == "update_type" else "docstring" for p, i in updates.items() if i["status"] in ("update_type", "update_doc") } if new_params: for param in sorted(new_params): logs.append(f" • {param} (missing)") if updated_params: for param, update_type in sorted(updated_params.items()): logs.append(f" • {param} (outdated {update_type})") if has_changes: if merged_imports: import_statements = _generate_import_statements(merged_imports) imports_to_add = cst.parse_module(import_statements).body # Update inference client module with the missing imports modified_module = modified_module.visit(AddImports(imports_to_add)) # Format the updated source code formatted_source_code = format_source_code(modified_module.code) INFERENCE_CLIENT_FILE.write_text(formatted_source_code) if len(logs) > 0: for log in logs: print(log) print( "❌ Mismatch between between parameters defined in tasks methods signature in " "`./src/huggingface_hub/inference/_client.py` and parameters defined in " "`./src/huggingface_hub/inference/_generated/types.py \n" "Please run `make inference_update` or `python utils/check_task_parameters.py --update" ) exit(1) else: if update: print( "✅ InferenceClient source code has been updated in" " `./src/huggingface_hub/inference/_client.py`.\n Please make sure the changes are" " accurate and commit them." ) else: print("✅ All good!") exit(0) def update_inference_client(update: bool): print(f"🙈 Skipping the following tasks: {TASKS_TO_SKIP}") # Get all tasks from the ./src/huggingface_hub/inference/_generated/types/ tasks = set() for file in INFERENCE_TYPES_PATH.glob("*.py"): if file.stem not in TASKS_TO_SKIP: tasks.add(file.stem) # Construct a mapping between method names and their parameters dataclass names method_params = {} for method_name, _ in inspect.getmembers(InferenceClient, predicate=inspect.isfunction): if method_name.startswith("_") or method_name not in tasks: continue parameter_type_name = _get_parameter_type_name(method_name) if parameter_type_name is not None: method_params[method_name] = parameter_type_name _check_and_update_parameters(method_params, update=update) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--update", action="store_true", help=("Whether to update `./src/huggingface_hub/inference/_client.py` if parameters are missing."), ) args = parser.parse_args() update_inference_client(update=args.update) huggingface_hub-0.31.1/utils/generate_async_inference_client.py000066400000000000000000000476221500667546600250100ustar00rootroot00000000000000# coding=utf-8 # Copyright 2023-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains a tool to generate `src/huggingface_hub/inference/_generated/_async_client.py`.""" import argparse import re from pathlib import Path from typing import NoReturn from helpers import format_source_code ASYNC_CLIENT_FILE_PATH = ( Path(__file__).parents[1] / "src" / "huggingface_hub" / "inference" / "_generated" / "_async_client.py" ) SYNC_CLIENT_FILE_PATH = Path(__file__).parents[1] / "src" / "huggingface_hub" / "inference" / "_client.py" def generate_async_client_code(code: str) -> str: """Generate AsyncInferenceClient source code.""" # Warning message "this is an automatically generated file" code = _add_warning_to_file_header(code) # Imports specific to asyncio code = _add_imports(code) # Define `AsyncInferenceClient` code = _rename_to_AsyncInferenceClient(code) # Refactor `.post` method to be async + adapt calls code = _make_inner_post_async(code) code = _await_inner_post_method_call(code) code = _use_async_streaming_util(code) # Make all tasks-method async code = _make_tasks_methods_async(code) # Adapt text_generation to async code = _adapt_text_generation_to_async(code) # Adapt chat_completion to async code = _adapt_chat_completion_to_async(code) # Update some docstrings code = _rename_HTTPError_to_ClientResponseError_in_docstring(code) code = _update_examples_in_public_methods(code) # Adapt get_model_status code = _adapt_get_model_status(code) # Adapt list_deployed_models code = _adapt_list_deployed_models(code) # Adapt /info and /health endpoints code = _adapt_info_and_health_endpoints(code) # Add _get_client_session code = _add_get_client_session(code) # Adapt the proxy client (for client.chat.completions.create) code = _adapt_proxy_client(code) return code def check_async_client(update: bool) -> NoReturn: """Check AsyncInferenceClient is correctly defined and consistent with InferenceClient. This script is used in the `make style` and `make quality` checks. """ sync_client_code = SYNC_CLIENT_FILE_PATH.read_text() current_async_client_code = ASYNC_CLIENT_FILE_PATH.read_text() raw_async_client_code = generate_async_client_code(sync_client_code) formatted_async_client_code = format_source_code(raw_async_client_code) # If expected `__init__.py` content is different, test fails. If '--update-init-file' # is used, `__init__.py` file is updated before the test fails. if current_async_client_code != formatted_async_client_code: if update: ASYNC_CLIENT_FILE_PATH.write_text(formatted_async_client_code) print( "✅ AsyncInferenceClient source code has been updated in" " `./src/huggingface_hub/inference/_generated/_async_client.py`.\n Please make sure the changes are" " accurate and commit them." ) exit(0) else: print( "❌ Expected content mismatch in `./src/huggingface_hub/inference/_generated/_async_client.py`.\n It" " is most likely that you modified some InferenceClient code and did not update the" " AsyncInferenceClient one.\n Please run `make style` or `python" " utils/generate_async_inference_client.py --update`." ) exit(1) print("✅ All good! (AsyncInferenceClient)") exit(0) def _add_warning_to_file_header(code: str) -> str: warning_message = ( "#\n# WARNING\n# This entire file has been adapted from the sync-client code in" " `src/huggingface_hub/inference/_client.py`.\n# Any change in InferenceClient will be automatically reflected" " in AsyncInferenceClient.\n# To re-generate the code, run `make style` or `python" " ./utils/generate_async_inference_client.py --update`.\n# WARNING" ) return re.sub( r""" ( # Group1: license (end) \n \#\ limitations\ under\ the\ License. \n ) (.*?) # Group2 : all notes and comments (to be replaced) (\nimport[ ]) # Group3: import section (start) """, repl=rf"\1{warning_message}\3", string=code, count=1, flags=re.DOTALL | re.VERBOSE, ) def _add_imports(code: str) -> str: # global imports code = re.sub( r"(\nimport .*?\n)", repl=( r"\1" + "from .._common import _async_yield_from, _import_aiohttp\n" + "from typing import AsyncIterable\n" + "from typing import Set\n" + "import asyncio\n" ), string=code, count=1, flags=re.DOTALL, ) # type-checking imports code = re.sub( r"(\nif TYPE_CHECKING:\n)", repl=r"\1 from aiohttp import ClientResponse, ClientSession\n", string=code, count=1, flags=re.DOTALL, ) return code def _rename_to_AsyncInferenceClient(code: str) -> str: return code.replace("class InferenceClient:", "class AsyncInferenceClient:", 1) ASYNC_INNER_POST_CODE = """ aiohttp = _import_aiohttp() # TODO: this should be handled in provider helpers directly if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers: request_parameters.headers["Accept"] = "image/png" with _open_as_binary(request_parameters.data) as data_as_binary: # Do not use context manager as we don't want to close the connection immediately when returning # a stream session = self._get_client_session(headers=request_parameters.headers) try: response = await session.post(request_parameters.url, json=request_parameters.json, data=data_as_binary, proxy=self.proxies) response_error_payload = None if response.status != 200: try: response_error_payload = await response.json() # get payload before connection closed except Exception: pass response.raise_for_status() if stream: return _async_yield_from(session, response) else: content = await response.read() await session.close() return content except asyncio.TimeoutError as error: await session.close() # Convert any `TimeoutError` to a `InferenceTimeoutError` raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore except aiohttp.ClientResponseError as error: error.response_error_payload = response_error_payload await session.close() raise error except Exception: await session.close() raise async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_value, traceback): await self.close() def __del__(self): if len(self._sessions) > 0: warnings.warn( "Deleting 'AsyncInferenceClient' client but some sessions are still open. " "This can happen if you've stopped streaming data from the server before the stream was complete. " "To close the client properly, you must call `await client.close()` " "or use an async context (e.g. `async with AsyncInferenceClient(): ...`." ) async def close(self): \"""Close all open sessions. By default, 'aiohttp.ClientSession' objects are closed automatically when a call is completed. However, if you are streaming data from the server and you stop before the stream is complete, you must call this method to close the session properly. Another possibility is to use an async context (e.g. `async with AsyncInferenceClient(): ...`). \""" await asyncio.gather(*[session.close() for session in self._sessions.keys()])""" def _make_inner_post_async(code: str) -> str: # Update AsyncInferenceClient._inner_post() implementation (use aiohttp instead of requests) code = re.sub( r""" def[ ]_inner_post\( # definition (\n.*?\"\"\".*?\"\"\"\n) # Group1: docstring .*? # implementation (to be overwritten) (\n\W*def ) # Group2: next method """, repl=rf"async def _inner_post(\1{ASYNC_INNER_POST_CODE}\2", string=code, count=1, flags=re.DOTALL | re.VERBOSE, ) # Update `post`'s type annotations code = code.replace(" def _inner_post(", " async def _inner_post(") return code.replace("Iterable[bytes]", "AsyncIterable[bytes]") def _rename_HTTPError_to_ClientResponseError_in_docstring(code: str) -> str: # Update `raises`-part in docstrings return code.replace("`HTTPError`:", "`aiohttp.ClientResponseError`:") def _make_tasks_methods_async(code: str) -> str: # Add `async` keyword in front of public methods (of AsyncClientInference) return re.sub( r""" # Group 1: newline + 4-spaces indent (\n\ {4}) # Group 2: def + method name + parenthesis + optionally type: ignore + self ( def[ ] # def [a-z]\w*? # method name (not starting by _) \( # parenthesis (\s*\#[ ]type:[ ]ignore(\[misc\])?)? # optionally 'type: ignore' or 'type: ignore[misc]' \s*self, # expect self, as first arg )""", repl=r"\1async \2", # insert "async" keyword string=code, flags=re.DOTALL | re.VERBOSE, ) def _adapt_text_generation_to_async(code: str) -> str: # Text-generation task has to be handled specifically since it has a recursive call mechanism (to retry on non-tgi # servers) # Catch `aiohttp` error instead of `requests` error code = code.replace( """ except HTTPError as e: match = MODEL_KWARGS_NOT_USED_REGEX.search(str(e)) if isinstance(e, BadRequestError) and match: """, """ except _import_aiohttp().ClientResponseError as e: match = MODEL_KWARGS_NOT_USED_REGEX.search(e.response_error_payload["error"]) if e.status == 400 and match: """, ) # Await recursive call code = code.replace( "return self.text_generation", "return await self.text_generation", ) code = code.replace( "return self.chat_completion", "return await self.chat_completion", ) # Update return types: Iterable -> AsyncIterable code = code.replace( ") -> Iterable[str]:", ") -> AsyncIterable[str]:", ) code = code.replace( ") -> Union[bytes, Iterable[bytes]]:", ") -> Union[bytes, AsyncIterable[bytes]]:", ) code = code.replace( ") -> Iterable[TextGenerationStreamOutput]:", ") -> AsyncIterable[TextGenerationStreamOutput]:", ) code = code.replace( ") -> Union[TextGenerationOutput, Iterable[TextGenerationStreamOutput]]:", ") -> Union[TextGenerationOutput, AsyncIterable[TextGenerationStreamOutput]]:", ) code = code.replace( ") -> Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]:", ") -> Union[str, TextGenerationOutput, AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]:", ) return code def _adapt_chat_completion_to_async(code: str) -> str: # Await text-generation call code = code.replace( "text_generation_output = self.text_generation(", "text_generation_output = await self.text_generation(", ) # Update return types: Iterable -> AsyncIterable code = code.replace( ") -> Iterable[ChatCompletionStreamOutput]:", ") -> AsyncIterable[ChatCompletionStreamOutput]:", ) code = code.replace( ") -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]:", ") -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]:", ) return code def _await_inner_post_method_call(code: str) -> str: return code.replace("self._inner_post(", "await self._inner_post(") def _update_example_code_block(code_block: str) -> str: """Update an atomic code block example from a docstring.""" code_block = "\n # Must be run in an async context" + code_block code_block = code_block.replace("InferenceClient", "AsyncInferenceClient") code_block = code_block.replace("client.", "await client.") code_block = code_block.replace(">>> for ", ">>> async for ") return code_block def _update_examples_in_public_methods(code: str) -> str: for match in re.finditer( r""" \n\s* Example.*?:\n\s* # example section ```py # start (.*?) # code block ``` # end \n """, string=code, flags=re.DOTALL | re.VERBOSE, ): # Example, including code block full_match = match.group() # Code block alone code_block = match.group(1) # Update code block in example updated_match = full_match.replace(code_block, _update_example_code_block(code_block)) # Update example in full script code = code.replace(full_match, updated_match) return code def _use_async_streaming_util(code: str) -> str: code = code.replace( "_stream_text_generation_response", "_async_stream_text_generation_response", ) code = code.replace("_stream_chat_completion_response", "_async_stream_chat_completion_response") return code def _adapt_get_model_status(code: str) -> str: sync_snippet = """ response = get_session().get(url, headers=build_hf_headers(token=self.token)) hf_raise_for_status(response) response_data = response.json()""" async_snippet = """ async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client: response = await client.get(url, proxy=self.proxies) response.raise_for_status() response_data = await response.json()""" return code.replace(sync_snippet, async_snippet) def _adapt_list_deployed_models(code: str) -> str: sync_snippet = """ for framework in frameworks: response = get_session().get(f"{INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token)) hf_raise_for_status(response) _unpack_response(framework, response.json())""".strip() async_snippet = """ async def _fetch_framework(framework: str) -> None: async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client: response = await client.get(f"{INFERENCE_ENDPOINT}/framework/{framework}", proxy=self.proxies) response.raise_for_status() _unpack_response(framework, await response.json()) import asyncio await asyncio.gather(*[_fetch_framework(framework) for framework in frameworks])""".strip() return code.replace(sync_snippet, async_snippet) def _adapt_info_and_health_endpoints(code: str) -> str: info_sync_snippet = """ response = get_session().get(url, headers=build_hf_headers(token=self.token)) hf_raise_for_status(response) return response.json()""" info_async_snippet = """ async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client: response = await client.get(url, proxy=self.proxies) response.raise_for_status() return await response.json()""" code = code.replace(info_sync_snippet, info_async_snippet) health_sync_snippet = """ response = get_session().get(url, headers=build_hf_headers(token=self.token)) return response.status_code == 200""" health_async_snippet = """ async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client: response = await client.get(url, proxy=self.proxies) return response.status == 200""" return code.replace(health_sync_snippet, health_async_snippet) def _add_get_client_session(code: str) -> str: # Add trust_env as parameter code = _add_before(code, "proxies: Optional[Any] = None,", "trust_env: bool = False,") code = _add_before(code, "\n self.proxies = proxies\n", "\n self.trust_env = trust_env") # Document `trust_env` parameter code = _add_before( code, "\n proxies (`Any`, `optional`):", """ trust_env ('bool', 'optional'): Trust environment settings for proxy configuration if the parameter is `True` (`False` by default).""", ) # insert `_get_client_session` before `get_endpoint_info` method client_session_code = """ def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession": aiohttp = _import_aiohttp() client_headers = self.headers.copy() if headers is not None: client_headers.update(headers) # Return a new aiohttp ClientSession with correct settings. session = aiohttp.ClientSession( headers=client_headers, cookies=self.cookies, timeout=aiohttp.ClientTimeout(self.timeout), trust_env=self.trust_env, ) # Keep track of sessions to close them later self._sessions[session] = set() # Override the `._request` method to register responses to be closed session._wrapped_request = session._request async def _request(method, url, **kwargs): response = await session._wrapped_request(method, url, **kwargs) self._sessions[session].add(response) return response session._request = _request # Override the 'close' method to # 1. close ongoing responses # 2. deregister the session when closed session._close = session.close async def close_session(): for response in self._sessions[session]: response.close() await session._close() self._sessions.pop(session, None) session.close = close_session return session """ code = _add_before(code, "\n async def get_endpoint_info(", client_session_code) # Add self._sessions attribute in __init__ code = _add_before( code, "\n def __repr__(self):\n", "\n # Keep track of the sessions to close them properly" "\n self._sessions: Dict['ClientSession', Set['ClientResponse']] = dict()", ) return code def _adapt_proxy_client(code: str) -> str: return code.replace( "def __init__(self, client: InferenceClient):", "def __init__(self, client: AsyncInferenceClient):", ) def _add_before(code: str, pattern: str, addition: str) -> str: index = code.find(pattern) assert index != -1, f"Pattern '{pattern}' not found in code." return code[:index] + addition + code[index:] if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--update", action="store_true", help=( "Whether to re-generate `./src/huggingface_hub/inference/_generated/_async_client.py` if a change is" " detected." ), ) args = parser.parse_args() check_async_client(update=args.update) huggingface_hub-0.31.1/utils/generate_inference_types.py000066400000000000000000000333331500667546600234730ustar00rootroot00000000000000# coding=utf-8 # Copyright 2024-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains a tool to generate `src/huggingface_hub/inference/_generated/types`.""" import argparse import re from pathlib import Path from typing import Dict, List, Literal, NoReturn, Optional import libcst as cst from helpers import check_and_update_file_content, format_source_code huggingface_hub_folder_path = Path(__file__).parents[1] / "src" / "huggingface_hub" INFERENCE_TYPES_FOLDER_PATH = huggingface_hub_folder_path / "inference" / "_generated" / "types" MAIN_INIT_PY_FILE = huggingface_hub_folder_path / "__init__.py" REFERENCE_PACKAGE_EN_PATH = ( Path(__file__).parents[1] / "docs" / "source" / "en" / "package_reference" / "inference_types.md" ) REFERENCE_PACKAGE_KO_PATH = ( Path(__file__).parents[1] / "docs" / "source" / "ko" / "package_reference" / "inference_types.md" ) IGNORE_FILES = [ "__init__.py", "base.py", ] BASE_DATACLASS_REGEX = re.compile( r""" ^@dataclass \nclass\s(\w+):\n """, re.VERBOSE | re.MULTILINE, ) INHERITED_DATACLASS_REGEX = re.compile( r""" ^@dataclass_with_extra \nclass\s(\w+)\(BaseInferenceType\): """, re.VERBOSE | re.MULTILINE, ) TYPE_ALIAS_REGEX = re.compile( r""" ^(?!\s) # to make sure the line does not start with whitespace (top-level) (\w+) \s*=\s* (.+) $ """, re.VERBOSE | re.MULTILINE, ) OPTIONAL_FIELD_REGEX = re.compile(r": Optional\[(.+)\]$", re.MULTILINE) INIT_PY_HEADER = """ # This file is auto-generated by `utils/generate_inference_types.py`. # Do not modify it manually. # # ruff: noqa: F401 from .base import BaseInferenceType """ # Regex to add all dataclasses to ./src/huggingface_hub/__init__.py MAIN_INIT_PY_REGEX = re.compile( r""" \"inference\._generated\.types\":\s*\[ # module name (.*?) # all dataclasses listed \] # closing bracket """, re.MULTILINE | re.VERBOSE | re.DOTALL, ) # List of classes that are shared across multiple modules # This is used to fix the naming of the classes (to make them unique by task) SHARED_CLASSES = [ "BoundingBox", "ClassificationOutputTransform", "ClassificationOutput", "GenerationParameters", "TargetSize", "EarlyStoppingEnum", ] REFERENCE_PACKAGE_EN_CONTENT = """ # Inference types This page lists the types (e.g. dataclasses) available for each task supported on the Hugging Face Hub. Each task is specified using a JSON schema, and the types are generated from these schemas - with some customization due to Python requirements. Visit [@huggingface.js/tasks](https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks) to find the JSON schemas for each task. This part of the lib is still under development and will be improved in future releases. {types} """ REFERENCE_PACKAGE_KO_CONTENT = """ # 추론 타입[[inference-types]] 이 페이지에는 Hugging Face Hub에서 지원하는 타입(예: 데이터 클래스)이 나열되어 있습니다. 각 작업은 JSON 스키마를 사용하여 지정되며, 이러한 스키마에 의해서 타입이 생성됩니다. 이때 Python 요구 사항으로 인해 일부 사용자 정의가 있을 수 있습니다. 각 작업의 JSON 스키마를 확인하려면 [@huggingface.js/tasks](https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks)를 확인하세요. 라이브러리에서 이 부분은 아직 개발 중이며, 향후 릴리즈에서 개선될 예정입니다. {types} """ def _replace_class_name(content: str, cls: str, new_cls: str) -> str: """ Replace the class name `cls` with the new class name `new_cls` in the content. """ pattern = rf""" (? str: content = content.replace( "\nfrom dataclasses import", "\nfrom .base import BaseInferenceType, dataclass_with_extra\nfrom dataclasses import", ) content = BASE_DATACLASS_REGEX.sub(r"@dataclass_with_extra\nclass \1(BaseInferenceType):\n", content) return content def _delete_empty_lines(content: str) -> str: return "\n".join([line for line in content.split("\n") if line.strip()]) def _fix_naming_for_shared_classes(content: str, module_name: str) -> str: for cls in SHARED_CLASSES: # No need to fix the naming of a shared class if it's not used in the module if cls not in content: continue # Update class definition # Very hacky way to build "AudioClassificationOutputElement" instead of "ClassificationOutput" new_cls = "".join(part.capitalize() for part in module_name.split("_")) if "Classification" in new_cls: # to avoid "ClassificationClassificationOutput" new_cls += cls.removeprefix("Classification") else: new_cls += cls if new_cls.endswith("ClassificationOutput"): # to get "AudioClassificationOutputElement" new_cls += "Element" content = _replace_class_name(content, cls, new_cls) return content def _fix_text2text_shared_parameters(content: str, module_name: str) -> str: if module_name in ("summarization", "translation"): content = content.replace( "Text2TextGenerationParameters", f"{module_name.capitalize()}GenerationParameters", ) content = content.replace( "Text2TextGenerationTruncationStrategy", f"{module_name.capitalize()}GenerationTruncationStrategy", ) return content def _make_optional_fields_default_to_none(content: str): lines = [] for line in content.split("\n"): if "Optional[" in line and not line.endswith("None"): line += " = None" lines.append(line) return "\n".join(lines) def _list_dataclasses(content: str) -> List[str]: """List all dataclasses defined in the module.""" return INHERITED_DATACLASS_REGEX.findall(content) def _list_type_aliases(content: str) -> List[str]: """List all type aliases defined in the module.""" return [alias_class for alias_class, _ in TYPE_ALIAS_REGEX.findall(content)] class DeprecatedRemover(cst.CSTTransformer): def is_deprecated(self, docstring: Optional[str]) -> bool: """Check if a docstring contains @deprecated.""" return docstring is not None and "@deprecated" in docstring.lower() def get_docstring(self, body: List[cst.BaseStatement]) -> Optional[str]: """Extract docstring from a body of statements.""" if not body: return None first = body[0] if isinstance(first, cst.SimpleStatementLine): expr = first.body[0] if isinstance(expr, cst.Expr) and isinstance(expr.value, cst.SimpleString): return expr.value.evaluated_value return None def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> Optional[cst.ClassDef]: """Handle class definitions - remove if deprecated.""" docstring = self.get_docstring(original_node.body.body) if self.is_deprecated(docstring): return cst.RemoveFromParent() new_body = [] statements = list(updated_node.body.body) i = 0 while i < len(statements): stmt = statements[i] # Check if this is a field (AnnAssign) if isinstance(stmt, cst.SimpleStatementLine) and isinstance(stmt.body[0], cst.AnnAssign): # Look ahead for docstring next_docstring = None if i + 1 < len(statements): next_docstring = self.get_docstring([statements[i + 1]]) if self.is_deprecated(next_docstring): i += 2 # Skip both the field and its docstring continue new_body.append(stmt) i += 1 if not new_body: return cst.RemoveFromParent() return updated_node.with_changes(body=updated_node.body.with_changes(body=new_body)) def _clean_deprecated_fields(content: str) -> str: """Remove deprecated classes and fields using libcst.""" source_tree = cst.parse_module(content) transformer = DeprecatedRemover() modified_tree = source_tree.visit(transformer) return modified_tree.code def fix_inference_classes(content: str, module_name: str) -> str: content = _inherit_from_base(content) content = _delete_empty_lines(content) content = _fix_naming_for_shared_classes(content, module_name) content = _fix_text2text_shared_parameters(content, module_name) content = _make_optional_fields_default_to_none(content) return content def create_init_py(dataclasses: Dict[str, List[str]]): """Create __init__.py file with all dataclasses.""" content = INIT_PY_HEADER content += "\n" content += "\n".join( [f"from .{module} import {', '.join(dataclasses_list)}" for module, dataclasses_list in dataclasses.items()] ) return content def add_dataclasses_to_main_init(content: str, dataclasses: Dict[str, List[str]]): dataclasses_list = sorted({cls for classes in dataclasses.values() for cls in classes}) dataclasses_str = ", ".join(f"'{cls}'" for cls in dataclasses_list) return MAIN_INIT_PY_REGEX.sub(f'"inference._generated.types": [{dataclasses_str}]', content) def generate_reference_package(dataclasses: Dict[str, List[str]], language: Literal["en", "ko"]) -> str: """Generate the reference package content.""" per_task_docs = [] for task in sorted(dataclasses.keys()): lines = [f"[[autodoc]] huggingface_hub.{cls}" for cls in sorted(dataclasses[task])] lines_str = "\n\n".join(lines) if language == "en": # e.g. '## audio_classification' per_task_docs.append(f"\n## {task}\n\n{lines_str}\n\n") elif language == "ko": # e.g. '## audio_classification[[huggingface_hub.AudioClassificationInput]]' per_task_docs.append(f"\n## {task}[[huggingface_hub.{sorted(dataclasses[task])[0]}]]\n\n{lines_str}\n\n") else: raise ValueError(f"Language {language} is not supported.") template = REFERENCE_PACKAGE_EN_CONTENT if language == "en" else REFERENCE_PACKAGE_KO_CONTENT return template.format(types="\n".join(per_task_docs)) def check_inference_types(update: bool) -> NoReturn: """Check and update inference types. This script is used in the `make style` and `make quality` checks. """ dataclasses = {} aliases = {} for file in INFERENCE_TYPES_FOLDER_PATH.glob("*.py"): if file.name in IGNORE_FILES: continue content = file.read_text() content = _clean_deprecated_fields(content) fixed_content = fix_inference_classes(content, module_name=file.stem) formatted_content = format_source_code(fixed_content) dataclasses[file.stem] = _list_dataclasses(formatted_content) aliases[file.stem] = _list_type_aliases(formatted_content) check_and_update_file_content(file, formatted_content, update) all_classes = {module: dataclasses[module] + aliases[module] for module in dataclasses.keys()} init_py_content = create_init_py(all_classes) init_py_content = format_source_code(init_py_content) init_py_file = INFERENCE_TYPES_FOLDER_PATH / "__init__.py" check_and_update_file_content(init_py_file, init_py_content, update) main_init_py_content = MAIN_INIT_PY_FILE.read_text() updated_main_init_py_content = add_dataclasses_to_main_init(main_init_py_content, all_classes) updated_main_init_py_content = format_source_code(updated_main_init_py_content) check_and_update_file_content(MAIN_INIT_PY_FILE, updated_main_init_py_content, update) reference_package_content_en = generate_reference_package(dataclasses, "en") check_and_update_file_content(REFERENCE_PACKAGE_EN_PATH, reference_package_content_en, update) reference_package_content_ko = generate_reference_package(dataclasses, "ko") check_and_update_file_content(REFERENCE_PACKAGE_KO_PATH, reference_package_content_ko, update) print("✅ All good! (inference types)") exit(0) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--update", action="store_true", help=( "Whether to re-generate files in `./src/huggingface_hub/inference/_generated/types/` if a change is detected." ), ) args = parser.parse_args() check_inference_types(update=args.update) huggingface_hub-0.31.1/utils/helpers.py000066400000000000000000000040171500667546600200760ustar00rootroot00000000000000# coding=utf-8 # Copyright 2024-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains helpers used by the scripts in `./utils`.""" import subprocess import tempfile from pathlib import Path from ruff.__main__ import find_ruff_bin def check_and_update_file_content(file: Path, expected_content: str, update: bool): # Ensure the expected content ends with a newline to satisfy end-of-file-fixer hook expected_content = expected_content.rstrip("\n") + "\n" content = file.read_text() if file.exists() else None if content != expected_content: if update: file.write_text(expected_content) print(f" {file} has been updated. Please make sure the changes are accurate and commit them.") else: print(f"❌ Expected content mismatch in {file}.") exit(1) def format_source_code(code: str) -> str: """Format the generated source code using Ruff.""" with tempfile.TemporaryDirectory() as tmpdir: filepath = Path(tmpdir) / "tmp.py" filepath.write_text(code) ruff_bin = find_ruff_bin() if not ruff_bin: raise FileNotFoundError("Ruff executable not found.") try: subprocess.run([ruff_bin, "check", str(filepath), "--fix", "--quiet"], check=True) subprocess.run([ruff_bin, "format", str(filepath), "--quiet"], check=True) except subprocess.CalledProcessError as e: raise RuntimeError(f"Error running Ruff: {e}") return filepath.read_text() huggingface_hub-0.31.1/utils/push_repocard_examples.py000066400000000000000000000106201500667546600231650ustar00rootroot00000000000000# coding=utf-8 # Copyright 2022-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Generate and push an empty ModelCard and DatasetCard to the Hub as examples.""" import argparse from pathlib import Path import jinja2 from huggingface_hub import DatasetCard, ModelCard, hf_hub_download, upload_file, whoami from huggingface_hub.constants import REPOCARD_NAME ORG_NAME = "templates" MODEL_CARD_REPO_ID = "templates/model-card-example" DATASET_CARD_REPO_ID = "templates/dataset-card-example" def check_can_push(): """Check the user can push to the `templates/` folder with its credentials.""" try: me = whoami() except EnvironmentError: print("You must be logged in to push repo card examples.") if all(org["name"] != ORG_NAME for org in me.get("orgs", [])): print(f"❌ You must have access to organization '{ORG_NAME}' to push repo card examples.") exit(1) def push_model_card_example(overwrite: bool) -> None: """Generate an empty model card from template for documentation purposes. Do not push if content has not changed. Script is triggered in CI on main branch. Card is pushed to https://huggingface.co/templates/model-card-example. """ # Not using ModelCard directly to preserve comments in metadata part template = jinja2.Template(ModelCard.default_template_path.read_text()) content = template.render( card_data="{}", model_summary=( "This modelcard aims to be a base template for new models. " "It has been generated using [this raw template]" "(https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/modelcard_template.md?plain=1)." ), ) if not overwrite: existing_content = Path(hf_hub_download(MODEL_CARD_REPO_ID, REPOCARD_NAME, repo_type="model")).read_text() if content == existing_content: print("Model Card not pushed: did not change.") return print(f"Pushing empty Model Card to Hub: {MODEL_CARD_REPO_ID}") upload_file( path_or_fileobj=content.encode(), path_in_repo=REPOCARD_NAME, repo_id=MODEL_CARD_REPO_ID, repo_type="model", ) def push_dataset_card_example(overwrite: bool) -> None: """Generate an empty dataset card from template for documentation purposes. Do not push if content has not changed. Script is triggered in CI on main branch. Card is pushed to https://huggingface.co/datasets/templates/dataset-card-example. """ # Not using DatasetCard directly to preserve comments in metadata part template = jinja2.Template(DatasetCard.default_template_path.read_text()) content = template.render( card_data="{}", dataset_summary=( "This dataset card aims to be a base template for new datasets. " "It has been generated using [this raw template]" "(https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/datasetcard_template.md?plain=1)." ), ) if not overwrite: existing_content = Path(hf_hub_download(DATASET_CARD_REPO_ID, REPOCARD_NAME, repo_type="dataset")).read_text() if content == existing_content: print("Dataset Card not pushed: did not change.") return print(f"Pushing empty Dataset Card to Hub: {DATASET_CARD_REPO_ID}") upload_file( path_or_fileobj=content.encode(), path_in_repo=REPOCARD_NAME, repo_id=DATASET_CARD_REPO_ID, repo_type="dataset", ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--overwrite", action="store_true", help="Whether to force updating examples. By default, push to hub only if card is updated.", ) args = parser.parse_args() check_can_push() push_model_card_example(args.overwrite) push_dataset_card_example(args.overwrite)