English
wham
microsoft
katja-hofmann commited on
Commit
f1c8ee5
·
1 Parent(s): e199e7c

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. CODE_OF_CONDUCT.md +10 -0
  3. CONTRIBUTING.md +14 -0
  4. LICENSE.md +96 -0
  5. SECURITY.md +37 -0
  6. assets/Demonstrator/Fig_01.png +0 -0
  7. assets/Demonstrator/Fig_02.png +0 -0
  8. assets/Demonstrator/Fig_03.png +0 -0
  9. assets/Demonstrator/Fig_04.png +0 -0
  10. assets/Demonstrator/Fig_05.png +0 -0
  11. assets/Demonstrator/Fig_06.png +0 -0
  12. assets/Demonstrator/Fig_07.png +0 -0
  13. assets/Demonstrator/Fig_08.png +0 -0
  14. assets/Demonstrator/Fig_09.png +0 -0
  15. assets/Demonstrator/Fig_10.png +0 -0
  16. assets/Demonstrator/Fig_11.png +0 -0
  17. assets/Demonstrator/Fig_12.png +0 -0
  18. assets/Demonstrator/Fig_13.png +0 -0
  19. assets/Demonstrator/Fig_14.png +0 -0
  20. assets/Demonstrator/Fig_15.png +0 -0
  21. assets/Readme/model_capabilities.gif +3 -0
  22. assets/Readme/wham_gen_1.gif +3 -0
  23. assets/Readme/wham_gen_2.gif +3 -0
  24. assets/Readme/wham_gen_3.gif +3 -0
  25. assets/Readme/wham_gen_4.gif +3 -0
  26. assets/Readme/wham_gen_5.gif +3 -0
  27. assets/Readme/wham_gen_6.gif +3 -0
  28. assets/Readme/wham_gen_7.gif +3 -0
  29. assets/Readme/wham_gen_8.gif +3 -0
  30. assets/Readme/wham_gen_9.gif +3 -0
  31. configs/metadata_custom_tag.config +5 -0
  32. models/WHAM_1.6B_v1.ckpt +3 -0
  33. models/WHAM_200M.ckpt +3 -0
  34. requirements.txt +48 -0
  35. run_dreaming.py +264 -0
  36. run_server.py +519 -0
  37. setup_local.sh +21 -0
  38. wham/models/nn/model_blocks.py +49 -0
  39. wham/models/nn/nanoGPT.py +665 -0
  40. wham/models/pl/__init__.py +0 -0
  41. wham/models/pl/pl_base_model.py +5 -0
  42. wham/models/vqgan/taming/LICENSE +24 -0
  43. wham/models/vqgan/taming/model.py +696 -0
  44. wham/models/vqgan/taming/quantize.py +146 -0
  45. wham/models/vqgan/taming_vq_model.py +264 -0
  46. wham/models/vqgan/vqgan.py +236 -0
  47. wham/models/vqgan/vqgan_models.py +311 -0
  48. wham/models/vqvae/vqvae_utils.py +154 -0
  49. wham/models/wham_base/__init__.py +0 -0
  50. wham/models/wham_base/encode_predict_decode_base.py +256 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ fonts/arial.ttf filter=lfs diff=lfs merge=lfs -text
37
+ *.gif filter=lfs diff=lfs merge=lfs -text
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Microsoft Open Source Code of Conduct
2
+
3
+ This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4
+
5
+ Resources:
6
+
7
+ - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8
+ - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9
+ - Contact [[email protected]](mailto:[email protected]) with questions or concerns
10
+ - Employees can reach out at [aka.ms/opensource/moderation-support](https://aka.ms/opensource/moderation-support)
CONTRIBUTING.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing
2
+
3
+ This project welcomes contributions and suggestions. Most contributions require you to
4
+ agree to a Contributor License Agreement (CLA) declaring that you have the right to,
5
+ and actually do, grant us the rights to use your contribution. For details, visit
6
+ https://cla.microsoft.com.
7
+
8
+ When you submit a pull request, a CLA-bot will automatically determine whether you need
9
+ to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the
10
+ instructions provided by the bot. You will only need to do this once across all repositories using our CLA.
11
+
12
+ This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
13
+ For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
14
+ or contact [[email protected]](mailto:[email protected]) with any additional questions or comments.
LICENSE.md ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MICROSOFT RESEARCH LICENSE TERMS
2
+
3
+ **IF YOU LIVE IN THE UNITED STATES, PLEASE READ THE “BINDING ARBITRATION AND CLASS ACTION WAIVER” SECTION BELOW. IT AFFECTS HOW DISPUTES ARE RESOLVED.**
4
+
5
+ These license terms are an agreement between you and Microsoft Corporation (or one of its affiliates). They apply to the source code, object code, machine learning models, or data (collectively “Materials”) that accompany this license. IF YOU COMPLY WITH THESE LICENSE TERMS, YOU HAVE THE RIGHTS BELOW. BY USING THE MATERIALS, YOU ACCEPT THESE TERMS.
6
+
7
+ ## 1) INSTALLATION AND USE RIGHTS TO THE MATERIALS.
8
+
9
+ Subject to the terms of this agreement, you have the below rights, if applicable, to use the Materials solely for non-commercial, non-revenue generating, research purposes:
10
+
11
+ a) **Source Code.** If source code is included, you may use and modify the source code, but you may not distribute the source code.
12
+
13
+ b) **Object Code.** If object code is included, you may use the object code, but you may not distribute the object code.
14
+
15
+ c) **Models.** If machine learning model(s) are included, you may use the model(s), but you may not distribute the models.
16
+
17
+ d) **Data.** If data is included, you may use the data, but your use must be consistent with the consent under which the data was provided and/or gathered and you may not modify or distribute the data.
18
+
19
+ ## 2) SCOPE OF LICENSE.
20
+
21
+ The Materials are licensed, not sold. Microsoft reserves all other rights. Unless applicable law gives you more rights despite this limitation, you will not (and have no right to):
22
+
23
+ a) Work around any technical limitations in the Materials that only allow you to use it in certain ways;
24
+
25
+ b) Reverse engineer, decompile or disassemble the Materials;
26
+
27
+ c) Remove, minimize, block, or modify any notices of Microsoft or its suppliers in the Materials;
28
+
29
+ d) Use the Materials in any way that is against the law or to create or propagate malware; or
30
+
31
+ e) Share, publish, distribute or lend the Materials, provide the Materials as a stand-alone hosted solution for others to use, or transfer the Materials or this agreement to any third party.
32
+
33
+ ## 3) PERSONAL DATA.
34
+
35
+ If the data (set forth in Section 1(d) above) includes or is found to include any data that enables any ability to identify an individual ("Personal Data"), you will not use such Personal Data for any purpose other than was authorized and consented to by the data subject/research participant. You will not use Personal Data to contact any person. You will keep Personal Data in strict confidence. You will not share any Personal Data that is collected or in your possession with any third party for any reason and as required under the original consent agreement. Further, you will destroy the Personal Data and any backup or copies, **immediately upon the completion of your research.**
36
+
37
+ ## 4) LICENSE TO MICROSOFT.
38
+
39
+ Notwithstanding the limitations in Section 1, you may distribute your modifications back to Microsoft, and if you do provide Microsoft with modifications of the Materials, you hereby grant Microsoft, without any restrictions or limitations, a non-exclusive, perpetual, irrevocable, royalty-free, assignable and sub-licensable license, to reproduce, publicly perform or display, install, use, modify, post, distribute, make and have made, sell and transfer such modifications and derivatives for any purpose.
40
+
41
+ ## 5) PUBLICATION.
42
+
43
+ You may publish (or present papers or articles) on your results from using the Materials provided that no material or substantial portion of the Materials is included in any such publication or presentation.
44
+
45
+ ## 6) FEEDBACK.
46
+
47
+ Any feedback about the Materials provided by you to us is voluntarily given, and Microsoft shall be free to use the feedback as it sees fit without obligation or restriction of any kind, even if the feedback is designated by you as confidential. **Additional** Such feedback shall be considered a contribution and licensed to Microsoft under the terms of Section 4 above.
48
+
49
+ ## 7) COMPLIANCE WITH TRADE LAWS.
50
+
51
+ You acknowledge that the Materials may be subject to applicable trade laws in one or more countries. You will comply with all relevant laws and regulations applicable to the import or export of the Materials, including but not limited to, trade laws such as the U.S. Export Administration Regulations or other end-user, end use, and destination restrictions by the U.S. and other governments, as well as sanctions regulations administered by the U.S. Office of Foreign Assets Control. Microsoft may suspend or terminate the agreement immediately to the extent that Microsoft reasonably concludes that continued performance would violate trade laws or put it at risk of becoming subject to sanctions or penalties under trade laws. For additional information, see www.microsoft.com/exporting.
52
+
53
+ ## 8) SUPPORT SERVICES.
54
+
55
+ Microsoft is not obligated under this agreement to provide any support services for the Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
56
+
57
+ ## 9) BINDING ARBITRATION AND CLASS ACTION WAIVER.
58
+
59
+ **This Section applies if you live in (or, if a business, your principal place of business is in) the United States.** If you and Microsoft have a dispute, you and Microsoft agree to try for 60 days to resolve it informally. If you and Microsoft can’t, you and Microsoft agree to **binding individual arbitration before the American Arbitration Association** under the Federal Arbitration Act ("FAA"), and not to **sue in court in front of a judge or jury.** Instead, a neutral arbitrator will decide. **Class action lawsuits, class-wide arbitrations, private attorney-general actions,** and any other proceeding where someone acts in a representative capacity **are not allowed;** nor is combining individual proceedings without the consent of all parties. The complete Arbitration Agreement contains more terms and is at aka.ms/arb-agreement-1. You and Microsoft agree to these terms.
60
+
61
+ ## 10) ENTIRE AGREEMENT.
62
+
63
+ This agreement, and any other terms Microsoft may provide for supplements, updates, or third-party applications, is the entire agreement for the Materials.
64
+
65
+ ## 11) APPLICABLE LAW AND PLACE TO RESOLVE DISPUTES.
66
+
67
+ If you acquired the Materials in the United States or Canada, the laws of the state or province where you live (or, if a business, where your principal place of business is located) govern the interpretation of this agreement, claims for its breach, and all other claims (including consumer protection, unfair competition, and tort claims), regardless of conflict of laws principles, except that the FAA governs everything related to arbitration. If you acquired the Materials in any other country, its laws apply, except that the FAA governs everything related to arbitration. If U.S. federal jurisdiction exists, you and Microsoft consent to exclusive jurisdiction and venue in the federal court in King County, Washington for all disputes heard in court (excluding arbitration). If not, you and Microsoft consent to exclusive jurisdiction and venue in the Superior Court of King County, Washington for all disputes heard in court (excluding arbitration).
68
+
69
+ ## 12) CONSUMER RIGHTS; REGIONAL VARIATIONS.
70
+
71
+ This agreement describes certain legal rights. You may have other rights, including consumer rights, under the laws of your state, province, or country. Separate and apart from your relationship with Microsoft, you may also have rights with respect to the party from which you acquired the Materials. This agreement does not change those other rights if the laws of your state, province, or country do not permit it to do so. For example, if you acquired the Materials in one of the below regions, or mandatory country law applies, then the following provisions apply to you:
72
+
73
+ a) **Australia.** You have statutory guarantees under the Australian Consumer Law and nothing in this agreement is intended to affect those rights.
74
+
75
+ b) **Canada.** If you acquired this software in Canada, you may stop receiving updates by turning off the automatic update feature, disconnecting your device from the Internet (if and when you re-connect to the Internet, however, the Materials will resume checking for and installing updates), or uninstalling the Materials. The product documentation, if any, may also specify how to turn off updates for your specific device or software.
76
+
77
+ c) **Germany and Austria.**
78
+ i. **Warranty.** The properly licensed software will perform substantially as described in any Microsoft materials that accompany the Materials. However, Microsoft gives no contractual guarantee in relation to the licensed software.
79
+ ii. **Limitation of Liability.** In case of intentional conduct, gross negligence, claims based on the Product Liability Act, as well as, in case of death or personal or physical injury, Microsoft is liable according to the statutory law.
80
+
81
+ Subject to the foregoing clause (ii), Microsoft will only be liable for slight negligence if Microsoft is in breach of such material contractual obligations, the fulfillment of which facilitate the due performance of this agreement, the breach of which would endanger the purpose of this agreement and the compliance with which a party may constantly trust in (so-called "cardinal obligations"). In other cases of slight negligence, Microsoft will not be liable for slight negligence.
82
+
83
+ ## 13) DISCLAIMER OF WARRANTY.
84
+
85
+ THE MATERIALS ARE LICENSED "AS IS." YOU BEAR THE RISK OF USING THEM. MICROSOFT GIVES NO EXPRESS WARRANTIES, GUARANTEES, OR CONDITIONS. TO THE EXTENT PERMITTED UNDER APPLICABLE LAWS, MICROSOFT EXCLUDES ALL IMPLIED WARRANTIES, INCLUDING MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT.
86
+
87
+ ## 14) LIMITATION ON AND EXCLUSION OF DAMAGES.
88
+
89
+ IF YOU HAVE ANY BASIS FOR RECOVERING DAMAGES DESPITE THE PRECEDING DISCLAIMER OF WARRANTY, YOU CAN RECOVER FROM MICROSOFT AND ITS SUPPLIERS ONLY DIRECT DAMAGES UP TO U.S. $5.00. YOU CANNOT RECOVER ANY OTHER DAMAGES, INCLUDING CONSEQUENTIAL, LOST PROFITS, SPECIAL, INDIRECT OR INCIDENTAL DAMAGES.
90
+
91
+ This limitation applies to:
92
+ - (a) anything related to the Materials, services, content (including code) on third party Internet sites, or third party applications; and
93
+ - (b) claims for breach of contract, warranty, guarantee, or condition; strict liability, negligence, or other tort; or any other claim; in each case to the extent permitted by applicable law.
94
+
95
+ It also applies even if Microsoft knew or should have known about the possibility of the damages. The above limitation or exclusion may not apply to you because your state, province, or country may not allow the exclusion or limitation of incidental, consequential, or other damages.
96
+
SECURITY.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Security
2
+
3
+ Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
4
+
5
+ If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
6
+
7
+ ## Reporting Security Issues
8
+
9
+ **Please do not report security vulnerabilities through public GitHub issues.**
10
+
11
+ Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
12
+
13
+ If you prefer to submit without logging in, send email to [[email protected]](mailto:[email protected]). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
14
+
15
+ You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
16
+
17
+ Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
18
+
19
+ * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
20
+ * Full paths of source file(s) related to the manifestation of the issue
21
+ * The location of the affected source code (tag/branch/commit or direct URL)
22
+ * Any special configuration required to reproduce the issue
23
+ * Step-by-step instructions to reproduce the issue
24
+ * Proof-of-concept or exploit code (if possible)
25
+ * Impact of the issue, including how an attacker might exploit the issue
26
+
27
+ This information will help us triage your report more quickly.
28
+
29
+ If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
30
+
31
+ ## Preferred Languages
32
+
33
+ We prefer all communications to be in English.
34
+
35
+ ## Policy
36
+
37
+ Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
assets/Demonstrator/Fig_01.png ADDED
assets/Demonstrator/Fig_02.png ADDED
assets/Demonstrator/Fig_03.png ADDED
assets/Demonstrator/Fig_04.png ADDED
assets/Demonstrator/Fig_05.png ADDED
assets/Demonstrator/Fig_06.png ADDED
assets/Demonstrator/Fig_07.png ADDED
assets/Demonstrator/Fig_08.png ADDED
assets/Demonstrator/Fig_09.png ADDED
assets/Demonstrator/Fig_10.png ADDED
assets/Demonstrator/Fig_11.png ADDED
assets/Demonstrator/Fig_12.png ADDED
assets/Demonstrator/Fig_13.png ADDED
assets/Demonstrator/Fig_14.png ADDED
assets/Demonstrator/Fig_15.png ADDED
assets/Readme/model_capabilities.gif ADDED

Git LFS Details

  • SHA256: 87cf1460b2779a1c85b70e2229a7e1e256c501a5e3db26ea74e445b9dc75e965
  • Pointer size: 132 Bytes
  • Size of remote file: 8.63 MB
assets/Readme/wham_gen_1.gif ADDED

Git LFS Details

  • SHA256: 96558d0ad8084eafaf60ee360f13fe8decfbc5ac737b0c2788c01310e81750d1
  • Pointer size: 132 Bytes
  • Size of remote file: 4.42 MB
assets/Readme/wham_gen_2.gif ADDED

Git LFS Details

  • SHA256: 1296bb4ccdac5c7d3a1e7e9adfc48a6ec255933ff252a31d4e45cd117a28aee7
  • Pointer size: 132 Bytes
  • Size of remote file: 4.15 MB
assets/Readme/wham_gen_3.gif ADDED

Git LFS Details

  • SHA256: cb8ea8b3d6c8ec737a9b03f4cd93aeb36ddddc33695849b9b83543a8c2242b6f
  • Pointer size: 132 Bytes
  • Size of remote file: 4.27 MB
assets/Readme/wham_gen_4.gif ADDED

Git LFS Details

  • SHA256: 45e895599dddae5e6d2eb31f66957726fb82662f41b149f4de206466083f5a42
  • Pointer size: 132 Bytes
  • Size of remote file: 4.3 MB
assets/Readme/wham_gen_5.gif ADDED

Git LFS Details

  • SHA256: e7e7675c737bf5cbdfb54dfcc568eeda4c4212dbe5726741205610ab29cfcabb
  • Pointer size: 132 Bytes
  • Size of remote file: 4.24 MB
assets/Readme/wham_gen_6.gif ADDED

Git LFS Details

  • SHA256: e536b1f88a92de4e116a6acd022987778f63ed5a841517758c14a0d7f2a3c2bd
  • Pointer size: 132 Bytes
  • Size of remote file: 4.09 MB
assets/Readme/wham_gen_7.gif ADDED

Git LFS Details

  • SHA256: eb7e6c63eb8c46fc8c824d93406550082b6532ea9473cd021bae72a7d6cbe7db
  • Pointer size: 132 Bytes
  • Size of remote file: 4.13 MB
assets/Readme/wham_gen_8.gif ADDED

Git LFS Details

  • SHA256: 366f3f92310f3cfa55c9f4da719b01c8399c42f7d7bb860c5f7153568e4991d5
  • Pointer size: 132 Bytes
  • Size of remote file: 3.98 MB
assets/Readme/wham_gen_9.gif ADDED

Git LFS Details

  • SHA256: 931713a1d9a9dbdef7b4a1821ef78d490282bf8475e65b39948f8b5f42dc9982
  • Pointer size: 132 Bytes
  • Size of remote file: 4.53 MB
configs/metadata_custom_tag.config ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ %Image::ExifTool::UserDefined = (
2
+ 'Image::ExifTool::XMP::xmp' => {
3
+ 'ProgramName' => { Name => 'ProgramName', Writable => 'string' }
4
+ }
5
+ );
models/WHAM_1.6B_v1.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c4997074883aa1a39a5994a7dea91fb62b2382fc039523458827adb777af8e9
3
+ size 20339650059
models/WHAM_200M.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ddb8e03a33f0849a63da030fea3de4994d95e16888993b8ab92faa904f3b31f
3
+ size 3980245067
requirements.txt ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --find-links https://download.pytorch.org/whl/torch_stable.html
2
+ aiohttp==3.9.3
3
+ aiosignal==1.3.1
4
+ async-timeout==4.0.3
5
+ attrs==23.2.0
6
+ blinker==1.7.0
7
+ certifi==2024.2.2
8
+ charset-normalizer==3.3.2
9
+ click==8.1.7
10
+ cloudpickle==3.0.0
11
+ cmake==3.28.3
12
+ einops==0.6.0
13
+ ffmpegcv==0.3.10
14
+ filelock==3.13.1
15
+ Flask==3.0.2
16
+ frozenlist==1.4.1
17
+ fsspec==2024.2.0
18
+ idna==3.6
19
+ importlib_metadata==7.0.2
20
+ itsdangerous==2.1.2
21
+ Jinja2==3.1.3
22
+ lightning-utilities==0.10.1
23
+ lit==17.0.6
24
+ MarkupSafe==2.1.5
25
+ mpmath==1.3.0
26
+ multidict==6.0.5
27
+ networkx==3.2.1
28
+ numpy==1.25.2
29
+ opencv-python==4.6.0.66
30
+ opencv-python-headless==4.9.0.80
31
+ packaging==23.2
32
+ pillow==10.2.0
33
+ pytorch-lightning==1.9.4
34
+ PyYAML==6.0.1
35
+ requests==2.31.0
36
+ sympy==1.12
37
+ tensordict==0.1.2
38
+ torch==2.0.1+cu118
39
+ torchinfo==1.7.1
40
+ torchmetrics==0.11.4
41
+ torchvision==0.15.2+cu118
42
+ tqdm==4.66.2
43
+ triton==2.0.0
44
+ typing_extensions==4.10.0
45
+ urllib3==2.2.1
46
+ Werkzeug==3.0.1
47
+ yarl==1.9.4
48
+ zipp==3.17.0
run_dreaming.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example script for running dreaming on a dataset.
3
+ The idea is that there are ground_truth ("reference") video clips, and we dream the same clips given some initial context.
4
+
5
+ After dreaming, we have two sets of videos which, barring the intrinsic noise of the game environment (e.g., randomness of other players),
6
+ should be identical if model was ideal.
7
+ """
8
+
9
+ import argparse
10
+ from pathlib import Path
11
+ import os
12
+ import subprocess
13
+
14
+ import cv2
15
+ from tensordict import TensorDict
16
+ import torch as th
17
+ from tqdm import tqdm
18
+ import numpy as np
19
+ import ffmpegcv
20
+ from PIL import Image
21
+
22
+ import wham.utils as utils
23
+
24
+
25
+ parser = argparse.ArgumentParser(description="Run dreaming.")
26
+ parser.add_argument("--model_path", type=str, required=True, help="Path to the model checkpoint.")
27
+ parser.add_argument("--data_path", type=str, required=True, help="Path to the directory that contains the ground truth data to dream for.")
28
+ parser.add_argument("--output", type=str, default="dreaming_output", help="Path to the directory where output should be put.")
29
+ parser.add_argument("--max_files", type=int, default=None, help="Maximum number of files to process.")
30
+ parser.add_argument("--metadata_config", type=str, default="configs/metadata_custom_tag.config", help="Path to metadata tag config for origin field.")
31
+
32
+
33
+ parser.add_argument(
34
+ "--protocol",
35
+ type=str,
36
+ default="base",
37
+ choices=["base", "comprehensive"],
38
+ help="What protocol to use for the dreaming. base = action conditioned, comprehensive = dream actions as well.",
39
+ )
40
+ parser.add_argument("--batch_size", type=int, default=1, help="Batch size for dreaming. Higher batch_size uses more VRAM but overall is faster.")
41
+ parser.add_argument("--context_length", type=int, default=10, help="Number of frames to use an initial context.")
42
+ parser.add_argument("--steps_to_dream", type=int, default=10, help="Batch size for dreaming.")
43
+
44
+ parser.add_argument("--sampling_temperature", type=float, default=0.9, help="Temperature for sampling from the model.")
45
+ parser.add_argument("--sampling_top_k", type=int, default=None, help="Top-k for sampling from the model.")
46
+ parser.add_argument("--sampling_top_p", type=float, default=None, help="Top-p for sampling from the model.")
47
+
48
+
49
+ def get_context_data(image_context, action_context, action_sequences):
50
+ # Make sure we have CHW images:
51
+ assert image_context.shape[-3] == 3, "Image context should be CHW"
52
+
53
+ image_context = th.from_numpy(image_context).cuda()
54
+ action_data = th.from_numpy(action_context).float().cuda()
55
+ action_sequences = th.from_numpy(action_sequences).float().cuda() if action_sequences is not None else None
56
+
57
+ return TensorDict({"images": image_context, "actions_output": action_data}, batch_size=image_context.shape[:2])
58
+
59
+
60
+ def add_video_metadata(file_path, metadata_config):
61
+ # Construct the exiftool command
62
+ cmd = [
63
+ 'exiftool',
64
+ '-config', metadata_config,
65
+ f'-ProgramName=\"{utils.PROGRAM_NAME}\"',
66
+ '-overwrite_original',
67
+ file_path
68
+ ]
69
+
70
+ try:
71
+ # Execute the exiftool command
72
+ subprocess.run(cmd, check=True)
73
+ print(f"Metadata modified successfully.")
74
+ # Print the new file metadata
75
+ cmd_output = [
76
+ 'exiftool',
77
+ file_path
78
+ ]
79
+ subprocess.run(cmd_output, check=True)
80
+ except subprocess.CalledProcessError as e:
81
+ print(f"Error modifying metadata: {e}")
82
+
83
+
84
+ @th.no_grad()
85
+ def do_dreaming(model, image_context, action_context, args, action_sequences=None):
86
+ """
87
+ image_contect and action_context provide the initial context for the model to dream from.
88
+
89
+ If action_sequences (batch_size, args.steps_to_dream, action_dim) is provided, then model will be prompted with these actions.
90
+ """
91
+ context_data = get_context_data(image_context, action_context, action_sequences)
92
+ encoded_context_data = model.encode_context(context_data)
93
+
94
+ encoded_action_sequences = None
95
+ if action_sequences is not None:
96
+ assert action_sequences.shape[1] == args.steps_to_dream, "action_sequences should have shape (batch_size, args.steps_to_dream, action_dim)"
97
+ action_sequences = TensorDict({"actions_output": action_sequences}, batch_size=action_sequences.shape[:2]).cuda()
98
+ encoded_action_sequences = model.encode_context(action_sequences)
99
+
100
+ encoded_dreamt_steps = []
101
+
102
+ for dream_step in range(args.steps_to_dream):
103
+ encoded_predicted_step, _ = model.predictor.predict_next_step(
104
+ encoded_context_data, temperature=args.sampling_temperature, top_k=args.sampling_top_k, top_p=args.sampling_top_p, min_tokens_to_keep=1
105
+ )
106
+
107
+ # Remove first step from context if we are at the max context length:
108
+ if encoded_context_data.shape[1] == args.context_length:
109
+ encoded_context_data = encoded_context_data[:, 1:]
110
+
111
+ # Add predicted image + action to the context
112
+ append_step = encoded_predicted_step
113
+ if encoded_action_sequences is not None:
114
+ # Replace predicted action with real action
115
+ append_step["actions_output"] = encoded_action_sequences["actions_output"][:, [dream_step], :]
116
+ encoded_context_data = th.cat((encoded_context_data, append_step), dim=1)
117
+
118
+ encoded_dreamt_steps.append(encoded_predicted_step)
119
+
120
+ # Decode everything
121
+ dreamed_images = []
122
+ actions_during_dream = []
123
+ for seq_i in range(args.steps_to_dream):
124
+ decoded_step = model.decode_context(encoded_dreamt_steps[seq_i])
125
+ dreamed_images.append(decoded_step["images"][:, [0]].cpu().numpy())
126
+ actions_during_dream.append(decoded_step["actions_output"][:, [0]].cpu().numpy())
127
+
128
+ dreamed_images = np.concatenate(dreamed_images, axis=1)
129
+ actions_during_dream = np.concatenate(actions_during_dream, axis=1)
130
+
131
+ return dreamed_images, actions_during_dream
132
+
133
+
134
+ @th.no_grad()
135
+ def encode_decode_images(model, images):
136
+ """
137
+ Pass ground_truth images through the encoding/decoding process of the model.
138
+ """
139
+ context = TensorDict({"images": th.from_numpy(images).cuda()}, batch_size=images.shape[:2])
140
+ output_images = []
141
+ for seq_i in range(images.shape[1]):
142
+ encoded_images = model.encode_context(context[:, [seq_i]])
143
+ decoded_images = model.decode_context(encoded_images)
144
+ output_images.append(decoded_images["images"].cpu().numpy())
145
+ return np.concatenate(output_images, axis=1)
146
+
147
+
148
+ def main(args):
149
+ total_video_length = args.context_length + args.steps_to_dream
150
+
151
+ # Now, load the model:
152
+ model_path = Path(args.model_path)
153
+ assert model_path.is_file(), "Could not find the model!"
154
+ model = utils.load_model_from_checkpoint(model_path).cuda()
155
+
156
+ # Glob the dataset to find all the ground truth segments we want to construct a dream for:
157
+ data_path = Path(args.data_path)
158
+ ground_truth_files = list(data_path.rglob("*.npz"))
159
+ num_dreams = len(ground_truth_files)
160
+
161
+ if args.max_files is not None:
162
+ # Sort to make sure we always get the same files
163
+ ground_truth_files = sorted(ground_truth_files)
164
+ ground_truth_files = ground_truth_files[: args.max_files]
165
+ num_dreams = len(ground_truth_files)
166
+
167
+ output_path = Path(args.output)
168
+ os.makedirs(output_path, exist_ok=True)
169
+
170
+ print("=" * 100)
171
+ print(f"GENERATING DREAMS OF {num_dreams} SEGMENTS")
172
+ print(f"WRITING TO {args.output}")
173
+ print("=" * 100)
174
+
175
+ dreams_created = 0
176
+ with tqdm(total=num_dreams, desc="Dreams") as pbar:
177
+ while ground_truth_files:
178
+ # Load batch_size headers:
179
+ batches = min(args.batch_size, len(ground_truth_files))
180
+ batched_image_context = []
181
+ batched_image_sequence = []
182
+ batched_action_context = []
183
+ batched_action_sequence = []
184
+ episode_names = []
185
+ for i in range(batches):
186
+ episode = ground_truth_files.pop()
187
+ episode_names.append(episode)
188
+ try:
189
+ data = np.load(episode)
190
+ images = data["images"]
191
+ actions = data["actions"]
192
+ except Exception:
193
+ print(f"Failed to load episode {episode} - skipping.")
194
+ continue
195
+
196
+ if actions.shape[0] < total_video_length:
197
+ # We want to make sure we have ground_truth comparisons for the entire dream, so we ensure the episode is long enough
198
+ raise ValueError(f"Episode {episode} is too short to dream from. It has {actions.shape[0]} steps, but we need at least {total_video_length}.")
199
+ batched_image_context.append(images[: args.context_length])
200
+ batched_image_sequence.append(images[args.context_length: total_video_length])
201
+ batched_action_context.append(actions[: args.context_length])
202
+ batched_action_sequence.append(actions[args.context_length: total_video_length])
203
+
204
+ image_context = np.array(batched_image_context)
205
+ image_sequences = np.array(batched_image_sequence)
206
+ action_context = np.array(batched_action_context)
207
+ action_sequences = np.array(batched_action_sequence)
208
+
209
+ if args.protocol == "comprehensive":
210
+ # We do not need to pass in the action sequences for comprehensive protocol
211
+ action_sequences = None
212
+
213
+ full_image_sequence = np.concatenate((image_context, image_sequences), axis=1)
214
+
215
+ dreamt_images, actions_during_dream = do_dreaming(model, image_context, action_context, args, action_sequences=action_sequences)
216
+ encoded_decoded_images_batch = encode_decode_images(model, full_image_sequence)
217
+
218
+ pbar.update(batches)
219
+ dreams_created += batches
220
+
221
+ # Save the dreams:
222
+ # We are aiming to mimic the folder structure of the ground truth dataset, so use the episode names
223
+ # but make them relative to our output folder:
224
+ for i, dream in enumerate(dreamt_images):
225
+ episode = episode_names[i]
226
+ output_file = output_path / episode.relative_to(data_path)
227
+ output_file.parent.mkdir(parents=True, exist_ok=True)
228
+ np.savez(
229
+ output_file,
230
+ context_length=args.context_length,
231
+ steps_to_dream=args.steps_to_dream,
232
+ raw_context=image_context[i],
233
+ dreamt_images=dream,
234
+ all_actions=np.concatenate((action_context[i], actions_during_dream[i])),
235
+ encoded_decoded_ground_truth_images=encoded_decoded_images_batch[i],
236
+ )
237
+
238
+ video_file = str(output_file.with_suffix(".mp4"))
239
+ writer = ffmpegcv.VideoWriter(video_file, None, utils.DREAMING_FPS)
240
+ full_sequence = np.concatenate((image_context[i], dream), axis=0)
241
+ for frame in full_sequence:
242
+ img = frame.transpose(1, 2, 0).astype(np.uint8).copy()
243
+ # Please DO NOT remove this watermark. This will infringe upon the repo's license agreement
244
+ (text_width, _), _ = cv2.getTextSize(utils.WATERMARK_TEXT, utils.WATERMARK_FONT, utils.WATERMARK_FONT_SCALE, utils.WATERMARK_FONT_THICKNESS)
245
+ x = img.shape[1] - text_width - 10 # 10 pixels from the right edge
246
+ y = img.shape[0] - 10 # 10 pixels from the bottom edge
247
+ cv2.putText(img, utils.WATERMARK_TEXT, (x, y), utils.WATERMARK_FONT, utils.WATERMARK_FONT_SCALE, utils.WATERMARK_FONT_COLOR, utils.WATERMARK_FONT_THICKNESS)
248
+
249
+ # Add image metadata
250
+ pil_image = Image.fromarray(img)
251
+ pil_image.info['Id'] = 0x0131
252
+ pil_image.info['Type'] = 2
253
+ pil_image.info['Value'] = utils.PROGRAM_NAME.encode("utf-8")
254
+ pil_image.info['Len'] = len(utils.PROGRAM_NAME) + 1
255
+
256
+ # Convert pil_image to a CV2 format for the video writer
257
+ cv_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
258
+ writer.write(cv_image)
259
+ writer.release()
260
+ add_video_metadata(video_file, args.metadata_config)
261
+
262
+ if __name__ == "__main__":
263
+ args = parser.parse_args()
264
+ main(args)
run_server.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from dataclasses import dataclass, field
3
+ import json
4
+ import copy
5
+ import multiprocessing as mp
6
+ import uuid
7
+ from datetime import datetime, timedelta
8
+ from collections import defaultdict, deque
9
+ import io
10
+ import zipfile
11
+ import queue
12
+ import time
13
+ import random
14
+ import logging
15
+
16
+ from tensordict import TensorDict
17
+ import cv2
18
+ from flask import Flask, request, make_response, send_file
19
+ from PIL import Image
20
+ import torchvision.transforms as T
21
+ import numpy as np
22
+ import torch as th
23
+
24
+ from wham.utils import load_model_from_checkpoint, POS_BINS_BOUNDARIES, POS_BINS_MIDDLE
25
+
26
+ logging.basicConfig(level=logging.INFO)
27
+
28
+ parser = argparse.ArgumentParser(description="Simple Dreamer")
29
+ parser.add_argument("--model", type=str, required=True, help="Path to the model file for the local runs")
30
+ parser.add_argument("--debug", action="store_true", help="Enable flask debug mode.")
31
+ parser.add_argument("--random_model", action="store_true", help="Use randomly initialized model instead of the provided one")
32
+ parser.add_argument("--port", type=int, default=5000)
33
+
34
+ parser.add_argument("--max_concurrent_jobs", type=int, default=30, help="Maximum number of jobs that can be run concurrently on this server.")
35
+ parser.add_argument("--max_dream_steps_per_job", type=int, default=10, help="Maximum number of dream steps each job can request.")
36
+ parser.add_argument("--max_job_lifespan", type=int, default=60 * 10, help="Maximum number of seconds we keep run around if not polled.")
37
+
38
+ parser.add_argument("--image_width", type=int, default=300, help="Width of the image")
39
+ parser.add_argument("--image_height", type=int, default=180, help="Height of the image")
40
+
41
+ parser.add_argument("--max_batch_size", type=int, default=3, help="Maximum batch size for the dreamer workers")
42
+
43
+ PREDICTION_JSON_FILENAME = "predictions.json"
44
+ # Minimum time between times we check when to delete jobs. We do this when adding new jobs.
45
+ JOB_CLEANUP_CHECK_RATE = timedelta(seconds=10)
46
+
47
+ MAX_CANCELLED_ID_QUEUE_SIZE = 100
48
+
49
+ DEFAULT_SAMPLING_SETTINGS = {
50
+ "temperature": 0.9,
51
+ "top_k": None,
52
+ "top_p": 1.0,
53
+ "max_context_length": 10,
54
+ }
55
+
56
+
57
+ def float_or_none(string):
58
+ if string.lower() == "none":
59
+ return None
60
+ return float(string)
61
+
62
+
63
+ def be_image_preprocess(image, target_width, target_height):
64
+ # If target_width and target_height are specified, resize the image.
65
+ if target_width is not None and target_height is not None:
66
+ # Make sure we do not try to resize if the image is already the correct size.
67
+ if image.shape[1] != target_width or image.shape[0] != target_height:
68
+ image = cv2.resize(image, (target_width, target_height))
69
+ return np.transpose(image, (2, 0, 1))
70
+
71
+
72
+ def action_vector_to_be_action_vector(action):
73
+ # Preprocess a BE action vector from 16 numbers with:
74
+ # 12 buttons [0, 1] and 4 stick directions [-1, 1]
75
+ # to discrete actions valid for the token model
76
+ # 12 buttons [0, 1] and 4 stick directions {discrete bin}
77
+ action[-4:] = np.digitize(action[-4:], bins=POS_BINS_BOUNDARIES) - 1
78
+ return action
79
+
80
+
81
+ def be_action_vector_to_action_vector(action):
82
+ # Preprocess a BE action vector into unified space
83
+ for stick_index in range(-4, 0):
84
+ action[stick_index] = POS_BINS_MIDDLE[int(action[stick_index])]
85
+ return action
86
+
87
+
88
+
89
+ @dataclass
90
+ class DreamJob:
91
+ job_id: str
92
+ sampling_settings: dict
93
+ num_predictions_remaining: int
94
+ num_predictions_done: int
95
+ # (B, T, C, H, W)
96
+ context_images: th.Tensor
97
+ context_actions: th.Tensor
98
+ # Tokens that will replace the context_images if they are provided
99
+ context_tokens: list
100
+ # This will replace the dreamed action if provided.
101
+ # For every step, we remove the first action until exhausted
102
+ actions_to_take: th.Tensor = None
103
+
104
+
105
+ @dataclass
106
+ class DreamJobResult:
107
+ job_id: str
108
+ dream_step_index: int
109
+ # (B, 1, C, H, W)
110
+ dreamt_image: th.Tensor
111
+ dreamt_action: th.Tensor
112
+ dreamt_tokens: th.Tensor
113
+ result_creation_time: datetime = field(default_factory=datetime.now)
114
+
115
+
116
+
117
+ def setup_and_load_model_be_model(args):
118
+ model = load_model_from_checkpoint(args.model)
119
+ th.set_float32_matmul_precision("high")
120
+ th.backends.cuda.matmul.allow_tf32 = True
121
+ return model
122
+
123
+
124
+ def get_job_batchable_information(job):
125
+ """Return comparable object of job information. Used for batching"""
126
+ context_length = job.context_images.shape[1]
127
+ return (context_length, job.sampling_settings)
128
+
129
+
130
+ def fetch_list_of_batchable_jobs(job_queue, cancelled_ids_set, max_batch_size, timeout=1):
131
+ """Return a list of jobs (or empty list) that can be batched together"""
132
+ batchable_jobs = []
133
+ required_job_info = None
134
+ while len(batchable_jobs) < max_batch_size:
135
+ try:
136
+ job = job_queue.get(timeout=timeout)
137
+ except queue.Empty:
138
+ break
139
+ # If pipe breaks, also gracefully return
140
+ except OSError:
141
+ break
142
+ if job.job_id in cancelled_ids_set:
143
+ # This job was cancelled, so discard it completely
144
+ continue
145
+ job_info = get_job_batchable_information(job)
146
+ if required_job_info is None:
147
+ required_job_info = job_info
148
+ elif required_job_info != job_info:
149
+ # This job is not batchable, put it back
150
+ job_queue.put(job)
151
+ # we assume here that, generally, the others jobs would also be
152
+ # invalid. So we just return the batchable jobs we have instead
153
+ # of going through more.
154
+ break
155
+ batchable_jobs.append(job)
156
+ return batchable_jobs
157
+
158
+
159
+ def update_cancelled_jobs(cancelled_ids_queue, cancelled_ids_deque, cancelled_ids_set):
160
+ """IN-PLACE Update cancelled_ids_set with new ids from the queue"""
161
+ has_changed = False
162
+ while not cancelled_ids_queue.empty():
163
+ try:
164
+ cancelled_id = cancelled_ids_queue.get_nowait()
165
+ except queue.Empty:
166
+ break
167
+ cancelled_ids_deque.append(cancelled_id)
168
+ has_changed = True
169
+
170
+ if has_changed:
171
+ cancelled_ids_set.clear()
172
+ cancelled_ids_set.update(cancelled_ids_deque)
173
+
174
+
175
+ def predict_step(context_data, sampling_settings, model, tokens=None):
176
+ with th.no_grad():
177
+ predicted_step = model.predict_next_step(context_data, min_tokens_to_keep=1, tokens=tokens, **sampling_settings)
178
+ return predicted_step
179
+
180
+
181
+ def dreamer_worker(job_queue, result_queue, cancelled_jobs_queue, quit_flag, device_to_use, args):
182
+ logger = logging.getLogger(f"dreamer_worker {device_to_use}")
183
+ logger.info("Loading up model...")
184
+ model = setup_and_load_model_be_model(args)
185
+ model = model.to(device_to_use)
186
+ logger.info("Model loaded. Fetching results")
187
+
188
+ cancelled_ids_deque = deque(maxlen=MAX_CANCELLED_ID_QUEUE_SIZE)
189
+ cancelled_ids_set = set()
190
+
191
+ while not quit_flag.is_set():
192
+ update_cancelled_jobs(cancelled_jobs_queue, cancelled_ids_deque, cancelled_ids_set)
193
+ batchable_jobs = fetch_list_of_batchable_jobs(job_queue, cancelled_ids_set, max_batch_size=args.max_batch_size)
194
+ if len(batchable_jobs) == 0:
195
+ continue
196
+ sampling_settings = batchable_jobs[0].sampling_settings
197
+ # make better way for passing these arguments around. sampling_settings
198
+ # is passed as kwargs to predicting step, but max_context_length is not part of valid
199
+ # keys there, so we need to pop it out.
200
+ max_context_length = sampling_settings.pop("max_context_length")
201
+
202
+ images = [job.context_images[:, :max_context_length] for job in batchable_jobs]
203
+ actions = [job.context_actions[:, :max_context_length] for job in batchable_jobs]
204
+ tokens = [job.context_tokens for job in batchable_jobs]
205
+
206
+ images = th.concat(images, dim=0).to(device_to_use)
207
+ actions = th.concat(actions, dim=0).to(device_to_use)
208
+
209
+ context_data = TensorDict({
210
+ "images": images,
211
+ "actions_output": actions
212
+ }, batch_size=images.shape[:2])
213
+
214
+ predicted_step, predicted_image_tokens = predict_step(context_data, sampling_settings, model, tokens)
215
+
216
+ predicted_step = predicted_step.cpu()
217
+ predicted_images = predicted_step["images"]
218
+ predicted_actions = predicted_step["actions_output"]
219
+ predicted_image_tokens = predicted_image_tokens.cpu()
220
+
221
+ for job_i, job in enumerate(batchable_jobs):
222
+ image_context = job.context_images
223
+ action_context = job.context_actions
224
+ token_context = job.context_tokens
225
+ # Keep batch dimension
226
+ dreamt_image = predicted_images[job_i].unsqueeze(0)
227
+ dreamt_action = predicted_actions[job_i].unsqueeze(0)
228
+ dreamt_tokens = predicted_image_tokens[job_i].unsqueeze(0)
229
+
230
+ # Replace the dreamed action if provided
231
+ actions_to_take = job.actions_to_take
232
+ if actions_to_take is not None and actions_to_take.shape[1] > 0:
233
+ dreamt_action = actions_to_take[:, 0:1]
234
+ # Remove the action we took
235
+ actions_to_take = actions_to_take[:, 1:]
236
+ if actions_to_take.shape[1] == 0:
237
+ actions_to_take = None
238
+
239
+ result_queue.put(DreamJobResult(
240
+ job_id=job.job_id,
241
+ dream_step_index=job.num_predictions_done,
242
+ dreamt_image=dreamt_image,
243
+ dreamt_action=dreamt_action,
244
+ dreamt_tokens=dreamt_tokens
245
+ ))
246
+
247
+ # Add job back in the queue if we have more steps to do
248
+ if job.num_predictions_remaining > 0:
249
+ # Stack the dreamt image and action to the context
250
+ if image_context.shape[1] >= max_context_length:
251
+ image_context = image_context[:, 1:]
252
+ action_context = action_context[:, 1:]
253
+ token_context = token_context[1:]
254
+ image_context = th.cat([image_context, dreamt_image], dim=1)
255
+ action_context = th.cat([action_context, dreamt_action], dim=1)
256
+ token_context.append(dreamt_tokens[0, 0].tolist())
257
+ # We need to add context length back to sampling settings...
258
+ # add some better way of passing these settings around
259
+ job.sampling_settings["max_context_length"] = max_context_length
260
+ job_queue.put(DreamJob(
261
+ job_id=job.job_id,
262
+ sampling_settings=job.sampling_settings,
263
+ num_predictions_remaining=job.num_predictions_remaining - 1,
264
+ num_predictions_done=job.num_predictions_done + 1,
265
+ context_images=image_context,
266
+ context_actions=action_context,
267
+ context_tokens=token_context,
268
+ actions_to_take=actions_to_take
269
+ ))
270
+
271
+
272
+ class DreamerServer:
273
+ def __init__(self, num_workers, args):
274
+ self.num_workers = num_workers
275
+ self.args = args
276
+ self.model = None
277
+ self.jobs = mp.Queue(maxsize=args.max_concurrent_jobs)
278
+ self.results_queue = mp.Queue()
279
+ self.cancelled_jobs = set()
280
+ self.cancelled_jobs_queues = [mp.Queue() for _ in range(num_workers)]
281
+ # job_id -> results
282
+ self._last_result_cleanup = datetime.now()
283
+ self._max_job_lifespan_datetime = timedelta(seconds=args.max_job_lifespan)
284
+ self.local_results = defaultdict(list)
285
+ self.logger = logging.getLogger("DreamerServer")
286
+
287
+ def get_details(self):
288
+ details = {
289
+ "model_file": self.args.model,
290
+ "max_concurrent_jobs": self.args.max_concurrent_jobs,
291
+ "max_dream_steps_per_job": self.args.max_dream_steps_per_job,
292
+ "max_job_lifespan": self.args.max_job_lifespan,
293
+ }
294
+ return json.dumps(details)
295
+
296
+ def _check_if_should_remove_old_jobs(self):
297
+ time_now = datetime.now()
298
+ # Only cleanup every JOB_CLEANUP_CHECK_RATE seconds at most
299
+ if time_now - self._last_result_cleanup < JOB_CLEANUP_CHECK_RATE:
300
+ return
301
+
302
+ self._last_result_cleanup = time_now
303
+ # First add existing results to the local results
304
+ self._gather_new_results()
305
+ # Check if we should remove old jobs
306
+ job_ids = list(self.local_results.keys())
307
+ for job_id in job_ids:
308
+ results = self.local_results[job_id]
309
+ # If newest result is older than max_job_lifespan, remove the job
310
+ if time_now - results[-1].result_creation_time > self._max_job_lifespan_datetime:
311
+ self.logger.info(f"Deleted job {job_id} because it was too old. Last result was {results[-1].result_creation_time}")
312
+ del self.local_results[job_id]
313
+
314
+ def add_new_job(self, request, request_json):
315
+ """
316
+ Add new dreaming job to the queues.
317
+ Request should have:
318
+
319
+
320
+ Returns: json object with new job id
321
+ """
322
+ self._check_if_should_remove_old_jobs()
323
+
324
+ sampling_settings = copy.deepcopy(DEFAULT_SAMPLING_SETTINGS)
325
+ if "num_steps_to_predict" not in request_json:
326
+ return make_response("num_steps_to_predict not in request", 400)
327
+ num_steps_to_predict = request_json['num_steps_to_predict']
328
+ if num_steps_to_predict > self.args.max_dream_steps_per_job:
329
+ return make_response(f"num_steps_to_predict too large. Max {self.args.max_dream_steps_per_job}", 400)
330
+
331
+ num_parallel_predictions = int(request_json['num_parallel_predictions']) if 'num_parallel_predictions' in request_json else 1
332
+
333
+ if (self.jobs.qsize() + num_parallel_predictions) >= self.args.max_concurrent_jobs:
334
+ return make_response(f"Too many jobs already running. Max {self.args.max_concurrent_jobs}", 400)
335
+
336
+ for key in sampling_settings:
337
+ sampling_settings[key] = float_or_none(request_json[key]) if key in request_json else sampling_settings[key]
338
+
339
+ context_images = []
340
+ context_actions = []
341
+ context_tokens = []
342
+ future_actions = []
343
+
344
+ for step in request_json["steps"]:
345
+ image_path = step["image_name"]
346
+ image = np.array(Image.open(request.files[image_path].stream))
347
+ image = be_image_preprocess(image, target_width=self.args.image_width, target_height=self.args.image_height)
348
+ context_images.append(th.from_numpy(image))
349
+
350
+ action = step["action"]
351
+ action = action_vector_to_be_action_vector(action)
352
+ context_actions.append(th.tensor(action))
353
+
354
+ tokens = step["tokens"]
355
+ context_tokens.append(tokens)
356
+
357
+ future_actions = None
358
+ if "future_actions" in request_json:
359
+ future_actions = []
360
+ for step in request_json["future_actions"]:
361
+ # The rest is the action vector
362
+ action = step["action"]
363
+ action = action_vector_to_be_action_vector(action)
364
+ # Add sequence and batch dimensions
365
+ future_actions.append(th.tensor(action))
366
+
367
+ # Add batch dimensions
368
+ context_images = th.stack(context_images).unsqueeze(0)
369
+ context_actions = th.stack(context_actions).unsqueeze(0)
370
+ future_actions = th.stack(future_actions).unsqueeze(0) if future_actions is not None else None
371
+
372
+ list_of_job_ids = []
373
+ for _ in range(num_parallel_predictions):
374
+ job_id = uuid.uuid4().hex
375
+ self.jobs.put(DreamJob(
376
+ job_id=job_id,
377
+ sampling_settings=sampling_settings,
378
+ num_predictions_remaining=num_steps_to_predict,
379
+ num_predictions_done=0,
380
+ context_images=context_images,
381
+ context_actions=context_actions,
382
+ context_tokens=context_tokens,
383
+ actions_to_take=future_actions
384
+ ))
385
+ list_of_job_ids.append(job_id)
386
+
387
+ job_queue_size = self.jobs.qsize()
388
+ return json.dumps({"job_ids": list_of_job_ids, "current_jobs_in_queue": job_queue_size})
389
+
390
+ def _gather_new_results(self):
391
+ if not self.results_queue.empty():
392
+ for _ in range(self.results_queue.qsize()):
393
+ result = self.results_queue.get()
394
+ if result.job_id in self.cancelled_jobs:
395
+ # Discard result if job was cancelled
396
+ continue
397
+ self.local_results[result.job_id].append(result)
398
+
399
+ def get_new_results(self, request, request_json):
400
+ if "job_ids" not in request_json:
401
+ return make_response("job_ids not in request", 400)
402
+ self._gather_new_results()
403
+ job_ids = request_json["job_ids"]
404
+ if not isinstance(job_ids, list):
405
+ job_ids = [job_ids]
406
+ return_results = []
407
+ for job_id in job_ids:
408
+ if job_id in self.local_results:
409
+ return_results.append(self.local_results[job_id])
410
+ del self.local_results[job_id]
411
+
412
+ if len(return_results) == 0:
413
+ return make_response("No new responses", 204)
414
+
415
+ output_json = []
416
+ output_image_bytes = {}
417
+ for job_results in return_results:
418
+ for result in job_results:
419
+ action = result.dreamt_action.numpy()
420
+ # Remember to remove batch and sequence dimensions
421
+ action = be_action_vector_to_action_vector(action[0, 0].tolist())
422
+ dreamt_tokens = result.dreamt_tokens[0, 0].tolist()
423
+ image_filename = f"{result.job_id}_{result.dream_step_index}.png"
424
+ output_json.append({
425
+ "job_id": result.job_id,
426
+ "dream_step_index": result.dream_step_index,
427
+ "action": action,
428
+ "tokens": dreamt_tokens,
429
+ "image_filename": image_filename
430
+ })
431
+
432
+ image_bytes = io.BytesIO()
433
+ # this probably is not as smooth as it could be
434
+ T.ToPILImage()(result.dreamt_image[0, 0]).save(image_bytes, format="PNG")
435
+ output_image_bytes[image_filename] = image_bytes.getvalue()
436
+
437
+ # Write a zip file with all the images
438
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
439
+ zip_bytes = io.BytesIO()
440
+ with zipfile.ZipFile(zip_bytes, "w") as z:
441
+ for filename, bytes in output_image_bytes.items():
442
+ z.writestr(filename, bytes)
443
+ # Write the json
444
+ z.writestr(PREDICTION_JSON_FILENAME, json.dumps(output_json))
445
+
446
+ zip_bytes.seek(0)
447
+
448
+ return send_file(
449
+ zip_bytes,
450
+ mimetype="zip",
451
+ as_attachment=True,
452
+ download_name=f"dreaming_results_{timestamp}.zip"
453
+ )
454
+
455
+ def cancel_job(self, request, request_json):
456
+ if "job_id" not in request_json:
457
+ return make_response("job_id not in request", 400)
458
+ job_id = request_json["job_id"]
459
+ self.cancelled_jobs.add(job_id)
460
+ # Cancel all jobs in the queue with this id
461
+ for job_queue in self.cancelled_jobs_queues:
462
+ job_queue.put(job_id)
463
+ return make_response("OK", 200)
464
+
465
+
466
+ def main_run(args):
467
+ app = Flask(__name__)
468
+
469
+ num_workers = th.cuda.device_count()
470
+ if num_workers == 0:
471
+ raise RuntimeError("No CUDA devices found. Cannot run Dreamer.")
472
+
473
+ server = DreamerServer(num_workers, args)
474
+ quit_flag = mp.Event()
475
+
476
+ # Start the dreamer worker(s)
477
+ dreamer_worker_processes = []
478
+ for device_i in range(num_workers):
479
+ device = f"cuda:{device_i}"
480
+ dreamer_worker_process = mp.Process(
481
+ target=dreamer_worker,
482
+ args=(server.jobs, server.results_queue, server.cancelled_jobs_queues[device_i], quit_flag, device, args)
483
+ )
484
+ dreamer_worker_process.daemon = True
485
+ dreamer_worker_process.start()
486
+ dreamer_worker_processes.append(dreamer_worker_process)
487
+
488
+ # Add the API endpoints
489
+ @app.route('/')
490
+ def details():
491
+ return server.get_details()
492
+
493
+ @app.route('/new_job', methods=['POST'])
494
+ def new_job():
495
+ request_json = json.loads(request.form["json"])
496
+ return server.add_new_job(request, request_json)
497
+
498
+ @app.route('/get_job_results', methods=['GET'])
499
+ def get_results():
500
+ # the "Json" is now in regular GET payload/parameters
501
+ request_json = {"job_ids": request.args.getlist("job_ids")}
502
+ return server.get_new_results(request, request_json)
503
+
504
+ @app.route('/cancel_job', methods=['GET'])
505
+ def cancel_job():
506
+ request_json = request.args.to_dict()
507
+ return server.cancel_job(request, request_json)
508
+
509
+ app.run(host="0.0.0.0", port=args.port, debug=args.debug)
510
+
511
+ # Cleanup
512
+ quit_flag.set()
513
+ for dreamer_worker_process in dreamer_worker_processes:
514
+ dreamer_worker_process.join()
515
+
516
+
517
+ if __name__ == '__main__':
518
+ args = parser.parse_args()
519
+ main_run(args)
setup_local.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Tested using Python 3.9
2
+
3
+ echo "Making and activating a new virtual environment..."
4
+ python3.9 -m venv venv
5
+
6
+ echo "Activating the virtual environment..."
7
+ source venv/bin/activate
8
+
9
+ echo "Upgrading pip..."
10
+ pip install --upgrade pip
11
+
12
+ echo "Instaling the required packages..."
13
+ pip install -r requirements.txt
14
+
15
+ echo "Instaling the exiftool package for adding file metadata on Linux..."
16
+ sudo apt install -y exiftool
17
+
18
+ echo "Installing ffmpeg..."
19
+ sudo apt install ffmpeg
20
+
21
+ echo "All packages installed successfully!"
wham/models/nn/model_blocks.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ """
4
+ Some Utility blocks for ViT-VQGAN.
5
+
6
+ ConvNeXt blocks are based on:
7
+ Liu, Zhuang, et al. "A convnet for the 2020s."
8
+ Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022.
9
+ """
10
+
11
+
12
+ class ConvNextDownsampleBig(nn.Module):
13
+ def __init__(self, c_in, c_out):
14
+ super().__init__()
15
+ self.group_norm = nn.GroupNorm(c_in, c_in)
16
+ self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=8, stride=4, padding=0)
17
+
18
+ def forward(self, x):
19
+ return self.conv1(self.group_norm(x))
20
+
21
+
22
+ class ConvNextBlock(nn.Module):
23
+ def __init__(self, channels):
24
+ super().__init__()
25
+ self.conv1 = nn.Conv2d(channels, channels, kernel_size=7, stride=1, padding=7 // 2, groups=channels) # 'Depthwise' conv
26
+ self.group_norm = nn.GroupNorm(channels, channels) # Should be equivalent to layernorm
27
+
28
+ # Transformer-style non-linearity
29
+ self.conv2 = nn.Conv2d(channels, channels * 4, kernel_size=1, stride=1, padding=0)
30
+ self.activation = nn.GELU()
31
+ self.conv3 = nn.Conv2d(channels * 4, channels, kernel_size=1, stride=1, padding=0)
32
+
33
+ def forward(self, x):
34
+ y = self.conv1(x)
35
+ y = self.group_norm(y)
36
+ y = self.conv2(y)
37
+ y = self.activation(y)
38
+ y = self.conv3(y)
39
+ return x + y
40
+
41
+
42
+ class ConvNextDownsample(nn.Module):
43
+ def __init__(self, c_in, c_out):
44
+ super().__init__()
45
+ self.group_norm = nn.GroupNorm(c_in, c_in)
46
+ self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)
47
+
48
+ def forward(self, x):
49
+ return self.conv1(self.group_norm(x))
wham/models/nn/nanoGPT.py ADDED
@@ -0,0 +1,665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://github.com/karpathy/nanoGPT/blob/master/model.py - Thanks Andrej Karpathy
2
+
3
+ # MIT License
4
+ # Copyright (c) 2022 Andrej Karpathy
5
+ # 2023 Microsoft Research
6
+
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18
+ # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19
+ # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
20
+ # IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
21
+ # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
22
+ # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
23
+ # OR OTHER DEALINGS IN THE SOFTWARE.
24
+
25
+
26
+ """
27
+ Full definition of a GPT Language Model, all of it in this single file.
28
+ References:
29
+ 1) the official GPT-2 TensorFlow implementation released by OpenAI:
30
+ https://github.com/openai/gpt-2/blob/master/src/model.py
31
+ 2) huggingface/transformers PyTorch implementation:
32
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
33
+ """
34
+
35
+ from dataclasses import dataclass
36
+ import inspect
37
+ import math
38
+
39
+ import torch
40
+ import torch.nn as nn
41
+ from torch.nn import functional as F
42
+
43
+ NEGATIVE_INFINITE_FLOAT = -float("inf")
44
+ CROSS_ENTROPY_INVALID_CLASS_TARGET = -1
45
+
46
+ # @torch.jit.script # good to enable when not using torch.compile, disable when using (our default)
47
+ def new_gelu(x):
48
+ """
49
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
50
+ Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
51
+ """
52
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
53
+
54
+
55
+ def limit_logits_to_valid_range(logits, valid_token_range):
56
+ """
57
+ MODIFIES logits INPLACE.
58
+ Mask out invalid positions in the logits tensor with -inf so they are not considered by the softmax.
59
+
60
+ Args:
61
+ logits: logits tensor of shape (batch_size, vocab_size)
62
+ valid_token_range: tuple of (start, end) indices of valid positions in the logits tensor (inclusive).
63
+ Everything outside is masked out with -inf.
64
+ """
65
+ logits[:, : valid_token_range[0]] = NEGATIVE_INFINITE_FLOAT
66
+ logits[:, valid_token_range[1] + 1 :] = NEGATIVE_INFINITE_FLOAT
67
+
68
+
69
+ def default_sample_token(logits, valid_token_range=None, temperature=1.0, deterministic=False, top_k=None, top_p=None, min_tokens_to_keep=1):
70
+ """
71
+ Given a vector of logits, sample and return an index according to settings.
72
+
73
+ logits: tensor of shape (batch_size, vocab_size)
74
+
75
+ valid_token_range should be a tuple, specifying start and end indices we'd like to sample from (inclusive).
76
+ If None, we'll sample from the full vocab.
77
+
78
+ If deterministic is True, we'll take the argmax of the logits which implies top-k sampling with top_k = 1, therefore user inputted values of top_p and top_k will be ignored.
79
+
80
+ Otherwise, either top-p (float) value can be specified or top-k (int) value can be specified.
81
+ Top-p (float top_p) : only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
82
+ Top-k (int top_k) : selects top_k tokens for generation.
83
+ min_tokens_to_keep: Used with both top_p and top_k sampling.
84
+ """
85
+ assert top_k is None or top_p is None, "Can only specify one of top-k or top-p sampling."
86
+ if temperature < 0.1:
87
+ # Avoid too low a temp, especially 0
88
+ temperature = 0.1
89
+ logits = logits / temperature
90
+ if valid_token_range is not None:
91
+ limit_logits_to_valid_range(logits, valid_token_range)
92
+ if deterministic:
93
+ selected_logits = select_logits(logits, top_k=1)
94
+ else:
95
+ selected_logits = select_logits(logits, top_p=top_p, top_k=top_k, min_tokens_to_keep=min_tokens_to_keep)
96
+ probs = F.softmax(selected_logits, dim=-1)
97
+ # More robustly handle errors in the sampling here
98
+ sampled_idx = torch.multinomial(probs, num_samples=1).squeeze(-1)
99
+ return sampled_idx
100
+
101
+
102
+ def select_logits(logits, top_k=None, top_p=None, min_tokens_to_keep=1):
103
+ """
104
+ Select from original logits using top-k or top-p sampling.
105
+
106
+ Args:
107
+ logits (torch.Tensor): Logits to sample from.
108
+ k (int, optional): Number of top elements to consider in top-k sampling.
109
+ p (float, optional): Threshold probability for top-p sampling.
110
+ min_tokens_to_keep (int, optional): Minimum number of tokens to keep in the output.
111
+
112
+ Returns:
113
+ logits: Selected logits after top-k or top-p sampling. Sets all logits outside the selected ones to NEGATIVE_INFINITE_FLOAT.
114
+ """
115
+ assert top_k is None or top_p is None, "Can only specify one of top-k or top-p sampling."
116
+ min_tokens_to_keep = min(min_tokens_to_keep, logits.size(-1))
117
+ if top_k is not None:
118
+ if not isinstance(top_k, int) or top_k <= 0:
119
+ raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
120
+
121
+ # Top-k sampling
122
+ top_k = max(top_k, min_tokens_to_keep)
123
+ top_k = min(top_k, logits.size(-1))
124
+ top_k_logits, _ = torch.topk(logits, top_k)
125
+ indices_to_remove = logits < top_k_logits[..., -1:]
126
+ logits = torch.where(indices_to_remove, NEGATIVE_INFINITE_FLOAT, logits)
127
+
128
+ elif top_p is not None:
129
+ top_p = float(top_p)
130
+ if top_p < 0 or top_p > 1.0:
131
+ raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
132
+
133
+ # Top-p sampling
134
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
135
+ sorted_probs = torch.softmax(sorted_logits, dim=-1)
136
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
137
+ sorted_indices_to_remove = cumulative_probs > top_p
138
+
139
+ # Remove tokens with cumulative probability above the threshold
140
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = False
141
+
142
+ # scatter sorted tensors to original indexing
143
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
144
+ logits = torch.where(indices_to_remove, NEGATIVE_INFINITE_FLOAT, logits)
145
+
146
+ else:
147
+ # Return logits as is
148
+ pass
149
+
150
+ return logits
151
+
152
+
153
+ class LayerNorm(nn.Module):
154
+ """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
155
+
156
+ def __init__(self, ndim, bias):
157
+ super().__init__()
158
+ self.weight = nn.Parameter(torch.ones(ndim))
159
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
160
+
161
+ def forward(self, input):
162
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
163
+
164
+ class LayerNormMinimal(nn.Module):
165
+ """LayerNorm like above, but without learnable parameters"""
166
+
167
+ def __init__(self, ndim, bias):
168
+ super().__init__()
169
+ self.ndim = (ndim,)
170
+
171
+ def forward(self, input):
172
+ return F.layer_norm(input, self.ndim, eps=1e-5)
173
+
174
+
175
+ class CausalSelfAttention(nn.Module):
176
+ def __init__(self, config):
177
+ super().__init__()
178
+ assert config.n_embd % config.n_head == 0
179
+ # key, query, value projections for all heads, but in a batch
180
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
181
+ # output projection
182
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
183
+ # regularization
184
+ self.attn_dropout = nn.Dropout(config.dropout)
185
+ self.resid_dropout = nn.Dropout(config.dropout)
186
+ self.n_head = config.n_head
187
+ self.n_embd = config.n_embd
188
+ self.dropout = config.dropout
189
+ # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
190
+ self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0
191
+ # causal mask to ensure that attention is only applied to the left in the input sequence
192
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size), persistent=False)
193
+
194
+ self.cached_k = None
195
+ self.cached_v = None
196
+ self.current_cache_size = 0
197
+
198
+ def _manual_causal_attention(self, q, k, v, mask):
199
+ # q, k and v should be of shape (B, nh, T, hs)
200
+ token_len = q.size(-2)
201
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
202
+ att = att.masked_fill(mask[:, :, :token_len, :token_len] == 0, float("-inf"))
203
+ att = F.softmax(att, dim=-1)
204
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
205
+ return y
206
+
207
+ def forward(self, x, cache=False):
208
+ batch_size, token_len, n_embd = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
209
+
210
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
211
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
212
+ k = k.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
213
+ q = q.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
214
+ v = v.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
215
+
216
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
217
+ if self.flash and not cache:
218
+ # efficient attention using Flash Attention CUDA kernels
219
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
220
+ elif cache:
221
+ # manual implemention of attention (as below), but cache arrays we can reuse
222
+ assert token_len == 1, "Cache only works for single step"
223
+ assert self.cached_k is not None, "Must call reset_cache() before using cache"
224
+ assert self.current_cache_size < self.cached_k.size(2), "Trying to generate more steps than provided in reset_cache() `num_steps_to_come`"
225
+ assert self.dropout == 0.0, "Dropout not supported with caching"
226
+ this_step_q = q
227
+ self.cached_k[:, :, self.current_cache_size, :] = k[:, :, 0, :]
228
+ self.cached_v[:, :, self.current_cache_size, :] = v[:, :, 0, :]
229
+ # Remove the zero parts
230
+ k = self.cached_k[:, :, : self.current_cache_size + 1, :]
231
+ # compute last row of the attention mask
232
+ this_step_att_row = (this_step_q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
233
+ this_step_att_row = F.softmax(this_step_att_row, dim=-1)
234
+ # We only need output for the current step
235
+ y = this_step_att_row @ self.cached_v[:, :, : self.current_cache_size + 1, :]
236
+ # Update cache
237
+ self.current_cache_size += 1
238
+ else:
239
+ y = self._manual_causal_attention(q, k, v, self.bias)
240
+ y = y.transpose(1, 2).contiguous().view(batch_size, token_len, n_embd) # re-assemble all head outputs side by side
241
+
242
+ # output projection
243
+ y = self.resid_dropout(self.c_proj(y))
244
+ return y
245
+
246
+ def reset_cache(self, x, num_steps_to_come):
247
+ """
248
+ Reset caches by doing initial pass with x data (returning same output as forward).
249
+ Also set the number of steps to come, which is used to initialize the buffers
250
+ """
251
+ batch_size, token_len, n_embd = x.size()
252
+
253
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
254
+ k = k.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
255
+ q = q.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
256
+ v = v.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
257
+
258
+ # Use SDPA instead of a manual implementation
259
+ # y = self._manual_causal_attention(q, k, v, self.bias)
260
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
261
+
262
+ y = y.transpose(1, 2).contiguous().view(batch_size, token_len, n_embd)
263
+ # output projection
264
+ y = self.resid_dropout(self.c_proj(y))
265
+
266
+ # Create full k,q,v for predicting all future steps.
267
+ # Just null-out the last num_steps_to_come-1 steps
268
+ pad_size = num_steps_to_come
269
+ self.current_cache_size = token_len
270
+ self.cached_k = torch.cat([k, torch.zeros(batch_size, self.n_head, pad_size, n_embd // self.n_head, device=k.device)], dim=2)
271
+ self.cached_v = torch.cat([v, torch.zeros(batch_size, self.n_head, pad_size, n_embd // self.n_head, device=v.device)], dim=2)
272
+
273
+ return y
274
+
275
+ class SelfAttention(nn.Module):
276
+ """
277
+ Non-causal self-attention layer, the same as CausalSelfAttention but without the causal mask.
278
+ Duplicating the code to keep this separate for clarity.
279
+ """
280
+
281
+ def __init__(self, config):
282
+ super().__init__()
283
+ assert config.n_embd % config.n_head == 0
284
+ # key, query, value projections for all heads, but in a batch
285
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
286
+ # output projection
287
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
288
+ # regularization
289
+ self.attn_dropout = nn.Dropout(config.dropout)
290
+ self.resid_dropout = nn.Dropout(config.dropout)
291
+ self.n_head = config.n_head
292
+ self.n_embd = config.n_embd
293
+ self.dropout = config.dropout
294
+ # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
295
+ self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0
296
+ assert self.flash, "SelfAttention only supports flash attention for now."
297
+
298
+ self.register_buffer("attn_mask", torch.ones((config.block_size, config.block_size)).bool().unsqueeze(0).unsqueeze(0))
299
+
300
+ def forward(self, x):
301
+ batch_size, token_len, n_embd = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
302
+
303
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
304
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
305
+ k = k.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
306
+ q = q.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
307
+ v = v.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
308
+
309
+ # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
310
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=self.attn_mask, dropout_p=self.dropout, is_causal=False)
311
+ y = y.transpose(1, 2).contiguous().view(batch_size, token_len, n_embd) # re-assemble all head outputs side by side
312
+
313
+ # output projection
314
+ y = self.resid_dropout(self.c_proj(y))
315
+ return y
316
+
317
+ class MLP(nn.Module):
318
+ def __init__(self, config):
319
+ super().__init__()
320
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
321
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
322
+ self.dropout = nn.Dropout(config.dropout)
323
+
324
+ def forward(self, x):
325
+ x = self.c_fc(x)
326
+ x = new_gelu(x)
327
+ x = self.c_proj(x)
328
+ x = self.dropout(x)
329
+ return x
330
+
331
+ class GELU_MLP(nn.Module):
332
+ """MLP Block using PyTorch's native GELU activation function"""
333
+ def __init__(self, config):
334
+ super().__init__()
335
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
336
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
337
+ self.dropout = nn.Dropout(config.dropout)
338
+
339
+ def forward(self, x):
340
+ x = self.c_fc(x)
341
+ x = F.gelu(x, approximate="tanh")
342
+ x = self.c_proj(x)
343
+ x = self.dropout(x)
344
+ return x
345
+
346
+
347
+ class Block(nn.Module):
348
+ def __init__(self, config):
349
+ super().__init__()
350
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
351
+ self.attn = CausalSelfAttention(config)
352
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
353
+ self.mlp = MLP(config)
354
+
355
+ def forward(self, x, cache=False, reset_cache_with_num_steps_to_come=None):
356
+ """
357
+ Args:
358
+ cache: If True, use the cache to predict the next token (assumes model was initialized with `reset_cache`).
359
+ reset_cache_with_num_steps_to_come:
360
+ If not None, reset and prepare the cache for cached prediction of the next `reset_cache_with_num_steps_to_come` tokens.
361
+ This is same as calling `reset_cache` with the same argument, but we include option here in `forward` to support torch hook functions (used to get embeddings from this module output).
362
+
363
+ Caching example:
364
+ ```
365
+ # Initialize model with reset_cache_with_num_steps_to_come=10
366
+ outputs[0] = model(inputs, reset_cache_with_num_steps_to_come=10)
367
+ # Predict next 10 tokens using cache
368
+ for i in range(10):
369
+ outputs[i+1] = model(inputs, cache=True)
370
+ ```
371
+ """
372
+ if reset_cache_with_num_steps_to_come:
373
+ return self.reset_cache(x, num_steps_to_come=reset_cache_with_num_steps_to_come)
374
+ x = x + self.attn(self.ln_1(x), cache=cache)
375
+ x = x + self.mlp(self.ln_2(x))
376
+ return x
377
+
378
+ def reset_cache(self, x, num_steps_to_come):
379
+ x = x + self.attn.reset_cache(self.ln_1(x), num_steps_to_come=num_steps_to_come)
380
+ x = x + self.mlp(self.ln_2(x))
381
+ return x
382
+
383
+ class BlockV2(nn.Module):
384
+ """
385
+ Compared to the Block in the original implementation, this one uses non-parametric LayerNorm and Pytorch's GELU.
386
+ These two changes save significant vram but are incompatible with previously trained models.
387
+ Hence the separate class.
388
+ """
389
+
390
+ def __init__(self, config):
391
+ super().__init__()
392
+ self.ln_1 = LayerNormMinimal(config.n_embd, bias=config.bias)
393
+ self.attn = CausalSelfAttention(config)
394
+ self.ln_2 = LayerNormMinimal(config.n_embd, bias=config.bias)
395
+ self.mlp = GELU_MLP(config)
396
+
397
+ def forward(self, x, cache=False, reset_cache_with_num_steps_to_come=None):
398
+ if reset_cache_with_num_steps_to_come:
399
+ return self.reset_cache(x, num_steps_to_come=reset_cache_with_num_steps_to_come)
400
+ x = x + self.attn(self.ln_1(x), cache=cache)
401
+ x = x + self.mlp(self.ln_2(x))
402
+ return x
403
+
404
+ def reset_cache(self, x, num_steps_to_come):
405
+ x = x + self.attn.reset_cache(self.ln_1(x), num_steps_to_come=num_steps_to_come)
406
+ x = x + self.mlp(self.ln_2(x))
407
+ return x
408
+
409
+ class SelfAttentionBlock(nn.Module):
410
+ def __init__(self, config):
411
+ super().__init__()
412
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
413
+ self.attn = SelfAttention(config)
414
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
415
+ self.mlp = MLP(config)
416
+
417
+ def forward(self, x):
418
+ x = x + self.attn(self.ln_1(x))
419
+ x = x + self.mlp(self.ln_2(x))
420
+ return x
421
+
422
+ @dataclass
423
+ class GPTConfig:
424
+ block_size: int = 1024
425
+ vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
426
+ n_layer: int = 12
427
+ n_head: int = 12
428
+ n_embd: int = 768
429
+ dropout: float = 0.0
430
+ bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
431
+ version: int = 1 # Version 1 is the original GPT, Version 2 is the one with non-parametric LayerNorm and Pytorch's GELU
432
+
433
+
434
+ class GPT(nn.Module):
435
+ def __init__(self, config):
436
+ super().__init__()
437
+ assert config.vocab_size is not None
438
+ assert config.block_size is not None
439
+ self.config = config
440
+
441
+ self.version = config.version
442
+
443
+ print(f"[nanoGPT] creating model with version {self.version}")
444
+
445
+ if self.version == 1:
446
+ transformer_dict = dict(
447
+ wpe=nn.Embedding(config.block_size, config.n_embd),
448
+ drop=nn.Dropout(config.dropout),
449
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
450
+ ln_f=LayerNorm(config.n_embd, bias=config.bias),
451
+ )
452
+ elif self.version == 2:
453
+ transformer_dict = dict(
454
+ wpe=nn.Embedding(config.block_size, config.n_embd),
455
+ drop=nn.Dropout(config.dropout),
456
+ h=nn.ModuleList([BlockV2(config) for _ in range(config.n_layer)]),
457
+ ln_f=LayerNorm(config.n_embd, bias=config.bias), # This one is still parametric due to user error
458
+ )
459
+
460
+ transformer_dict["wte"] = nn.Embedding(config.vocab_size, config.n_embd)
461
+ self.transformer = nn.ModuleDict(transformer_dict)
462
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
463
+ # with weight tying when using torch.compile() some warnings get generated:
464
+ # "UserWarning: functional_call was passed multiple values for tied weights.
465
+ # This behavior is deprecated and will be an error in future versions"
466
+ # not 100% sure what this is, so far seems to be harmless.
467
+ self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
468
+
469
+ # init all weights
470
+ self.apply(self._init_weights)
471
+ # apply special scaled init to the residual projections, per GPT-2 paper
472
+ for pn, p in self.named_parameters():
473
+ if pn.endswith("c_proj.weight"):
474
+ torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
475
+
476
+ def get_num_params(self, non_embedding=True):
477
+ """
478
+ Return the number of parameters in the model.
479
+ For non-embedding count (default), the position embeddings get subtracted.
480
+ The token embeddings would too, except due to the parameter sharing these
481
+ params are actually used as weights in the final layer, so we include them.
482
+ """
483
+ n_params = sum(p.numel() for p in self.parameters())
484
+ if non_embedding:
485
+ n_params -= self.transformer.wpe.weight.numel()
486
+ return n_params
487
+
488
+ def _init_weights(self, module):
489
+ if isinstance(module, nn.Linear):
490
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
491
+ if module.bias is not None:
492
+ torch.nn.init.zeros_(module.bias)
493
+ elif isinstance(module, nn.Embedding):
494
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
495
+
496
+ def _apply_pos_encoding(self, x):
497
+ device = x.device
498
+ token_len = x.size(1)
499
+ pos = torch.arange(0, token_len, dtype=torch.long, device=device).unsqueeze(0)
500
+ pos_emb = self.transformer.wpe(pos)
501
+ x = x + pos_emb
502
+ return x
503
+
504
+ def original_forward(self, idx, targets=None, loss_mask=None, loss_reduction="mean"):
505
+ batch_size, seq_len = idx.shape[:2]
506
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
507
+ x = self.transformer.drop(self._apply_pos_encoding(tok_emb))
508
+ for block in self.transformer.h:
509
+ x = block(x)
510
+ x = self.transformer.ln_f(x)
511
+
512
+ if targets is not None:
513
+ # if we are given some desired targets also calculate the loss
514
+ logits = self.lm_head(x)
515
+ if loss_mask is not None:
516
+ # Feeding target = CROSS_ENTROPY_INVALID_CLASS_TARGET to cross_entropy will ignore the loss
517
+ # for that position. This is useful for padding tokens.
518
+ targets[loss_mask == 0] = CROSS_ENTROPY_INVALID_CLASS_TARGET
519
+ loss = F.cross_entropy(
520
+ logits.view(batch_size * seq_len, self.config.vocab_size), targets.view(-1), ignore_index=CROSS_ENTROPY_INVALID_CLASS_TARGET, reduction=loss_reduction
521
+ )
522
+ if loss_reduction == "none":
523
+ # Reshape back into batch_size and seq_len
524
+ loss = loss.view(batch_size, seq_len)
525
+ else:
526
+ # inference-time mini-optimization: only forward the lm_head on the very last position
527
+ logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
528
+ loss = None
529
+
530
+ return logits, loss
531
+
532
+ def forward(self, x, targets=None, loss_mask=None, loss_reduction="mean"):
533
+ token_len = x.size(1)
534
+ assert token_len <= self.config.block_size, f"Cannot forward sequence of length {token_len}, block size is only {self.config.block_size}"
535
+ return self.original_forward(x, targets, loss_mask, loss_reduction)
536
+
537
+ @torch.no_grad()
538
+ def generate(self, idx, max_new_tokens, valid_token_range=None, temperature=1.0, top_k=None, raise_cropping=False, deterministic=False):
539
+ """
540
+ valid_token_range should be a tuple, specifying start and end indices we'd like to sample from (inclusive).
541
+ if None, we'll sample from the full vocab.
542
+
543
+ If raise_cropping is True, we'll raise an error if we need to crop the sequence context.
544
+ """
545
+ if valid_token_range is None:
546
+ valid_token_range = (0, self.config.vocab_size - 1)
547
+ assert len(valid_token_range) == 2
548
+ assert valid_token_range[0] < valid_token_range[1]
549
+ for _ in range(max_new_tokens):
550
+ # if the sequence context is growing too long we must crop it at block_size
551
+ idx_cond = idx
552
+ if idx.size(1) > self.config.block_size:
553
+ if raise_cropping:
554
+ raise ValueError("Tried to crop idxs but flag told to raise this")
555
+ else:
556
+ idx_cond = idx[:, -self.config.block_size :]
557
+ # forward the model to get the logits for the index in the sequence
558
+ logits, _ = self(idx_cond)
559
+ # pluck the logits at the final step and scale by desired temperature
560
+ logits = logits[:, -1, :] / temperature # logits is B T Vocabsize -> B Vocabsize
561
+ # optionally crop the logits to only the top k options
562
+ if top_k is not None:
563
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
564
+ logits[logits < v[:, [-1]]] = NEGATIVE_INFINITE_FLOAT
565
+
566
+ # Crop out the logits we don't want to sample from
567
+ if valid_token_range is not None:
568
+ limit_logits_to_valid_range(logits, valid_token_range)
569
+
570
+ # apply softmax to convert logits to (normalized) probabilities
571
+ probs = F.softmax(logits, dim=-1)
572
+
573
+ if deterministic:
574
+ # Take max of the results
575
+ idx_next = torch.argmax(probs, dim=-1, keepdim=True)
576
+ else:
577
+ # sample from the distribution
578
+ idx_next = torch.multinomial(probs, num_samples=1)
579
+ # append sampled index to the running sequence and continue
580
+ idx = torch.cat((idx, idx_next), dim=1)
581
+
582
+ return idx
583
+
584
+ @torch.no_grad()
585
+ def optimized_generate(
586
+ self,
587
+ idx,
588
+ num_new_tokens,
589
+ valid_token_ranges=None,
590
+ temperature=1.0,
591
+ deterministic=False,
592
+ raise_cropping=False,
593
+ top_k=None,
594
+ top_p=None,
595
+ min_tokens_to_keep=1,
596
+ ):
597
+ """
598
+ Generate function but optimized by caching the results in transformer blocks (think this is referred to as "attention caching").
599
+ The higher the num_new_tokens, the more the speedup compared to original generate.
600
+
601
+ Caveat: the context length + num_new_tokens must be less than the block size. This means that the first
602
+ generated tokens do not have full context length.
603
+
604
+ valid_token_ranges should be None or list of length num_new_tokens, specifying valid range for tokens for every step
605
+ """
606
+ # Properly compile the modules used and/or quantize for improved speed.
607
+ logit_layer = self.lm_head
608
+ embedder_fn = self.transformer.wte
609
+
610
+ if valid_token_ranges is None:
611
+ valid_token_ranges = [[0, self.config.vocab_size] for _ in range(num_new_tokens)]
612
+ assert len(valid_token_ranges) == num_new_tokens, "valid_token_ranges should be list of length num_new_tokens or None"
613
+
614
+ _, token_len = idx.size()
615
+ if token_len + num_new_tokens > self.config.block_size:
616
+ raise ValueError("Can't use optimized generation with num_new_tokens + context_length > block_size")
617
+ new_idxs = torch.zeros(idx.size(0), num_new_tokens, dtype=torch.long, device=idx.device)
618
+ # First, we need to cull the sequence to the block size
619
+ # and remove first max_new_tokens so we can reuse same position embeddings
620
+ # and not have to recompute them
621
+ num_original_tokens = idx.size(1)
622
+ original_idx = idx
623
+ if (num_original_tokens + num_new_tokens) > self.config.block_size:
624
+ if raise_cropping:
625
+ raise ValueError("Tried to crop idxs but flag told to raise this")
626
+ original_idx = idx[:, -self.config.block_size + num_new_tokens :]
627
+ original_pos = torch.arange(0, original_idx.size(1), dtype=torch.long, device=idx.device).unsqueeze(0)
628
+ # Now cache results with the original context
629
+ original_tok_emb = embedder_fn(original_idx)
630
+ original_pos_emb = self.transformer.wpe(original_pos)
631
+ original_x = original_tok_emb + original_pos_emb
632
+ for block in self.transformer.h:
633
+ # Reset the cache for each block, and cache new result
634
+ original_x = block(original_x, reset_cache_with_num_steps_to_come=num_new_tokens)
635
+
636
+ # Sample the first token
637
+ original_x = self.transformer.ln_f(original_x)
638
+ last_logit = logit_layer(original_x[:, [-1], :])
639
+ new_idxs[:, 0] = default_sample_token(
640
+ last_logit[:, -1, :], valid_token_ranges[0], temperature, deterministic, top_k=top_k, top_p=top_p, min_tokens_to_keep=min_tokens_to_keep
641
+ )
642
+
643
+ # Generate rest of the steps
644
+ for generation_idx in range(1, num_new_tokens):
645
+ # forward the model to get the logits for the index in the sequence
646
+ # This is the position of the latest generated token, not the currently going-to-be-generated token
647
+ latest_token_pos = num_original_tokens + generation_idx - 1
648
+ # We only need to pass in the latest token
649
+ newest_idx = new_idxs[:, generation_idx - 1].unsqueeze(-1)
650
+ newest_tok_emb = embedder_fn(newest_idx)
651
+ newest_pos_emb = self.transformer.wpe(torch.tensor(latest_token_pos, dtype=torch.long, device=idx.device).unsqueeze(0))
652
+ newest_x = newest_tok_emb + newest_pos_emb
653
+ for block in self.transformer.h:
654
+ newest_x = block(newest_x, cache=True)
655
+
656
+ newest_x = self.transformer.ln_f(newest_x)
657
+ newest_logit = logit_layer(newest_x)
658
+ # Check this function isn't slowing things down noticeably
659
+ new_idxs[:, generation_idx] = default_sample_token(
660
+ newest_logit[:, -1, :], valid_token_ranges[generation_idx], temperature, deterministic, top_k=top_k, top_p=top_p, min_tokens_to_keep=min_tokens_to_keep
661
+ )
662
+
663
+ # Combine indices
664
+ new_idxs = torch.cat((idx, new_idxs), dim=1)
665
+ return new_idxs
wham/models/pl/__init__.py ADDED
File without changes
wham/models/pl/pl_base_model.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+
3
+ class BaseTrainingModel(pl.LightningModule):
4
+ def __init__(self, **kwargs):
5
+ super().__init__(**kwargs)
wham/models/vqgan/taming/LICENSE ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ All files under this directory are originally from the taming-transformers repository:
2
+ https://github.com/CompVis/taming-transformers
3
+
4
+ Below is a copy of the original license
5
+ ------------------------------------------------------------------------------
6
+ Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
7
+
8
+ Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ of this software and associated documentation files (the "Software"), to deal
10
+ in the Software without restriction, including without limitation the rights
11
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ copies of the Software, and to permit persons to whom the Software is
13
+ furnished to do so, subject to the following conditions:
14
+
15
+ The above copyright notice and this permission notice shall be included in all
16
+ copies or substantial portions of the Software.
17
+
18
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
19
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
20
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
21
+ IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
22
+ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
23
+ OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
24
+ OR OTHER DEALINGS IN THE SOFTWARE./
wham/models/vqgan/taming/model.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # All files under this directory are originally from the taming-transformers repository:
2
+ # https://github.com/CompVis/taming-transformers
3
+
4
+ # MIT License
5
+ # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
6
+ # 2023 Microsoft Research
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in all
16
+ # copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
19
+ # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
20
+ # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
21
+ # IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
22
+ # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
23
+ # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
24
+ # OR OTHER DEALINGS IN THE SOFTWARE.
25
+
26
+ import math
27
+ import torch
28
+ import torch.nn as nn
29
+ import numpy as np
30
+
31
+
32
+ def get_timestep_embedding(timesteps, embedding_dim):
33
+ """
34
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
35
+ From Fairseq.
36
+ Build sinusoidal embeddings.
37
+ This matches the implementation in tensor2tensor, but differs slightly
38
+ from the description in Section 3.5 of "Attention Is All You Need".
39
+ """
40
+ assert len(timesteps.shape) == 1
41
+
42
+ half_dim = embedding_dim // 2
43
+ emb = math.log(10000) / (half_dim - 1)
44
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
45
+ emb = emb.to(device=timesteps.device)
46
+ emb = timesteps.float()[:, None] * emb[None, :]
47
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
48
+ if embedding_dim % 2 == 1: # zero pad
49
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
50
+ return emb
51
+
52
+
53
+ def nonlinearity(x):
54
+ # swish
55
+ return x * torch.sigmoid(x)
56
+
57
+
58
+ def Normalize(in_channels):
59
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
60
+
61
+
62
+ class Upsample(nn.Module):
63
+ def __init__(self, in_channels, with_conv):
64
+ super().__init__()
65
+ self.with_conv = with_conv
66
+ if self.with_conv:
67
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
68
+
69
+ def forward(self, x):
70
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
71
+ if self.with_conv:
72
+ x = self.conv(x)
73
+ return x
74
+
75
+
76
+ class Downsample(nn.Module):
77
+ def __init__(self, in_channels, with_conv):
78
+ super().__init__()
79
+ self.with_conv = with_conv
80
+ if self.with_conv:
81
+ # no asymmetric padding in torch conv, must do it ourselves
82
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
83
+
84
+ def forward(self, x):
85
+ if self.with_conv:
86
+ pad = (0, 1, 0, 1)
87
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
88
+ x = self.conv(x)
89
+ else:
90
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
91
+ return x
92
+
93
+
94
+ class ResnetBlock(nn.Module):
95
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
96
+ super().__init__()
97
+ self.in_channels = in_channels
98
+ out_channels = in_channels if out_channels is None else out_channels
99
+ self.out_channels = out_channels
100
+ self.use_conv_shortcut = conv_shortcut
101
+
102
+ self.norm1 = Normalize(in_channels)
103
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
104
+ if temb_channels > 0:
105
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
106
+ self.norm2 = Normalize(out_channels)
107
+ self.dropout = torch.nn.Dropout(dropout)
108
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
109
+ if self.in_channels != self.out_channels:
110
+ if self.use_conv_shortcut:
111
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
112
+ else:
113
+ self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
114
+
115
+ def forward(self, x, temb):
116
+ h = x
117
+ h = self.norm1(h)
118
+ h = nonlinearity(h)
119
+ h = self.conv1(h)
120
+
121
+ if temb is not None:
122
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
123
+
124
+ h = self.norm2(h)
125
+ h = nonlinearity(h)
126
+ h = self.dropout(h)
127
+ h = self.conv2(h)
128
+
129
+ if self.in_channels != self.out_channels:
130
+ if self.use_conv_shortcut:
131
+ x = self.conv_shortcut(x)
132
+ else:
133
+ x = self.nin_shortcut(x)
134
+
135
+ return x + h
136
+
137
+
138
+ class AttnBlock(nn.Module):
139
+ def __init__(self, in_channels):
140
+ super().__init__()
141
+ self.in_channels = in_channels
142
+
143
+ self.norm = Normalize(in_channels)
144
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
145
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
146
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
147
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
148
+
149
+ def forward(self, x):
150
+ h_ = x
151
+ h_ = self.norm(h_)
152
+ q = self.q(h_)
153
+ k = self.k(h_)
154
+ v = self.v(h_)
155
+
156
+ # compute attention
157
+ b, c, h, w = q.shape
158
+ q = q.reshape(b, c, h * w)
159
+ q = q.permute(0, 2, 1) # b,hw,c
160
+ k = k.reshape(b, c, h * w) # b,c,hw
161
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
162
+ w_ = w_ * (int(c) ** (-0.5))
163
+ w_ = torch.nn.functional.softmax(w_, dim=2)
164
+
165
+ # attend to values
166
+ v = v.reshape(b, c, h * w)
167
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
168
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
169
+ h_ = h_.reshape(b, c, h, w)
170
+
171
+ h_ = self.proj_out(h_)
172
+
173
+ return x + h_
174
+
175
+
176
+ class Model(nn.Module):
177
+ def __init__(
178
+ self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, use_timestep=True
179
+ ):
180
+ super().__init__()
181
+ self.ch = ch
182
+ self.temb_ch = self.ch * 4
183
+ self.num_resolutions = len(ch_mult)
184
+ self.num_res_blocks = num_res_blocks
185
+ self.resolution = resolution
186
+ self.in_channels = in_channels
187
+
188
+ self.use_timestep = use_timestep
189
+ if self.use_timestep:
190
+ # timestep embedding
191
+ self.temb = nn.Module()
192
+ self.temb.dense = nn.ModuleList(
193
+ [
194
+ torch.nn.Linear(self.ch, self.temb_ch),
195
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
196
+ ]
197
+ )
198
+
199
+ # downsampling
200
+ self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
201
+
202
+ curr_res = resolution
203
+ in_ch_mult = (1,) + tuple(ch_mult)
204
+ self.down = nn.ModuleList()
205
+ for i_level in range(self.num_resolutions):
206
+ block = nn.ModuleList()
207
+ attn = nn.ModuleList()
208
+ block_in = ch * in_ch_mult[i_level]
209
+ block_out = ch * ch_mult[i_level]
210
+ for i_block in range(self.num_res_blocks):
211
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
212
+ block_in = block_out
213
+ if curr_res in attn_resolutions:
214
+ attn.append(AttnBlock(block_in))
215
+ down = nn.Module()
216
+ down.block = block
217
+ down.attn = attn
218
+ if i_level != self.num_resolutions - 1:
219
+ down.downsample = Downsample(block_in, resamp_with_conv)
220
+ curr_res = curr_res // 2
221
+ self.down.append(down)
222
+
223
+ # middle
224
+ self.mid = nn.Module()
225
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
226
+ self.mid.attn_1 = AttnBlock(block_in)
227
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
228
+
229
+ # upsampling
230
+ self.up = nn.ModuleList()
231
+ for i_level in reversed(range(self.num_resolutions)):
232
+ block = nn.ModuleList()
233
+ attn = nn.ModuleList()
234
+ block_out = ch * ch_mult[i_level]
235
+ skip_in = ch * ch_mult[i_level]
236
+ for i_block in range(self.num_res_blocks + 1):
237
+ if i_block == self.num_res_blocks:
238
+ skip_in = ch * in_ch_mult[i_level]
239
+ block.append(ResnetBlock(in_channels=block_in + skip_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
240
+ block_in = block_out
241
+ if curr_res in attn_resolutions:
242
+ attn.append(AttnBlock(block_in))
243
+ up = nn.Module()
244
+ up.block = block
245
+ up.attn = attn
246
+ if i_level != 0:
247
+ up.upsample = Upsample(block_in, resamp_with_conv)
248
+ curr_res = curr_res * 2
249
+ self.up.insert(0, up) # prepend to get consistent order
250
+
251
+ # end
252
+ self.norm_out = Normalize(block_in)
253
+ self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
254
+
255
+ def forward(self, x, t=None):
256
+ # assert x.shape[2] == x.shape[3] == self.resolution
257
+
258
+ if self.use_timestep:
259
+ # timestep embedding
260
+ assert t is not None
261
+ temb = get_timestep_embedding(t, self.ch)
262
+ temb = self.temb.dense[0](temb)
263
+ temb = nonlinearity(temb)
264
+ temb = self.temb.dense[1](temb)
265
+ else:
266
+ temb = None
267
+
268
+ # downsampling
269
+ hs = [self.conv_in(x)]
270
+ for i_level in range(self.num_resolutions):
271
+ for i_block in range(self.num_res_blocks):
272
+ h = self.down[i_level].block[i_block](hs[-1], temb)
273
+ if len(self.down[i_level].attn) > 0:
274
+ h = self.down[i_level].attn[i_block](h)
275
+ hs.append(h)
276
+ if i_level != self.num_resolutions - 1:
277
+ hs.append(self.down[i_level].downsample(hs[-1]))
278
+
279
+ # middle
280
+ h = hs[-1]
281
+ h = self.mid.block_1(h, temb)
282
+ h = self.mid.attn_1(h)
283
+ h = self.mid.block_2(h, temb)
284
+
285
+ # upsampling
286
+ for i_level in reversed(range(self.num_resolutions)):
287
+ for i_block in range(self.num_res_blocks + 1):
288
+ h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
289
+ if len(self.up[i_level].attn) > 0:
290
+ h = self.up[i_level].attn[i_block](h)
291
+ if i_level != 0:
292
+ h = self.up[i_level].upsample(h)
293
+
294
+ # end
295
+ h = self.norm_out(h)
296
+ h = nonlinearity(h)
297
+ h = self.conv_out(h)
298
+ return h
299
+
300
+
301
+ class Encoder(nn.Module):
302
+ def __init__(
303
+ self,
304
+ *,
305
+ ch,
306
+ out_ch,
307
+ ch_mult=(1, 2, 4, 8),
308
+ num_res_blocks,
309
+ attn_resolutions,
310
+ dropout=0.0,
311
+ resamp_with_conv=True,
312
+ in_channels,
313
+ resolution,
314
+ z_channels,
315
+ double_z=True,
316
+ **ignore_kwargs
317
+ ):
318
+ super().__init__()
319
+ self.ch = ch
320
+ self.temb_ch = 0
321
+ self.num_resolutions = len(ch_mult)
322
+ self.num_res_blocks = num_res_blocks
323
+ self.resolution = resolution
324
+ self.in_channels = in_channels
325
+
326
+ # downsampling
327
+ self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
328
+
329
+ curr_res = resolution
330
+ in_ch_mult = (1,) + tuple(ch_mult)
331
+ self.down = nn.ModuleList()
332
+ for i_level in range(self.num_resolutions):
333
+ block = nn.ModuleList()
334
+ attn = nn.ModuleList()
335
+ block_in = ch * in_ch_mult[i_level]
336
+ block_out = ch * ch_mult[i_level]
337
+ for i_block in range(self.num_res_blocks):
338
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
339
+ block_in = block_out
340
+ if curr_res in attn_resolutions:
341
+ attn.append(AttnBlock(block_in))
342
+ down = nn.Module()
343
+ down.block = block
344
+ down.attn = attn
345
+ if i_level != self.num_resolutions - 1:
346
+ down.downsample = Downsample(block_in, resamp_with_conv)
347
+ curr_res = curr_res // 2
348
+ self.down.append(down)
349
+
350
+ # middle
351
+ self.mid = nn.Module()
352
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
353
+ self.mid.attn_1 = AttnBlock(block_in)
354
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
355
+
356
+ # end
357
+ self.norm_out = Normalize(block_in)
358
+ self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1)
359
+
360
+ def forward(self, x):
361
+ # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
362
+
363
+ # timestep embedding
364
+ temb = None
365
+
366
+ # downsampling
367
+ hs = [self.conv_in(x)]
368
+ for i_level in range(self.num_resolutions):
369
+ for i_block in range(self.num_res_blocks):
370
+ h = self.down[i_level].block[i_block](hs[-1], temb)
371
+ if len(self.down[i_level].attn) > 0:
372
+ h = self.down[i_level].attn[i_block](h)
373
+ hs.append(h)
374
+ if i_level != self.num_resolutions - 1:
375
+ hs.append(self.down[i_level].downsample(hs[-1]))
376
+
377
+ # middle
378
+ h = hs[-1]
379
+ h = self.mid.block_1(h, temb)
380
+ h = self.mid.attn_1(h)
381
+ h = self.mid.block_2(h, temb)
382
+
383
+ # end
384
+ h = self.norm_out(h)
385
+ h = nonlinearity(h)
386
+ h = self.conv_out(h)
387
+ return h
388
+
389
+
390
+ class Decoder(nn.Module):
391
+ def __init__(
392
+ self,
393
+ *,
394
+ ch,
395
+ out_ch,
396
+ ch_mult=(1, 2, 4, 8),
397
+ num_res_blocks,
398
+ attn_resolutions,
399
+ dropout=0.0,
400
+ resamp_with_conv=True,
401
+ in_channels,
402
+ resolution,
403
+ z_channels,
404
+ give_pre_end=False,
405
+ **ignorekwargs
406
+ ):
407
+ super().__init__()
408
+ self.ch = ch
409
+ self.temb_ch = 0
410
+ self.num_resolutions = len(ch_mult)
411
+ self.num_res_blocks = num_res_blocks
412
+ self.resolution = resolution
413
+ self.in_channels = in_channels
414
+ self.give_pre_end = give_pre_end
415
+
416
+ # compute in_ch_mult, block_in and curr_res at lowest res
417
+ in_ch_mult = (1,) + tuple(ch_mult)
418
+ block_in = ch * ch_mult[self.num_resolutions - 1]
419
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
420
+ self.z_shape = (1, z_channels, curr_res, curr_res)
421
+
422
+ # z to block_in
423
+ self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
424
+
425
+ # middle
426
+ self.mid = nn.Module()
427
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
428
+ self.mid.attn_1 = AttnBlock(block_in)
429
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
430
+
431
+ # upsampling
432
+ self.up = nn.ModuleList()
433
+ for i_level in reversed(range(self.num_resolutions)):
434
+ block = nn.ModuleList()
435
+ attn = nn.ModuleList()
436
+ block_out = ch * ch_mult[i_level]
437
+ for i_block in range(self.num_res_blocks + 1):
438
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
439
+ block_in = block_out
440
+ if curr_res in attn_resolutions:
441
+ attn.append(AttnBlock(block_in))
442
+ up = nn.Module()
443
+ up.block = block
444
+ up.attn = attn
445
+ if i_level != 0:
446
+ up.upsample = Upsample(block_in, resamp_with_conv)
447
+ curr_res = curr_res * 2
448
+ self.up.insert(0, up) # prepend to get consistent order
449
+
450
+ # end
451
+ self.norm_out = Normalize(block_in)
452
+ self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
453
+
454
+ def forward(self, z):
455
+ # assert z.shape[1:] == self.z_shape[1:]
456
+ self.last_z_shape = z.shape
457
+
458
+ # timestep embedding
459
+ temb = None
460
+
461
+ # z to block_in
462
+ h = self.conv_in(z)
463
+
464
+ # middle
465
+ h = self.mid.block_1(h, temb)
466
+ h = self.mid.attn_1(h)
467
+ h = self.mid.block_2(h, temb)
468
+
469
+ # upsampling
470
+ for i_level in reversed(range(self.num_resolutions)):
471
+ for i_block in range(self.num_res_blocks + 1):
472
+ h = self.up[i_level].block[i_block](h, temb)
473
+ if len(self.up[i_level].attn) > 0:
474
+ h = self.up[i_level].attn[i_block](h)
475
+ if i_level != 0:
476
+ h = self.up[i_level].upsample(h)
477
+
478
+ # end
479
+ if self.give_pre_end:
480
+ return h
481
+
482
+ h = self.norm_out(h)
483
+ h = nonlinearity(h)
484
+ h = self.conv_out(h)
485
+ return h
486
+
487
+
488
+ class VUNet(nn.Module):
489
+ def __init__(
490
+ self,
491
+ *,
492
+ ch,
493
+ out_ch,
494
+ ch_mult=(1, 2, 4, 8),
495
+ num_res_blocks,
496
+ attn_resolutions,
497
+ dropout=0.0,
498
+ resamp_with_conv=True,
499
+ in_channels,
500
+ c_channels,
501
+ resolution,
502
+ z_channels,
503
+ use_timestep=False,
504
+ **ignore_kwargs
505
+ ):
506
+ super().__init__()
507
+ self.ch = ch
508
+ self.temb_ch = self.ch * 4
509
+ self.num_resolutions = len(ch_mult)
510
+ self.num_res_blocks = num_res_blocks
511
+ self.resolution = resolution
512
+
513
+ self.use_timestep = use_timestep
514
+ if self.use_timestep:
515
+ # timestep embedding
516
+ self.temb = nn.Module()
517
+ self.temb.dense = nn.ModuleList(
518
+ [
519
+ torch.nn.Linear(self.ch, self.temb_ch),
520
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
521
+ ]
522
+ )
523
+
524
+ # downsampling
525
+ self.conv_in = torch.nn.Conv2d(c_channels, self.ch, kernel_size=3, stride=1, padding=1)
526
+
527
+ curr_res = resolution
528
+ in_ch_mult = (1,) + tuple(ch_mult)
529
+ self.down = nn.ModuleList()
530
+ for i_level in range(self.num_resolutions):
531
+ block = nn.ModuleList()
532
+ attn = nn.ModuleList()
533
+ block_in = ch * in_ch_mult[i_level]
534
+ block_out = ch * ch_mult[i_level]
535
+ for i_block in range(self.num_res_blocks):
536
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
537
+ block_in = block_out
538
+ if curr_res in attn_resolutions:
539
+ attn.append(AttnBlock(block_in))
540
+ down = nn.Module()
541
+ down.block = block
542
+ down.attn = attn
543
+ if i_level != self.num_resolutions - 1:
544
+ down.downsample = Downsample(block_in, resamp_with_conv)
545
+ curr_res = curr_res // 2
546
+ self.down.append(down)
547
+
548
+ self.z_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=1, stride=1, padding=0)
549
+ # middle
550
+ self.mid = nn.Module()
551
+ self.mid.block_1 = ResnetBlock(in_channels=2 * block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
552
+ self.mid.attn_1 = AttnBlock(block_in)
553
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
554
+
555
+ # upsampling
556
+ self.up = nn.ModuleList()
557
+ for i_level in reversed(range(self.num_resolutions)):
558
+ block = nn.ModuleList()
559
+ attn = nn.ModuleList()
560
+ block_out = ch * ch_mult[i_level]
561
+ skip_in = ch * ch_mult[i_level]
562
+ for i_block in range(self.num_res_blocks + 1):
563
+ if i_block == self.num_res_blocks:
564
+ skip_in = ch * in_ch_mult[i_level]
565
+ block.append(ResnetBlock(in_channels=block_in + skip_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
566
+ block_in = block_out
567
+ if curr_res in attn_resolutions:
568
+ attn.append(AttnBlock(block_in))
569
+ up = nn.Module()
570
+ up.block = block
571
+ up.attn = attn
572
+ if i_level != 0:
573
+ up.upsample = Upsample(block_in, resamp_with_conv)
574
+ curr_res = curr_res * 2
575
+ self.up.insert(0, up) # prepend to get consistent order
576
+
577
+ # end
578
+ self.norm_out = Normalize(block_in)
579
+ self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
580
+
581
+ def forward(self, x, z):
582
+ # assert x.shape[2] == x.shape[3] == self.resolution
583
+
584
+ if self.use_timestep:
585
+ # timestep embedding
586
+ assert t is not None
587
+ temb = get_timestep_embedding(t, self.ch)
588
+ temb = self.temb.dense[0](temb)
589
+ temb = nonlinearity(temb)
590
+ temb = self.temb.dense[1](temb)
591
+ else:
592
+ temb = None
593
+
594
+ # downsampling
595
+ hs = [self.conv_in(x)]
596
+ for i_level in range(self.num_resolutions):
597
+ for i_block in range(self.num_res_blocks):
598
+ h = self.down[i_level].block[i_block](hs[-1], temb)
599
+ if len(self.down[i_level].attn) > 0:
600
+ h = self.down[i_level].attn[i_block](h)
601
+ hs.append(h)
602
+ if i_level != self.num_resolutions - 1:
603
+ hs.append(self.down[i_level].downsample(hs[-1]))
604
+
605
+ # middle
606
+ h = hs[-1]
607
+ z = self.z_in(z)
608
+ h = torch.cat((h, z), dim=1)
609
+ h = self.mid.block_1(h, temb)
610
+ h = self.mid.attn_1(h)
611
+ h = self.mid.block_2(h, temb)
612
+
613
+ # upsampling
614
+ for i_level in reversed(range(self.num_resolutions)):
615
+ for i_block in range(self.num_res_blocks + 1):
616
+ h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
617
+ if len(self.up[i_level].attn) > 0:
618
+ h = self.up[i_level].attn[i_block](h)
619
+ if i_level != 0:
620
+ h = self.up[i_level].upsample(h)
621
+
622
+ # end
623
+ h = self.norm_out(h)
624
+ h = nonlinearity(h)
625
+ h = self.conv_out(h)
626
+ return h
627
+
628
+
629
+ class SimpleDecoder(nn.Module):
630
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
631
+ super().__init__()
632
+ self.model = nn.ModuleList(
633
+ [
634
+ nn.Conv2d(in_channels, in_channels, 1),
635
+ ResnetBlock(in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0),
636
+ ResnetBlock(in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0),
637
+ ResnetBlock(in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0),
638
+ nn.Conv2d(2 * in_channels, in_channels, 1),
639
+ Upsample(in_channels, with_conv=True),
640
+ ]
641
+ )
642
+ # end
643
+ self.norm_out = Normalize(in_channels)
644
+ self.conv_out = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
645
+
646
+ def forward(self, x):
647
+ for i, layer in enumerate(self.model):
648
+ if i in [1, 2, 3]:
649
+ x = layer(x, None)
650
+ else:
651
+ x = layer(x)
652
+
653
+ h = self.norm_out(x)
654
+ h = nonlinearity(h)
655
+ x = self.conv_out(h)
656
+ return x
657
+
658
+
659
+ class UpsampleDecoder(nn.Module):
660
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch_mult=(2, 2), dropout=0.0):
661
+ super().__init__()
662
+ # upsampling
663
+ self.temb_ch = 0
664
+ self.num_resolutions = len(ch_mult)
665
+ self.num_res_blocks = num_res_blocks
666
+ block_in = in_channels
667
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
668
+ self.res_blocks = nn.ModuleList()
669
+ self.upsample_blocks = nn.ModuleList()
670
+ for i_level in range(self.num_resolutions):
671
+ res_block = []
672
+ block_out = ch * ch_mult[i_level]
673
+ for i_block in range(self.num_res_blocks + 1):
674
+ res_block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
675
+ block_in = block_out
676
+ self.res_blocks.append(nn.ModuleList(res_block))
677
+ if i_level != self.num_resolutions - 1:
678
+ self.upsample_blocks.append(Upsample(block_in, True))
679
+ curr_res = curr_res * 2
680
+
681
+ # end
682
+ self.norm_out = Normalize(block_in)
683
+ self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
684
+
685
+ def forward(self, x):
686
+ # upsampling
687
+ h = x
688
+ for k, i_level in enumerate(range(self.num_resolutions)):
689
+ for i_block in range(self.num_res_blocks + 1):
690
+ h = self.res_blocks[i_level][i_block](h, None)
691
+ if i_level != self.num_resolutions - 1:
692
+ h = self.upsample_blocks[k](h)
693
+ h = self.norm_out(h)
694
+ h = nonlinearity(h)
695
+ h = self.conv_out(h)
696
+ return h
wham/models/vqgan/taming/quantize.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # All files under this directory are originally from the taming-transformers repository:
2
+ # https://github.com/CompVis/taming-transformers
3
+
4
+ # MIT License
5
+ # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
6
+ # 2023 Microsoft Research
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in all
16
+ # copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
19
+ # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
20
+ # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
21
+ # IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
22
+ # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
23
+ # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
24
+ # OR OTHER DEALINGS IN THE SOFTWARE.
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import numpy as np
29
+ from einops import rearrange
30
+
31
+
32
+ class VectorQuantizer2(nn.Module):
33
+ """
34
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
35
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
36
+ """
37
+
38
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
39
+ # backwards compatibility we use the buggy version by default, but you can
40
+ # specify legacy=False to fix it.
41
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
42
+ super().__init__()
43
+ self.n_e = n_e
44
+ self.e_dim = e_dim
45
+ self.beta = beta
46
+ self.legacy = legacy
47
+
48
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
49
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
50
+
51
+ self.remap = remap
52
+ if self.remap is not None:
53
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
54
+ self.re_embed = self.used.shape[0]
55
+ self.unknown_index = unknown_index # "random" or "extra" or integer
56
+ if self.unknown_index == "extra":
57
+ self.unknown_index = self.re_embed
58
+ self.re_embed = self.re_embed + 1
59
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " f"Using {self.unknown_index} for unknown indices.")
60
+ else:
61
+ self.re_embed = n_e
62
+
63
+ self.sane_index_shape = sane_index_shape
64
+
65
+ def remap_to_used(self, inds):
66
+ ishape = inds.shape
67
+ assert len(ishape) > 1
68
+ inds = inds.reshape(ishape[0], -1)
69
+ used = self.used.to(inds)
70
+ match = (inds[:, :, None] == used[None, None, ...]).long()
71
+ new = match.argmax(-1)
72
+ unknown = match.sum(2) < 1
73
+ if self.unknown_index == "random":
74
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
75
+ else:
76
+ new[unknown] = self.unknown_index
77
+ return new.reshape(ishape)
78
+
79
+ def unmap_to_all(self, inds):
80
+ ishape = inds.shape
81
+ assert len(ishape) > 1
82
+ inds = inds.reshape(ishape[0], -1)
83
+ used = self.used.to(inds)
84
+ if self.re_embed > self.used.shape[0]: # extra token
85
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
86
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
87
+ return back.reshape(ishape)
88
+
89
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
90
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
91
+ assert rescale_logits == False, "Only for interface compatible with Gumbel"
92
+ assert return_logits == False, "Only for interface compatible with Gumbel"
93
+ # reshape z -> (batch, height, width, channel) and flatten
94
+ z = rearrange(z, "b c h w -> b h w c").contiguous()
95
+ z_flattened = z.view(-1, self.e_dim)
96
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
97
+
98
+ d = (
99
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
100
+ + torch.sum(self.embedding.weight**2, dim=1)
101
+ - 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
102
+ )
103
+
104
+ min_encoding_indices = torch.argmin(d, dim=1)
105
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
106
+ perplexity = None
107
+ min_encodings = None
108
+
109
+ # compute loss for embedding
110
+ if not self.legacy:
111
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
112
+ else:
113
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
114
+
115
+ # preserve gradients
116
+ z_q = z + (z_q - z).detach()
117
+
118
+ # reshape back to match original input shape
119
+ z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
120
+
121
+ if self.remap is not None:
122
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
123
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
124
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
125
+
126
+ if self.sane_index_shape:
127
+ min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
128
+
129
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
130
+
131
+ def get_codebook_entry(self, indices, shape):
132
+ # shape specifying (batch, height, width, channel)
133
+ if self.remap is not None:
134
+ indices = indices.reshape(shape[0], -1) # add batch axis
135
+ indices = self.unmap_to_all(indices)
136
+ indices = indices.reshape(-1) # flatten again
137
+
138
+ # get quantized latent vectors
139
+ z_q = self.embedding(indices)
140
+
141
+ if shape is not None:
142
+ z_q = z_q.view(shape)
143
+ # reshape back to match original input shape
144
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
145
+
146
+ return z_q
wham/models/vqgan/taming_vq_model.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Wrapper for the VQ models from the taming-transformers repo
2
+ # https://github.com/CompVis/taming-transformers
3
+
4
+ from typing import Any, Mapping
5
+ import pytorch_lightning as pl
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from wham.models.vqgan.taming.model import Encoder, Decoder
10
+ from wham.models.vqgan.taming.quantize import VectorQuantizer2 as VectorQuantizer
11
+
12
+ from wham.models.wham_base.tensor_spaces import TensorSpace
13
+ from wham.models.wham_base.encoder_decoder import EncoderDecoderBase
14
+
15
+
16
+ HARDCODED_IMAGE_SIZE = 128
17
+
18
+
19
+ def taming_vq_preprocess_images(imgs):
20
+ """Normalize images (as pytorch tensor uint8s) as in taming-transformers"""
21
+ return imgs.float() / 127.5 - 1.0
22
+
23
+
24
+ def taming_vq_revert_preprocess_images(imgs):
25
+ """Revert preprocessing of images from taming to uint8 as in taming-transformers"""
26
+ # Clamp first
27
+ imgs = torch.clamp(imgs, -1.0, 1.0)
28
+ return ((imgs + 1) * 127.5).byte()
29
+
30
+
31
+ class _VQModelFromTamingRepository(pl.LightningModule):
32
+ """
33
+ This aims to be the original VQ model from the taming-transformers repo with as little modifications as possible. This should not be used directly.
34
+ Source: https://github.com/CompVis/taming-transformers/blob/master/taming/models/vqgan.py
35
+
36
+ MIT License
37
+ Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
38
+ 2023 Microsoft Research
39
+
40
+ Permission is hereby granted, free of charge, to any person obtaining a copy
41
+ of this software and associated documentation files (the "Software"), to deal
42
+ in the Software without restriction, including without limitation the rights
43
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
44
+ copies of the Software, and to permit persons to whom the Software is
45
+ furnished to do so, subject to the following conditions:
46
+
47
+ The above copyright notice and this permission notice shall be included in all
48
+ copies or substantial portions of the Software.
49
+
50
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
51
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
52
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
53
+ IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
54
+ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
55
+ OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
56
+ OR OTHER DEALINGS IN THE SOFTWARE.
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ ddconfig,
62
+ n_embed,
63
+ embed_dim,
64
+ ckpt_path=None,
65
+ ignore_keys=[],
66
+ image_key="image",
67
+ colorize_nlabels=None,
68
+ monitor=None,
69
+ remap=None,
70
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
71
+ ):
72
+ super().__init__()
73
+ self.image_key = image_key
74
+ self.encoder = Encoder(**ddconfig)
75
+ self.decoder = Decoder(**ddconfig)
76
+ # NOTE: Loss is disabled for this repo (we only want inference)
77
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape)
78
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
79
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
80
+ # Note: the '!= "None"' check is for checkpoints that mistakenly stored the None as a string
81
+ if ckpt_path is not None and ckpt_path != "None":
82
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
83
+ self.image_key = image_key
84
+ if colorize_nlabels is not None:
85
+ assert type(colorize_nlabels) == int
86
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
87
+ if monitor is not None:
88
+ self.monitor = monitor
89
+
90
+ def init_from_ckpt(self, path, ignore_keys=list()):
91
+ sd = torch.load(path, map_location="cpu")["state_dict"]
92
+ keys = list(sd.keys())
93
+ for k in keys:
94
+ for ik in ignore_keys:
95
+ if k.startswith(ik):
96
+ print("Deleting key {} from state_dict.".format(k))
97
+ del sd[k]
98
+ self.load_state_dict(sd, strict=False)
99
+ print(f"Restored from {path}")
100
+
101
+ def encode(self, x):
102
+ h = self.encoder(x)
103
+ h = self.quant_conv(h)
104
+ quant, emb_loss, info = self.quantize(h)
105
+ return quant, emb_loss, info
106
+
107
+ def decode(self, quant):
108
+ quant = self.post_quant_conv(quant)
109
+ dec = self.decoder(quant)
110
+ return dec
111
+
112
+ def forward(self, input):
113
+ quant, diff, _ = self.encode(input)
114
+ dec = self.decode(quant)
115
+ return dec, diff
116
+
117
+ def get_input(self, batch, k):
118
+ x = batch[k]
119
+ if len(x.shape) == 3:
120
+ x = x[..., None]
121
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
122
+ return x.float()
123
+
124
+ def training_step(self, batch, batch_idx, optimizer_idx):
125
+ raise NotImplementedError("This copy of the model code does not support training")
126
+
127
+ def validation_step(self, batch, batch_idx):
128
+ raise NotImplementedError("This copy of the model code does not support training")
129
+
130
+ def configure_optimizers(self):
131
+ raise NotImplementedError("This copy of the model code does not support training")
132
+
133
+ def get_last_layer(self):
134
+ return self.decoder.conv_out.weight
135
+
136
+ def log_images(self, batch, **kwargs):
137
+ log = dict()
138
+ x = self.get_input(batch, self.image_key)
139
+ x = x.to(self.device)
140
+ xrec, _ = self(x)
141
+ if x.shape[1] > 3:
142
+ # colorize with random projection
143
+ assert xrec.shape[1] > 3
144
+ x = self.to_rgb(x)
145
+ xrec = self.to_rgb(xrec)
146
+ log["inputs"] = x
147
+ log["reconstructions"] = xrec
148
+ return log
149
+
150
+ def to_rgb(self, x):
151
+ assert self.image_key == "segmentation"
152
+ if not hasattr(self, "colorize"):
153
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
154
+ x = F.conv2d(x, weight=self.colorize)
155
+ x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
156
+ return x
157
+
158
+
159
+ class TamingVQModel(EncoderDecoderBase):
160
+
161
+ __DEBUG_CREATION_KWARGS__ = {
162
+ "ckpt_path": None,
163
+ "model_spec": {
164
+ "taming_n_embed": 16,
165
+ "taming_embed_dim": 8,
166
+ "taming_num_indices_per_axis": 8,
167
+ "taming_ddconfig": {
168
+ "double_z": False,
169
+ "z_channels": 16,
170
+ "resolution": 128,
171
+ "in_channels": 3,
172
+ "out_ch": 3,
173
+ "ch": 128,
174
+ "ch_mult": [1, 1, 1, 1, 1],
175
+ "num_res_blocks": 1,
176
+ "attn_resolutions": [16],
177
+ "dropout": 0.0,
178
+ },
179
+ },
180
+ }
181
+
182
+ def __init__(self, model_spec, ckpt_path, **kwargs):
183
+ super().__init__()
184
+ self._vocab_size = model_spec["taming_n_embed"]
185
+ self.num_indices_per_axis = model_spec["taming_num_indices_per_axis"]
186
+ self.num_indices_total = self.num_indices_per_axis**2
187
+ self.taming_embed_dim = model_spec["taming_embed_dim"]
188
+ taming_ddconfig = model_spec.get("taming_ddconfig", None)
189
+ if taming_ddconfig is None:
190
+ raise ValueError("To run TamingVQModel, specify model_spec.taming_ddconfig, which should match the ddconfig used when training the model")
191
+
192
+ self.vq_model = _VQModelFromTamingRepository(taming_ddconfig, self._vocab_size, self.taming_embed_dim, ckpt_path=ckpt_path)
193
+
194
+ resolution = taming_ddconfig["resolution"]
195
+ in_channels = taming_ddconfig["in_channels"]
196
+ self.world_space = TensorSpace((in_channels, resolution, resolution), dtype=torch.uint8, low=0, high=255)
197
+ self.encoder_space = TensorSpace((self.num_indices_total,), dtype=torch.long, low=0, high=self.vocab_size - 1)
198
+
199
+ @property
200
+ def vocab_size(self):
201
+ """Return the number of entries in the codebook."""
202
+ return self._vocab_size
203
+
204
+ @property
205
+ def encoded_bottleneck_dim(self):
206
+ """Return the dimensionality of the latent vector encoded into codebook indices."""
207
+ return self.num_indices_total
208
+
209
+ def _preprocess_images(self, images):
210
+ """Preprocess images (B, C, H, W)"""
211
+ return taming_vq_preprocess_images(images)
212
+
213
+ def _revert_image_preprocess(self, x_batch):
214
+ """Revert the preprocessing done in _preprocess_images"""
215
+ return taming_vq_revert_preprocess_images(x_batch)
216
+
217
+ def decode_from_encoding_indices(self, encoding_indices, return_vq_embeddings=False):
218
+ """Return decoded images (B, C, H, W) for a batch of encoding indices (B, self.encoded_bottleneck_dim)"""
219
+ batch_size = encoding_indices.shape[0]
220
+ z = self.vq_model.quantize.get_codebook_entry(encoding_indices, shape=(batch_size, self.num_indices_per_axis, self.num_indices_per_axis, self.taming_embed_dim))
221
+ data_recon = self.vq_model.decode(z)
222
+ # Denormalize and cast to uint8
223
+ data_recon = self._revert_image_preprocess(data_recon)
224
+ if return_vq_embeddings:
225
+ return data_recon, z
226
+ return data_recon
227
+
228
+ def get_encoding_indices_for_images(self, images):
229
+ """
230
+ Return encoding indices (B, self.encoded_bottleneck_dim) for a batch of images (B, C, H, W).
231
+ Useful auxiliary method for testing.
232
+ """
233
+ x_batch = self._preprocess_images(images)
234
+ _, _, (_, _, encoding_indices) = self.vq_model.encode(x_batch)
235
+ # Split back into (B, self.encoded_bottleneck_dim)
236
+ encoding_indices = encoding_indices.view(images.shape[0], -1)
237
+ return encoding_indices
238
+
239
+ def forward_returning_action_and_embedding(self, states, actions_input, timesteps, attention_mask, images):
240
+ seq_len_dim = 1
241
+ assert images.shape[seq_len_dim] == 1, f"We require seq_len==1, but provided {images.shape[seq_len_dim]}."
242
+ images = images.squeeze(dim=seq_len_dim) # get rid of timestep dimension
243
+ x_batch = self._preprocess_images(images)
244
+ quant, _, (_, _, encoding_indices) = self.vq_model.encode(x_batch)
245
+ # Split back into (B, self.encoded_bottleneck_dim)
246
+ encoding_indices = encoding_indices.reshape(quant.shape[0], 1, quant.shape[2], quant.shape[3])
247
+ quant = quant.unsqueeze(seq_len_dim)
248
+ return None, {"quantized": quant, "encoding_indices": encoding_indices}
249
+
250
+ def _encode(self, world_space_tensor: torch.tensor) -> torch.tensor:
251
+ batch, time = world_space_tensor.shape[:2]
252
+ world_space_tensor = world_space_tensor.view(batch * time, *world_space_tensor.shape[2:])
253
+ encodings = self.get_encoding_indices_for_images(world_space_tensor)
254
+ # Reshape back to (batch, time, ...)
255
+ encodings = encodings.view(batch, time, -1)
256
+ return encodings
257
+
258
+ def _decode(self, encoder_space_tensor: torch.tensor) -> torch.tensor:
259
+ batch, time = encoder_space_tensor.shape[:2]
260
+ encoder_space_tensor = encoder_space_tensor.view(batch * time, *encoder_space_tensor.shape[2:])
261
+ decoded = self.decode_from_encoding_indices(encoder_space_tensor)
262
+ # Reshape back to (batch, time, ...)
263
+ decoded = decoded.view(batch, time, *decoded.shape[1:])
264
+ return decoded
wham/models/vqgan/vqgan.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from wham.models.wham_base.tensor_spaces import TensorSpace
7
+ from wham.models.wham_base.encoder_decoder import EncoderDecoderBase
8
+
9
+ from wham.models.vqgan import vqgan_models as vqgan
10
+ from wham.models.vqvae.vqvae_utils import make_grid, normalise_rgb, rev_normalise_rgb
11
+
12
+ from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
13
+ from pytorch_lightning.loggers.wandb import WandbLogger
14
+
15
+ TARGET_GAN_UPDATE = 5
16
+ GAN_DWEIGHT_MAX = 250
17
+ GAN_LOGIT_CAP = 5.0
18
+ MAX_PIXEL_WEIGHTING = 0.1
19
+
20
+ # The GAN parts are from Taming Transformers (https://github.com/CompVis/taming-transformers)
21
+ """
22
+ ViT-VQGAN is based on:
23
+ Yu, Jiahui, et al. "Vector-quantized image modeling with improved vqgan."
24
+ ICLR 2022
25
+ """
26
+
27
+
28
+ def create_vqgan_model_for_training(variant):
29
+ return VQGANModel(variant=variant)
30
+
31
+
32
+ class VQGANModel(EncoderDecoderBase):
33
+ @classmethod
34
+ def create_from_variant(cls, variant):
35
+ return VQGANModel(variant=variant)
36
+
37
+ def __init__(self, variant=None, ckpt_path=None, model_spec=None):
38
+ super().__init__()
39
+ self.save_hyperparameters()
40
+ self.variant = variant
41
+ if model_spec is not None:
42
+ self.model_spec = model_spec
43
+ else:
44
+ self.model_spec = variant["model_spec"]
45
+
46
+ # Batches of images we will use for logging
47
+ self.reference_x_batch = None # Same images used throughout training to see progress of the model
48
+ self.random_batch = None # Different images every iteration
49
+
50
+ if variant is None and "image_size_per_y_axis" in self.model_spec:
51
+ self.image_size_x = self.model_spec["image_size_per_x_axis"]
52
+ self.image_size_y = self.model_spec["image_size_per_y_axis"]
53
+ else:
54
+ assert "image_size_per_x_axis" in variant and "image_size_per_y_axis" in variant, "Please provide the image size as separate x and y for the VQGAN model"
55
+ self.image_size_x = variant["image_size_per_x_axis"]
56
+ self.image_size_y = variant["image_size_per_y_axis"]
57
+
58
+ self._embedding_dim = self.model_spec["embedding_dim"]
59
+ self.encoder = vqgan.ViTEncoder(
60
+ patch_size=self.model_spec["patch_size"],
61
+ transf_dim=self.model_spec["transf_dim"],
62
+ embedding_dim=self.model_spec["embedding_dim"],
63
+ image_size_x=self.image_size_x,
64
+ image_size_y=self.image_size_y,
65
+ num_layers=self.model_spec["num_layers"],
66
+ head_size=self.model_spec["head_size"],
67
+ )
68
+ self._bottleneck_size = self.encoder.bottleneck
69
+
70
+ self.vq_vae = vqgan.ViTVectorQuantizer(
71
+ self.model_spec["vocab_size"],
72
+ self.model_spec["embedding_dim"],
73
+ self.model_spec["commitment_cost"],
74
+ )
75
+
76
+ self.decoder = vqgan.ViTDecoder(
77
+ patch_size=self.model_spec["patch_size"],
78
+ transf_dim=self.model_spec["transf_dim"],
79
+ embedding_dim=self.model_spec["embedding_dim"],
80
+ image_size_x=self.image_size_x,
81
+ image_size_y=self.image_size_y,
82
+ num_layers=self.model_spec["num_layers"],
83
+ head_size=self.model_spec["head_size"],
84
+ expected_bottleneck=self._bottleneck_size,
85
+ )
86
+
87
+ self.is_perceptual = self.model_spec["is_perceptual"]
88
+ assert self.is_perceptual # This should be on
89
+
90
+ # Keep track of the usage of the codebook indices
91
+ self.codebook_index_usage = np.zeros(self.model_spec["vocab_size"], dtype=np.int64)
92
+
93
+ self.gan = self.model_spec.get("use_gan", False)
94
+ if self.gan:
95
+ # Only make the patchgan if we are using it. This makes it easier to experiment with GAN settings after pretraining the VQ-VAE for instance
96
+ self.patch_gan = vqgan.PatchGan(channel_start=self.model_spec["gan_channel_start"])
97
+ # Make a copy of the patchgan since we are only using a single optimizer
98
+ self.target_patchgan = vqgan.PatchGan(channel_start=self.model_spec["gan_channel_start"])
99
+ self.target_patchgan.requires_grad_(False)
100
+ self.target_patchgan.load_state_dict(self.patch_gan.state_dict())
101
+ self.target_update = TARGET_GAN_UPDATE
102
+
103
+ # At which iteration to start using the GAN loss
104
+ self.gan_start = self.model_spec["gan_start"]
105
+ # How much weight to give to the GAN loss gradients compared to the vq autoencoder loss
106
+ self.gan_weight = self.model_spec["gan_weight"]
107
+ # How many steps to train the discriminator before applying the gan loss.
108
+ self.gan_discrim_pretrain = self.model_spec["gan_discrim_pretrain"]
109
+ # How many steps to warmup the gan loss
110
+ self.gan_discrim_warmup = self.model_spec["gan_discrim_warmup"]
111
+ # Keeping track of the number of updates
112
+ self.updates = 0
113
+ print(f"Using GAN with weight {self.gan_weight} and target update {self.target_update} and gan start {self.gan_start} over {self.gan_discrim_warmup} steps")
114
+
115
+ self.lpips_model = None
116
+ # We don't need this for using the encoder/decoder
117
+ # self.lpips_model = lpips.LPIPS(net=self.model_spec["lpips_model"]).eval()
118
+ # for param in self.lpips_model.parameters():
119
+ # param.requires_grad = False
120
+
121
+ if ckpt_path is not None and ckpt_path != "None":
122
+ print(f"Initing VQGAN model from {ckpt_path}")
123
+ loaded_ckpt = torch.load(ckpt_path, map_location="cpu")
124
+ # Can ignore stuff here
125
+ self.load_state_dict(loaded_ckpt["state_dict"], strict=False)
126
+
127
+ self.world_space = TensorSpace((3, self.image_size_y, self.image_size_x), dtype=torch.uint8, low=0, high=255)
128
+ self.encoder_space = TensorSpace((self._bottleneck_size,), dtype=torch.long, low=0, high=self.vocab_size - 1)
129
+
130
+ @property
131
+ def vocab_size(self):
132
+ """Return the number of entries in the codebook."""
133
+ return self.vq_vae._vocab_size
134
+
135
+ @property
136
+ def encoded_bottleneck_dim(self):
137
+ """Return the dimensionality of the latent vector encoded into codebook indices."""
138
+ return self._bottleneck_size
139
+
140
+ @property
141
+ def embedding_dim(self):
142
+ """The dimensionality of quantized vectors (the dimension of codebook vectors)."""
143
+ return self.vq_vae._embedding_dim
144
+
145
+ def _get_last_layer(self):
146
+ """
147
+ The last layer used for generating the image.
148
+ Used for balancing the gradients of the reconstruction and the GAN loss.
149
+ """
150
+ return self.decoder.get_last_layer()
151
+
152
+ def _preprocess_images(self, images):
153
+ """Preprocess images (B, C, H, W)"""
154
+ x_batch = images.float() / 255
155
+ x_batch = normalise_rgb(x_batch)
156
+ return x_batch
157
+
158
+ def _revert_image_preprocess(self, x_batch):
159
+ """Revert the preprocessing done in _preprocess_images"""
160
+ normalized_imgs = rev_normalise_rgb(x_batch.clone())
161
+ x_batch = torch.clip(normalized_imgs, 0, 1)
162
+ images = (x_batch * 255).byte()
163
+ return images
164
+
165
+ def _get_latent_continuous(self, batch):
166
+ z = self.encoder(batch)
167
+ return z
168
+
169
+ def _get_latent_discretized(self, z):
170
+ z_quantized, vq_loss, perplexity, indices = self.vq_vae(z)
171
+ return z_quantized, vq_loss, perplexity, indices
172
+
173
+ def _encode_decode(self, x_batch):
174
+ z = self._get_latent_continuous(x_batch)
175
+ z_quantized, vq_loss, perplexity, indices = self._get_latent_discretized(z)
176
+ data_recon = self.decoder(z_quantized)
177
+ return vq_loss, perplexity, data_recon, indices
178
+
179
+ def _log_vars(self, log_vars):
180
+ prefix = "train" if self.training else "val"
181
+ for key, val in log_vars.items():
182
+ self.log(f"{prefix}/{key}", val, on_step=True, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
183
+
184
+ def decode_from_encoding_indices(self, encoding_indices):
185
+ """Return decoded images (B, C, H, W) for a batch of encoding indices (B, self.encoded_bottleneck_dim)"""
186
+ z = self.vq_vae.convert_encoding_indices_to_quantized_embeddings(encoding_indices)
187
+ data_recon = self.decoder(z)
188
+ # Denormalize and cast to uint8
189
+ data_recon = self._revert_image_preprocess(data_recon)
190
+ return data_recon
191
+
192
+ def get_encoding_indices_for_images(self, images):
193
+ """
194
+ Return encoding indices (B, self.encoded_bottleneck_dim) for a batch of images (B, C, H, W).
195
+ Useful auxiliary method for testing.
196
+ """
197
+ x_batch = self._preprocess_images(images)
198
+ z = self._get_latent_continuous(x_batch)
199
+ encoding_indices = self.vq_vae(z, only_return_encoding_indices=True)
200
+ return encoding_indices
201
+
202
+ def forward_returning_action_and_embedding(self, states, actions_input, timesteps, attention_mask, images):
203
+ raise NotImplementedError
204
+
205
+ def get_encoding_output(self, images):
206
+ """
207
+ Return outputs from the encoder for a batch of images (B, C, H, W).
208
+ Returns:
209
+ quantized_z: (B, self.encoded_bottleneck_dim, self.embedding_dim), quantized latent vectors with straight-through gradient estimator
210
+ vq_loss: (B, ), VQ loss for each image
211
+ perplexity: (B, ), perplexity for each image
212
+ encoding_indices: (B, self.encoded_bottleneck_dim), encoding indices for each image
213
+ """
214
+ x_batch = self._preprocess_images(images)
215
+ z = self._get_latent_continuous(x_batch)
216
+ quantized_z, vq_loss, perplexity, encoding_indices = self.vq_vae(z)
217
+ quantized_z = quantized_z.view(quantized_z.shape[0], self.encoded_bottleneck_dim, self.embedding_dim)
218
+ return quantized_z, vq_loss, perplexity, encoding_indices
219
+
220
+ def _encode(self, world_space_tensor: torch.tensor) -> torch.tensor:
221
+ # Flatten time and batch dim into one
222
+ batch, time = world_space_tensor.shape[:2]
223
+ world_space_tensor = world_space_tensor.view(batch * time, *world_space_tensor.shape[2:])
224
+ encodings = self.get_encoding_indices_for_images(world_space_tensor)
225
+ # Reshape back to (batch, time, ...)
226
+ encodings = encodings.view(batch, time, -1)
227
+ return encodings
228
+
229
+ def _decode(self, encoder_space_tensor: torch.tensor) -> torch.tensor:
230
+ # Flatten time and batch dim into one
231
+ batch, time = encoder_space_tensor.shape[:2]
232
+ encoder_space_tensor = encoder_space_tensor.view(batch * time, *encoder_space_tensor.shape[2:])
233
+ decoded = self.decode_from_encoding_indices(encoder_space_tensor)
234
+ # Reshape back to (batch, time, ...)
235
+ decoded = decoded.view(batch, time, *decoded.shape[1:])
236
+ return decoded
wham/models/vqgan/vqgan_models.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ # Copyright (c) 2018 Zalando Research
3
+ # 2023 Microsoft Research
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
16
+ # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
17
+ # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
18
+ # IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
19
+ # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20
+ # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
21
+ # OR OTHER DEALINGS IN THE SOFTWARE.
22
+
23
+ from math import sqrt
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+
29
+ from wham.models.nn.nanoGPT import GPTConfig, SelfAttentionBlock
30
+ from wham.models.nn.model_blocks import ConvNextBlock, ConvNextDownsample, ConvNextDownsampleBig
31
+
32
+ # Mainly following https://github.com/zalandoresearch/pytorch-vq-vae/blob/master/vq-vae.ipynb
33
+ """
34
+ ViT-VQGAN is based on:
35
+ Yu, Jiahui, et al. "Vector-quantized image modeling with improved vqgan."
36
+ ICLR 2022
37
+ """
38
+
39
+
40
+ def _convert_encoding_indices_to_quantized_embeddings(encoding_indices, embedding_layer, vocab_size, embedding_dim):
41
+ """
42
+ Args:
43
+ encoding_indices: tensor of integers (batch_size, bottleneck_size)
44
+ Each batch item represents a single image as a sequence of integers (indeces of codebook vectors)
45
+ Output:
46
+ quantized: tensor of floats (batch_size, bottleneck_size, embedding_dim)
47
+ """
48
+ batch_dim, bottleneck_size = encoding_indices.shape[:2]
49
+
50
+ encoding_indices = encoding_indices.view(-1).unsqueeze(1)
51
+ one_hot_encoding_indices = torch.zeros(encoding_indices.shape[0], vocab_size, device=encoding_indices.device)
52
+ one_hot_encoding_indices.scatter_(1, encoding_indices, 1)
53
+
54
+ quantized = torch.matmul(one_hot_encoding_indices, embedding_layer)
55
+ quantized = quantized.view(batch_dim, bottleneck_size, embedding_dim).contiguous()
56
+ return quantized
57
+
58
+
59
+ class ViTVectorQuantizer(nn.Module):
60
+ """
61
+ Vector Quantizer for a Vision Transformer based VQ model using normalised codebook embeddings as in https://arxiv.org/abs/2110.04627.
62
+ """
63
+
64
+ def __init__(self, vocab_size, embedding_dim, commitment_cost, epsilon=1e-5):
65
+ super().__init__()
66
+
67
+ self._embedding_dim = embedding_dim
68
+ self._vocab_size = vocab_size
69
+ self._epsilon = epsilon
70
+
71
+ self._embedding = nn.Embedding(self._vocab_size, self._embedding_dim)
72
+ self._embedding.weight.data.uniform_(-1 / self._vocab_size, 1 / self._vocab_size)
73
+ self._commitment_cost = commitment_cost
74
+
75
+ @property
76
+ def vocab_size(self):
77
+ """Return the number of entries in the codebook."""
78
+ return self._vocab_size
79
+
80
+ def convert_encoding_indices_to_quantized_embeddings(self, encoding_indices):
81
+ """
82
+ Args:
83
+ encoding_indices: tensor of integers (batch_size, bottleneck_size)
84
+ Each batch item represents a single image as a sequence of integers (indeces of codebook vectors)
85
+ Output:
86
+ quantized: tensor of floats (batch_size, self._embedding_dim, bottleneck_size)
87
+ """
88
+ return _convert_encoding_indices_to_quantized_embeddings(encoding_indices, F.normalize(self._embedding.weight), self._vocab_size, self._embedding_dim)
89
+
90
+ def forward(self, inputs, only_return_encoding_indices=False):
91
+ """
92
+ If only_return_encoding_indices is True, then only return the indices of codebook vectors
93
+ """
94
+ input_shape = inputs.shape
95
+
96
+ # Flatten input from Batch Tokens Embedding to B*T E
97
+ flat_input = inputs.view(-1, self._embedding_dim)
98
+ # Normalize inputs
99
+ flat_input = F.normalize(flat_input)
100
+
101
+ # Embeddings are always normalized
102
+ embeddings_to_use = F.normalize(self._embedding.weight)
103
+
104
+ # Calculate distances
105
+ distances = torch.sum(flat_input**2, dim=1, keepdim=True) + torch.sum(embeddings_to_use**2, dim=1) - 2 * torch.matmul(flat_input, embeddings_to_use.t())
106
+
107
+ # Encoding
108
+ encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
109
+ if only_return_encoding_indices:
110
+ # Add back batch dimension
111
+ return encoding_indices.view(input_shape[0], -1)
112
+ one_hot_encoding_indices = torch.zeros(encoding_indices.shape[0], self._vocab_size, device=inputs.device)
113
+ one_hot_encoding_indices.scatter_(1, encoding_indices, 1)
114
+
115
+ # Quantize and unflatten
116
+ quantized = torch.matmul(one_hot_encoding_indices, embeddings_to_use).view(input_shape)
117
+
118
+ # Loss
119
+ e_latent_loss = F.mse_loss(quantized.detach(), inputs)
120
+ q_latent_loss = F.mse_loss(quantized, inputs.detach())
121
+ loss = q_latent_loss + self._commitment_cost * e_latent_loss
122
+
123
+ quantized = inputs + (quantized - inputs).detach()
124
+ avg_probs = torch.mean(one_hot_encoding_indices, dim=0)
125
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + self._epsilon)))
126
+
127
+ return quantized, loss, perplexity, encoding_indices.view(input_shape[0], -1)
128
+
129
+
130
+ class ViTEncoder(nn.Module):
131
+ def __init__(self, patch_size, transf_dim, embedding_dim, image_size_x, image_size_y, num_layers, head_size):
132
+ super().__init__()
133
+
134
+ self.image_size_x = image_size_x
135
+ self.image_size_y = image_size_y
136
+ # We will pad the image to make it divisible by patch_size
137
+ self.x_pad = (patch_size - (self.image_size_x % patch_size)) % patch_size
138
+ self.y_pad = (patch_size - (self.image_size_y % patch_size)) % patch_size
139
+ assert (self.image_size_x + self.x_pad) % patch_size == 0 and (
140
+ self.image_size_y + self.y_pad
141
+ ) % patch_size == 0, "image_size_x and image_size_y must be divisible by patch_size"
142
+
143
+ self.vit_tokens = ((image_size_x + self.x_pad) // patch_size) * ((image_size_y + self.y_pad) // patch_size)
144
+ self._bottleneck = self.vit_tokens
145
+ print(f"Bottleneck is {self.bottleneck} for image size {image_size_x}x{image_size_y} with ViT Encoder and patch size {patch_size}")
146
+
147
+ self.patch_size = patch_size
148
+ self.transf_dim = transf_dim
149
+ self.embedding_dim = embedding_dim
150
+
151
+ self.proj1 = nn.Linear(3 * patch_size * patch_size, transf_dim)
152
+ self.pos_embeds = nn.Embedding(self.vit_tokens, transf_dim)
153
+
154
+ assert self.transf_dim % head_size == 0, "transf_dim must be divisible by head_size"
155
+ n_heads = self.transf_dim // head_size
156
+ transformer_config = GPTConfig(block_size=self.vit_tokens, n_layer=num_layers, n_head=n_heads, n_embd=transf_dim, bias=False, dropout=0)
157
+ self.vit = nn.Sequential(*[SelfAttentionBlock(transformer_config) for _ in range(num_layers)])
158
+
159
+ self.output_ln = nn.LayerNorm(transf_dim)
160
+ self.output_proj = nn.Linear(transf_dim, embedding_dim)
161
+
162
+ # init all weights
163
+ self.apply(self._init_weights)
164
+ # apply special scaled init to the residual projections, per GPT-2 paper
165
+ for pn, p in self.named_parameters():
166
+ if pn.endswith("c_proj.weight"):
167
+ torch.nn.init.normal_(p, mean=0.0, std=0.02 / sqrt(2 * transformer_config.n_layer))
168
+
169
+ @property
170
+ def bottleneck(self):
171
+ return self._bottleneck
172
+
173
+ def _init_weights(self, module):
174
+ if isinstance(module, nn.Linear):
175
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
176
+ if module.bias is not None:
177
+ torch.nn.init.zeros_(module.bias)
178
+ elif isinstance(module, nn.Embedding):
179
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
180
+
181
+ def forward(self, inputs):
182
+ # inputs: (batch_size, 3, image_size_x, image_size_y)
183
+
184
+ # Patch input images
185
+ batch_size = inputs.shape[0]
186
+ padded_inputs = F.pad(inputs, (0, self.x_pad, 0, self.y_pad), mode="constant", value=0)
187
+ x = padded_inputs.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
188
+ num_x_patches = (self.image_size_x + self.x_pad) // self.patch_size
189
+ num_y_patches = (self.image_size_y + self.y_pad) // self.patch_size
190
+
191
+ # inputs is of shape (batch_size, 3, num_x_patches, num_y_patches, patch_size, patch_size)
192
+ # Turn it into (batch_size, patches, input_dim)
193
+ patches = x.permute(0, 2, 3, 1, 4, 5).contiguous().view(batch_size, num_x_patches * num_y_patches, 3 * self.patch_size * self.patch_size)
194
+
195
+ proj_patches = self.proj1(patches)
196
+
197
+ pos_embeds = self.pos_embeds.weight.unsqueeze(0).repeat(batch_size, 1, 1)
198
+ vit_input = proj_patches + pos_embeds
199
+ vit_output = self.vit(vit_input)
200
+
201
+ vit_output = self.output_ln(vit_output)
202
+ embeddings = self.output_proj(vit_output)
203
+ normalised_embeddings = F.normalize(embeddings, dim=-1)
204
+
205
+ return normalised_embeddings
206
+
207
+
208
+ class ViTDecoder(nn.Module):
209
+ def __init__(self, patch_size, transf_dim, embedding_dim, image_size_x, image_size_y, num_layers, head_size, expected_bottleneck=None):
210
+ super().__init__()
211
+
212
+ self.image_size_x = image_size_x
213
+ self.image_size_y = image_size_y
214
+ self.x_pad = (patch_size - (self.image_size_x % patch_size)) % patch_size
215
+ self.y_pad = (patch_size - (self.image_size_y % patch_size)) % patch_size
216
+
217
+ assert (self.image_size_x + self.x_pad) % patch_size == 0 and (
218
+ self.image_size_y + self.y_pad
219
+ ) % patch_size == 0, "image_size_x and image_size_y must be divisible by patch_size"
220
+
221
+ self.vit_tokens = ((image_size_x + self.x_pad) // patch_size) * ((image_size_y + self.y_pad) // patch_size)
222
+ if expected_bottleneck is not None:
223
+ assert (
224
+ self.vit_tokens == expected_bottleneck
225
+ ), f"Expected bottleneck of {expected_bottleneck} but got {self.vit_tokens} for image size {image_size_x}x{image_size_y} with ViT Decoder and patch size {patch_size}"
226
+
227
+ self.patch_size = patch_size
228
+ self.transf_dim = transf_dim
229
+ self.embedding_dim = embedding_dim
230
+
231
+ self.proj1 = nn.Linear(embedding_dim, transf_dim)
232
+ self.pos_embeds = nn.Embedding(self.vit_tokens, transf_dim)
233
+
234
+ assert self.transf_dim % head_size == 0, "transf_dim must be divisible by head_size"
235
+ n_heads = self.transf_dim // head_size
236
+ transformer_config = GPTConfig(block_size=self.vit_tokens, n_layer=num_layers, n_head=n_heads, n_embd=transf_dim, bias=False, dropout=0)
237
+ self.vit = nn.Sequential(*[SelfAttentionBlock(transformer_config) for _ in range(num_layers)])
238
+
239
+ self.output_ln = nn.LayerNorm(transf_dim)
240
+ self.output_proj = nn.Linear(transf_dim, 3 * patch_size * patch_size)
241
+
242
+ # Couldn't resist the name
243
+ self.folder = nn.Fold(
244
+ output_size=(self.image_size_y + self.y_pad, self.image_size_x + self.x_pad),
245
+ kernel_size=(self.patch_size, self.patch_size),
246
+ stride=(self.patch_size, self.patch_size),
247
+ )
248
+
249
+ # init all weights
250
+ self.apply(self._init_weights)
251
+ # apply special scaled init to the residual projections, per GPT-2 paper
252
+ for pn, p in self.named_parameters():
253
+ if pn.endswith("c_proj.weight"):
254
+ torch.nn.init.normal_(p, mean=0.0, std=0.02 / sqrt(2 * transformer_config.n_layer))
255
+
256
+ def _init_weights(self, module):
257
+ if isinstance(module, nn.Linear):
258
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
259
+ if module.bias is not None:
260
+ torch.nn.init.zeros_(module.bias)
261
+ elif isinstance(module, nn.Embedding):
262
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
263
+
264
+ def forward(self, inputs):
265
+ # Patch input images
266
+ batch_size = inputs.shape[0]
267
+
268
+ # Unproject the embeddings from the VQ embedding space to the transformer space
269
+ proj_patches = self.proj1(inputs).reshape(batch_size, self.vit_tokens, self.transf_dim)
270
+
271
+ pos_embeds = self.pos_embeds.weight.unsqueeze(0).repeat(batch_size, 1, 1)
272
+ vit_input = proj_patches + pos_embeds
273
+ vit_output = self.vit(vit_input)
274
+
275
+ vit_output = self.output_ln(vit_output)
276
+
277
+ predictions = self.output_proj(vit_output) # (batch, patches, 3 * patch_size * patch_size)
278
+
279
+ # Reassemble the image into (batch, 3, image_size_x, image_size_y)
280
+ fold_inputs = predictions.permute(0, 2, 1).contiguous()
281
+ image_pred = self.folder(fold_inputs)
282
+
283
+ unpadded_image_pred = image_pred[:, :, : self.image_size_y, : self.image_size_x] # Remove padding in the same way it was applied in the encoder
284
+
285
+ # Anything on the output?
286
+ return unpadded_image_pred
287
+
288
+ def get_last_layer(self):
289
+ """
290
+ Return the last layer weights of the model, to use for loss balancing.
291
+ """
292
+ return self.output_proj.weight
293
+
294
+
295
+ class PatchGan(nn.Module):
296
+ def __init__(self, channel_start):
297
+ super().__init__()
298
+ x = channel_start
299
+ self.downsample1 = ConvNextDownsampleBig(3, x)
300
+ self.block1 = ConvNextBlock(x)
301
+ self.downsample2 = ConvNextDownsampleBig(x, x)
302
+ self.block2 = ConvNextBlock(x)
303
+ self.last = nn.Conv2d(x, 1, kernel_size=1, stride=1, padding=0)
304
+
305
+ def forward(self, x):
306
+ batch_size = x.shape[0]
307
+ y = torch.nn.functional.gelu(self.downsample1(x))
308
+ y = self.block1(y)
309
+ z = torch.nn.functional.gelu(self.downsample2(y))
310
+ z = self.block2(z)
311
+ return self.last(z).reshape(batch_size, -1)
wham/models/vqvae/vqvae_utils.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+
6
+
7
+ def normalise_rgb(X, channels_first=True):
8
+ """
9
+ Take in an image tensor of shape [ ... , 3], which is assumed to have already been divided by
10
+ 255 so X \in [0,1]. These functions do additional normalisation, roughly ending up with mean
11
+ of zero and unit variance. The constants appeared in most vision repos, and are supposedly the
12
+ 'right' constants to use based on imagenet statistics.
13
+ assert X.shape[-1] == 3
14
+ """
15
+ channel_dim = 1 if channels_first else -1
16
+ assert X.shape[channel_dim] == 3
17
+ if channels_first:
18
+ X[:, 0, ...] -= 0.485
19
+ X[:, 0, ...] /= 0.229
20
+ X[:, 1, ...] -= 0.456
21
+ X[:, 1, ...] /= 0.224
22
+ X[:, 2, ...] -= 0.406
23
+ X[:, 2, ...] /= 0.225
24
+ else:
25
+ X[..., 0] -= 0.485
26
+ X[..., 0] /= 0.229
27
+ X[..., 1] -= 0.456
28
+ X[..., 1] /= 0.224
29
+ X[..., 2] -= 0.406
30
+ X[..., 2] /= 0.225
31
+ return X
32
+
33
+
34
+ def rev_normalise_rgb(X, channels_first=True):
35
+ """
36
+ Reverse `normalise_rgb`, so the output lives in [0,1]. This function is needed for
37
+ reconstruction visualisation, etc.
38
+ """
39
+ channel_dim = 1 if channels_first else -1
40
+ assert X.shape[channel_dim] == 3
41
+ if channels_first:
42
+ X[:, 0, ...] *= 0.229
43
+ X[:, 0, ...] += 0.485
44
+ X[:, 1, ...] *= 0.224
45
+ X[:, 1, ...] += 0.456
46
+ X[:, 2, ...] *= 0.225
47
+ X[:, 2, ...] += 0.406
48
+ else:
49
+ X[..., 0] *= 0.229
50
+ X[..., 0] += 0.485
51
+ X[..., 1] *= 0.224
52
+ X[..., 1] += 0.456
53
+ X[..., 2] *= 0.225
54
+ X[..., 2] += 0.406
55
+ return X
56
+
57
+
58
+ @torch.no_grad()
59
+ def make_grid(
60
+ tensor: Union[torch.Tensor, List[torch.Tensor]],
61
+ nrow: int = 8,
62
+ padding: int = 2,
63
+ normalize: bool = False,
64
+ value_range: Optional[Tuple[int, int]] = None,
65
+ scale_each: bool = False,
66
+ pad_value: float = 0.0,
67
+ **kwargs,
68
+ ) -> torch.Tensor:
69
+ """
70
+ Make a grid of images.
71
+
72
+ Args:
73
+ tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
74
+ or a list of images all of the same size.
75
+ nrow (int, optional): Number of images displayed in each row of the grid.
76
+ The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
77
+ padding (int, optional): amount of padding. Default: ``2``.
78
+ normalize (bool, optional): If True, shift the image to the range (0, 1),
79
+ by the min and max values specified by ``value_range``. Default: ``False``.
80
+ value_range (tuple, optional): tuple (min, max) where min and max are numbers,
81
+ then these numbers are used to normalize the image. By default, min and max
82
+ are computed from the tensor.
83
+ scale_each (bool, optional): If ``True``, scale each image in the batch of
84
+ images separately rather than the (min, max) over all images. Default: ``False``.
85
+ pad_value (float, optional): Value for the padded pixels. Default: ``0``.
86
+
87
+ Returns:
88
+ grid (Tensor): the tensor containing grid of images.
89
+ """
90
+ if not torch.is_tensor(tensor):
91
+ if isinstance(tensor, list):
92
+ for t in tensor:
93
+ if not torch.is_tensor(t):
94
+ raise TypeError(f"tensor or list of tensors expected, got a list containing {type(t)}")
95
+ else:
96
+ raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
97
+
98
+ # if list of tensors, convert to a 4D mini-batch Tensor
99
+ if isinstance(tensor, list):
100
+ tensor = torch.stack(tensor, dim=0)
101
+
102
+ if tensor.dim() == 2: # single image H x W
103
+ tensor = tensor.unsqueeze(0)
104
+ if tensor.dim() == 3: # single image
105
+ if tensor.size(0) == 1: # if single-channel, convert to 3-channel
106
+ tensor = torch.cat((tensor, tensor, tensor), 0)
107
+ tensor = tensor.unsqueeze(0)
108
+
109
+ if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
110
+ tensor = torch.cat((tensor, tensor, tensor), 1)
111
+
112
+ if normalize is True:
113
+ tensor = tensor.clone() # avoid modifying tensor in-place
114
+ if value_range is not None and not isinstance(value_range, tuple):
115
+ raise TypeError("value_range has to be a tuple (min, max) if specified. min and max are numbers")
116
+
117
+ def norm_ip(img, low, high):
118
+ img.clamp_(min=low, max=high)
119
+ img.sub_(low).div_(max(high - low, 1e-5))
120
+
121
+ def norm_range(t, value_range):
122
+ if value_range is not None:
123
+ norm_ip(t, value_range[0], value_range[1])
124
+ else:
125
+ norm_ip(t, float(t.min()), float(t.max()))
126
+
127
+ if scale_each is True:
128
+ for t in tensor: # loop over mini-batch dimension
129
+ norm_range(t, value_range)
130
+ else:
131
+ norm_range(tensor, value_range)
132
+
133
+ if not isinstance(tensor, torch.Tensor):
134
+ raise TypeError("tensor should be of type torch.Tensor")
135
+ if tensor.size(0) == 1:
136
+ return tensor.squeeze(0)
137
+
138
+ # make the mini-batch of images into a grid
139
+ nmaps = tensor.size(0)
140
+ xmaps = min(nrow, nmaps)
141
+ ymaps = int(math.ceil(float(nmaps) / xmaps))
142
+ height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
143
+ num_channels = tensor.size(1)
144
+ grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
145
+ k = 0
146
+ for y in range(ymaps):
147
+ for x in range(xmaps):
148
+ if k >= nmaps:
149
+ break
150
+ # Tensor.copy_() is a valid method but seems to be missing from the stubs
151
+ # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_
152
+ grid.narrow(1, y * height + padding, height - padding).narrow(2, x * width + padding, width - padding).copy_(tensor[k]) # type: ignore[attr-defined]
153
+ k = k + 1
154
+ return grid
wham/models/wham_base/__init__.py ADDED
File without changes
wham/models/wham_base/encode_predict_decode_base.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Union, Type, Callable, Tuple, Mapping, Optional
2
+
3
+ import torch as th
4
+ import pytorch_lightning as pl
5
+ from tensordict import TensorDict # type: ignore # requires installing stubs for tensordict
6
+
7
+ from .tensor_spaces import TensorDictSpace
8
+ from .encoder_decoder import EncoderDecoderBase
9
+ from .pl_creation_args import LightningModuleCreationArgs
10
+
11
+
12
+ def create_encoder_args_from_config_dict(
13
+ config_dict: dict[str, Union[dict[str, Any], tuple]], class_name_to_model: Callable[[str], Type[pl.LightningModule]]
14
+ ) -> Mapping[str, Union[LightningModuleCreationArgs, Tuple[LightningModuleCreationArgs, LightningModuleCreationArgs]]]:
15
+ """
16
+ Given a dictionary mapping modality names to their encoder-decoder arguments, create the corresponding
17
+ creation args (LightningModuleCreationArgs) for each modality.
18
+
19
+ See LightningModuleCreationArgs.from_dict for more details.
20
+
21
+ Args:
22
+ config_dict: A dictionary mapping modality names to their encoder-decoder arguments.
23
+ Root level of this dictionary should be modality names we expect.
24
+ class_name_to_model: A function mapping class names to their corresponding model classes.
25
+
26
+ Returns:
27
+ A dictionary mapping modality names to their encoder-decoder creation args.
28
+ Each value may be a LightningModuleCreationArgs, or a tuple of two LightningModuleCreationArgs.
29
+ If value is a LightningModuleCreationArgs, then same model is used for encoding and decoding.
30
+ If value is a tuple of two LightningModuleCreationArgs, then first is used for encoding and second for decoding.
31
+ """
32
+ # Giving explicit type hint here to make mypy happy
33
+ modalities: dict[str, Any] = {}
34
+ for modality_name, modality_config in config_dict.items():
35
+ if isinstance(modality_config, (list, tuple)):
36
+ assert len(modality_config) == 2, f"Expected two entries for modality {modality_name}, got {len(modality_config)}"
37
+ modalities[modality_name] = (
38
+ LightningModuleCreationArgs.from_dict(modality_config[0], class_name_to_model),
39
+ LightningModuleCreationArgs.from_dict(modality_config[1], class_name_to_model),
40
+ )
41
+ else:
42
+ modalities[modality_name] = LightningModuleCreationArgs.from_dict(modality_config, class_name_to_model)
43
+ return modalities
44
+
45
+
46
+ def create_encoder_modules_from_args(
47
+ encoders: Mapping[str, Union[LightningModuleCreationArgs, Tuple[LightningModuleCreationArgs, LightningModuleCreationArgs]]], remove_checkpoint_path: bool = True
48
+ ) -> th.nn.ModuleDict:
49
+ """
50
+ Create the encoder modules from given creation args (LightningModuleCreationArgs).
51
+
52
+ Args:
53
+ encoders: A dictionary mapping modality names to their encoder-decoder creation args.
54
+ If value is a LightningModuleCreationArgs, then same model is used for encoding and decoding.
55
+ If value is a tuple of two LightningModuleCreationArgs, then first is used for encoding and second for decoding.
56
+ remove_checkpoint_path: If True, then remove the checkpoint_path from the creation args. This prepares the
57
+ created moduled to be properly saved and loaded as part of the bigger model
58
+
59
+ Returns:
60
+ A dictionary mapping modality names to their encoder-decoder modules.
61
+ """
62
+ modalities = {}
63
+ for modality_name, modality_args in encoders.items():
64
+ if isinstance(modality_args, (list, tuple)):
65
+ modalities[modality_name] = th.nn.ModuleList(
66
+ [
67
+ modality_args[0].create_module(remove_checkpoint_path=remove_checkpoint_path),
68
+ modality_args[1].create_module(remove_checkpoint_path=remove_checkpoint_path),
69
+ ]
70
+ )
71
+ else:
72
+ modalities[modality_name] = modality_args.create_module(remove_checkpoint_path=remove_checkpoint_path)
73
+ return th.nn.ModuleDict(modalities)
74
+
75
+
76
+ class EncodePredictDecodeModule(pl.LightningModule):
77
+ """
78
+ Base-class for models that encode, predict and decode.
79
+
80
+ Args:
81
+ context_encoders: A dictionary mapping modality names to their encoder-decoders.
82
+ If value is a pl.LightningModule, then same model is used for encoding and decoding.
83
+ If value is a tuple of two pl.LightningModule, then first is used for encoding and second for decoding.
84
+ condition_encoders: Same as `context_encoders`, but for conditions.
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ predictor_args: LightningModuleCreationArgs,
90
+ context_encoders: th.nn.ModuleDict,
91
+ condition_encoders: Optional[th.nn.ModuleDict] = None,
92
+ ):
93
+ if condition_encoders is None:
94
+ condition_encoders = th.nn.ModuleDict(dict())
95
+ self._assert_encoders(context_encoders)
96
+ self._assert_encoders(condition_encoders)
97
+ super().__init__()
98
+
99
+ self.context_encoders = context_encoders
100
+ self.condition_encoders = condition_encoders
101
+
102
+ self.context_world_space, self.context_encoder_space = self._get_spaces_from_encoders(context_encoders)
103
+ self.condition_world_space, self.condition_encoder_space = self._get_spaces_from_encoders(condition_encoders)
104
+
105
+ self.predictor = predictor_args.create_module(context_space=self.context_encoder_space, condition_space=self.condition_encoder_space)
106
+
107
+ def _assert_encoders(self, encoders: th.nn.ModuleDict) -> None:
108
+ """Check that encoder dictionary is valid"""
109
+ assert isinstance(encoders, th.nn.ModuleDict), f"Invalid type for encoders: {type(encoders)}. Expected th.nn.ModuleDict"
110
+ for modality_name, encoder in encoders.items():
111
+ assert isinstance(encoder, EncoderDecoderBase) or isinstance(
112
+ encoder, th.nn.ModuleList
113
+ ), f"Invalid type for modality {modality_name}: {type(encoder)}. Expected EncoderDecoderBase or Tuple[EncoderDecoderBase]"
114
+ if isinstance(encoder, th.nn.ModuleList):
115
+ assert len(encoder) == 2, f"Invalid number of arguments for modality {modality_name}: {len(encoder)}. Expected two (encoder, decoder)"
116
+ assert isinstance(
117
+ encoder[0], EncoderDecoderBase
118
+ ), f"Invalid type for encoder of modality {modality_name}: {type(encoder[0])}. Expected EncoderDecoderBase"
119
+ assert isinstance(
120
+ encoder[1], EncoderDecoderBase
121
+ ), f"Invalid type for decoder of modality {modality_name}: {type(encoder[1])}. Expected EncoderDecoderBase"
122
+
123
+ def _get_spaces_from_encoders(self, encoders: th.nn.ModuleDict) -> Tuple[TensorDictSpace, TensorDictSpace]:
124
+ """
125
+ Given a modality dictionary mapping modality names to their encoders and decoders,
126
+ extract the world space and encoder space,
127
+ """
128
+ world_spaces = {}
129
+ encoder_spaces = {}
130
+ for modality_name, modality in encoders.items():
131
+ if isinstance(modality, EncoderDecoderBase):
132
+ encoder_spaces[modality_name] = modality.encoder_space
133
+ world_spaces[modality_name] = modality.world_space
134
+ elif isinstance(modality, th.nn.ModuleList):
135
+ assert len(modality) == 2, f"Invalid number of modules for modality {modality_name}: {len(modality)}. Expected 2."
136
+ # Make sure that both encoder and decoder spaces match the expected space
137
+ encoder_encoder_space = modality[0].encoder_space
138
+ decoder_encoder_space = modality[1].encoder_space
139
+ assert (
140
+ encoder_encoder_space == decoder_encoder_space
141
+ ), f"Encoder and decoder spaces for modality {modality_name} do not match: {encoder_encoder_space} != {decoder_encoder_space}"
142
+ encoder_world_space = modality[0].world_space
143
+ decoder_world_space = modality[1].world_space
144
+ assert (
145
+ encoder_world_space == decoder_world_space
146
+ ), f"Encoder and decoder world spaces for modality {modality_name} do not match: {encoder_world_space} != {decoder_world_space}"
147
+ encoder_spaces[modality_name] = encoder_encoder_space
148
+ world_spaces[modality_name] = encoder_world_space
149
+ else:
150
+ raise TypeError(f"Invalid type for modality {modality_name}: {type(modality)}. Expected EncoderDecoderBase or th.nn.ModuleList")
151
+ return TensorDictSpace(world_spaces), TensorDictSpace(encoder_spaces)
152
+
153
+ def _encode(self, input_td: TensorDict, encoders: th.nn.ModuleDict, space: TensorDictSpace) -> TensorDict:
154
+ """
155
+ Encode input_td into encoder space using the given encoders.
156
+
157
+ Args:
158
+ input_td: A tensordict mapping modality names to their inputs.
159
+ encoders: A dictionary mapping modality names to their encoders.
160
+
161
+ Returns:
162
+ An encoded tensordict.
163
+ """
164
+ encoded_context = {}
165
+ preceding_dims = space.get_preceding_dimensions(input_td, allow_key_subset=True)
166
+ for modality_name in input_td.keys():
167
+ encoder = encoders[modality_name]
168
+ if isinstance(encoder, EncoderDecoderBase):
169
+ encoded_context[modality_name] = encoder.encode(input_td[modality_name])
170
+ elif isinstance(encoder, th.nn.ModuleList):
171
+ encoded_context[modality_name] = encoder[0].encode(input_td[modality_name])
172
+ else:
173
+ raise TypeError(f"Invalid type for modality {modality_name}: {type(encoder)}. Expected EncoderDecoderBase or th.nn.ModuleList")
174
+ return TensorDict(encoded_context, batch_size=preceding_dims)
175
+
176
+ def _decode(self, input_td: TensorDict, encoders: th.nn.ModuleDict, space: TensorDictSpace) -> TensorDict:
177
+ """
178
+ Decode input_td into the original space using the given encoders.
179
+
180
+ Args:
181
+ input_td: A tensordict mapping modality names to their encoded inputs.
182
+ encoders: A dictionary mapping modality names to their encoders.
183
+
184
+ Returns:
185
+ A decoded tensordict.
186
+ """
187
+ decoded_context = {}
188
+ preceding_dims = space.get_preceding_dimensions(input_td, allow_key_subset=True)
189
+ for modality_name in input_td.keys():
190
+ encoder = encoders[modality_name]
191
+ if isinstance(encoder, EncoderDecoderBase):
192
+ decoded_context[modality_name] = encoder.decode(input_td[modality_name])
193
+ elif isinstance(encoder, th.nn.ModuleList):
194
+ decoded_context[modality_name] = encoder[1].decode(input_td[modality_name])
195
+ else:
196
+ raise TypeError(f"Invalid type for modality {modality_name}: {type(encoder)}. Expected EncoderDecoderBase or th.nn.ModuleList")
197
+ return TensorDict(decoded_context, batch_size=preceding_dims)
198
+
199
+ def encode_context(self, context: TensorDict) -> TensorDict:
200
+ """
201
+ Encode the given context into the encoder space.
202
+
203
+ Args:
204
+ context: A tensordict mapping modality names to their inputs.
205
+
206
+ Returns:
207
+ An encoded tensordict.
208
+ """
209
+ assert self.context_world_space.contains(context, allow_key_subset=True), f"Context {context} is not contained in context world space {self.context_world_space}"
210
+ return self._encode(context, self.context_encoders, self.context_world_space)
211
+
212
+ def decode_context(self, encoded_context: TensorDict) -> TensorDict:
213
+ """
214
+ Decode the given encoded context into the original space.
215
+
216
+ Args:
217
+ encoded_context: A tensordict mapping modality names to their encoded inputs.
218
+
219
+ Returns:
220
+ A decoded tensordict.
221
+ """
222
+ assert self.context_encoder_space.contains(
223
+ encoded_context,
224
+ allow_key_subset=True,
225
+ ), f"Encoded context {encoded_context} is not contained in context encoder space {self.context_encoder_space}"
226
+ return self._decode(encoded_context, self.context_encoders, self.context_encoder_space)
227
+
228
+ def encode_condition(self, condition: TensorDict) -> TensorDict:
229
+ """
230
+ Encode the given condition into the encoder space.
231
+
232
+ Args:
233
+ condition: A tensordict mapping modality names to their inputs.
234
+
235
+ Returns:
236
+ An encoded tensordict.
237
+ """
238
+ assert self.condition_world_space.contains(
239
+ condition, allow_key_subset=True
240
+ ), f"Condition {condition} is not contained in condition world space {self.condition_world_space}"
241
+ return self._encode(condition, self.condition_encoders, self.condition_world_space)
242
+
243
+ def decode_condition(self, encoded_condition: TensorDict) -> TensorDict:
244
+ """
245
+ Decode the given encoded condition into the original space.
246
+
247
+ Args:
248
+ encoded_condition: A tensordict mapping modality names to their encoded inputs.
249
+
250
+ Returns:
251
+ A decoded tensordict.
252
+ """
253
+ assert self.condition_encoder_space.contains(
254
+ encoded_condition, allow_key_subset=True
255
+ ), f"Encoded condition {encoded_condition} is not contained in condition encoder space {self.condition_encoder_space}"
256
+ return self._decode(encoded_condition, self.condition_encoders, self.condition_encoder_space)