Commit
·
f1c8ee5
1
Parent(s):
e199e7c
Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- CODE_OF_CONDUCT.md +10 -0
- CONTRIBUTING.md +14 -0
- LICENSE.md +96 -0
- SECURITY.md +37 -0
- assets/Demonstrator/Fig_01.png +0 -0
- assets/Demonstrator/Fig_02.png +0 -0
- assets/Demonstrator/Fig_03.png +0 -0
- assets/Demonstrator/Fig_04.png +0 -0
- assets/Demonstrator/Fig_05.png +0 -0
- assets/Demonstrator/Fig_06.png +0 -0
- assets/Demonstrator/Fig_07.png +0 -0
- assets/Demonstrator/Fig_08.png +0 -0
- assets/Demonstrator/Fig_09.png +0 -0
- assets/Demonstrator/Fig_10.png +0 -0
- assets/Demonstrator/Fig_11.png +0 -0
- assets/Demonstrator/Fig_12.png +0 -0
- assets/Demonstrator/Fig_13.png +0 -0
- assets/Demonstrator/Fig_14.png +0 -0
- assets/Demonstrator/Fig_15.png +0 -0
- assets/Readme/model_capabilities.gif +3 -0
- assets/Readme/wham_gen_1.gif +3 -0
- assets/Readme/wham_gen_2.gif +3 -0
- assets/Readme/wham_gen_3.gif +3 -0
- assets/Readme/wham_gen_4.gif +3 -0
- assets/Readme/wham_gen_5.gif +3 -0
- assets/Readme/wham_gen_6.gif +3 -0
- assets/Readme/wham_gen_7.gif +3 -0
- assets/Readme/wham_gen_8.gif +3 -0
- assets/Readme/wham_gen_9.gif +3 -0
- configs/metadata_custom_tag.config +5 -0
- models/WHAM_1.6B_v1.ckpt +3 -0
- models/WHAM_200M.ckpt +3 -0
- requirements.txt +48 -0
- run_dreaming.py +264 -0
- run_server.py +519 -0
- setup_local.sh +21 -0
- wham/models/nn/model_blocks.py +49 -0
- wham/models/nn/nanoGPT.py +665 -0
- wham/models/pl/__init__.py +0 -0
- wham/models/pl/pl_base_model.py +5 -0
- wham/models/vqgan/taming/LICENSE +24 -0
- wham/models/vqgan/taming/model.py +696 -0
- wham/models/vqgan/taming/quantize.py +146 -0
- wham/models/vqgan/taming_vq_model.py +264 -0
- wham/models/vqgan/vqgan.py +236 -0
- wham/models/vqgan/vqgan_models.py +311 -0
- wham/models/vqvae/vqvae_utils.py +154 -0
- wham/models/wham_base/__init__.py +0 -0
- wham/models/wham_base/encode_predict_decode_base.py +256 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
fonts/arial.ttf filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Microsoft Open Source Code of Conduct
|
2 |
+
|
3 |
+
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
4 |
+
|
5 |
+
Resources:
|
6 |
+
|
7 |
+
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
|
8 |
+
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
9 |
+
- Contact [[email protected]](mailto:[email protected]) with questions or concerns
|
10 |
+
- Employees can reach out at [aka.ms/opensource/moderation-support](https://aka.ms/opensource/moderation-support)
|
CONTRIBUTING.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing
|
2 |
+
|
3 |
+
This project welcomes contributions and suggestions. Most contributions require you to
|
4 |
+
agree to a Contributor License Agreement (CLA) declaring that you have the right to,
|
5 |
+
and actually do, grant us the rights to use your contribution. For details, visit
|
6 |
+
https://cla.microsoft.com.
|
7 |
+
|
8 |
+
When you submit a pull request, a CLA-bot will automatically determine whether you need
|
9 |
+
to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the
|
10 |
+
instructions provided by the bot. You will only need to do this once across all repositories using our CLA.
|
11 |
+
|
12 |
+
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
13 |
+
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
14 |
+
or contact [[email protected]](mailto:[email protected]) with any additional questions or comments.
|
LICENSE.md
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MICROSOFT RESEARCH LICENSE TERMS
|
2 |
+
|
3 |
+
**IF YOU LIVE IN THE UNITED STATES, PLEASE READ THE “BINDING ARBITRATION AND CLASS ACTION WAIVER” SECTION BELOW. IT AFFECTS HOW DISPUTES ARE RESOLVED.**
|
4 |
+
|
5 |
+
These license terms are an agreement between you and Microsoft Corporation (or one of its affiliates). They apply to the source code, object code, machine learning models, or data (collectively “Materials”) that accompany this license. IF YOU COMPLY WITH THESE LICENSE TERMS, YOU HAVE THE RIGHTS BELOW. BY USING THE MATERIALS, YOU ACCEPT THESE TERMS.
|
6 |
+
|
7 |
+
## 1) INSTALLATION AND USE RIGHTS TO THE MATERIALS.
|
8 |
+
|
9 |
+
Subject to the terms of this agreement, you have the below rights, if applicable, to use the Materials solely for non-commercial, non-revenue generating, research purposes:
|
10 |
+
|
11 |
+
a) **Source Code.** If source code is included, you may use and modify the source code, but you may not distribute the source code.
|
12 |
+
|
13 |
+
b) **Object Code.** If object code is included, you may use the object code, but you may not distribute the object code.
|
14 |
+
|
15 |
+
c) **Models.** If machine learning model(s) are included, you may use the model(s), but you may not distribute the models.
|
16 |
+
|
17 |
+
d) **Data.** If data is included, you may use the data, but your use must be consistent with the consent under which the data was provided and/or gathered and you may not modify or distribute the data.
|
18 |
+
|
19 |
+
## 2) SCOPE OF LICENSE.
|
20 |
+
|
21 |
+
The Materials are licensed, not sold. Microsoft reserves all other rights. Unless applicable law gives you more rights despite this limitation, you will not (and have no right to):
|
22 |
+
|
23 |
+
a) Work around any technical limitations in the Materials that only allow you to use it in certain ways;
|
24 |
+
|
25 |
+
b) Reverse engineer, decompile or disassemble the Materials;
|
26 |
+
|
27 |
+
c) Remove, minimize, block, or modify any notices of Microsoft or its suppliers in the Materials;
|
28 |
+
|
29 |
+
d) Use the Materials in any way that is against the law or to create or propagate malware; or
|
30 |
+
|
31 |
+
e) Share, publish, distribute or lend the Materials, provide the Materials as a stand-alone hosted solution for others to use, or transfer the Materials or this agreement to any third party.
|
32 |
+
|
33 |
+
## 3) PERSONAL DATA.
|
34 |
+
|
35 |
+
If the data (set forth in Section 1(d) above) includes or is found to include any data that enables any ability to identify an individual ("Personal Data"), you will not use such Personal Data for any purpose other than was authorized and consented to by the data subject/research participant. You will not use Personal Data to contact any person. You will keep Personal Data in strict confidence. You will not share any Personal Data that is collected or in your possession with any third party for any reason and as required under the original consent agreement. Further, you will destroy the Personal Data and any backup or copies, **immediately upon the completion of your research.**
|
36 |
+
|
37 |
+
## 4) LICENSE TO MICROSOFT.
|
38 |
+
|
39 |
+
Notwithstanding the limitations in Section 1, you may distribute your modifications back to Microsoft, and if you do provide Microsoft with modifications of the Materials, you hereby grant Microsoft, without any restrictions or limitations, a non-exclusive, perpetual, irrevocable, royalty-free, assignable and sub-licensable license, to reproduce, publicly perform or display, install, use, modify, post, distribute, make and have made, sell and transfer such modifications and derivatives for any purpose.
|
40 |
+
|
41 |
+
## 5) PUBLICATION.
|
42 |
+
|
43 |
+
You may publish (or present papers or articles) on your results from using the Materials provided that no material or substantial portion of the Materials is included in any such publication or presentation.
|
44 |
+
|
45 |
+
## 6) FEEDBACK.
|
46 |
+
|
47 |
+
Any feedback about the Materials provided by you to us is voluntarily given, and Microsoft shall be free to use the feedback as it sees fit without obligation or restriction of any kind, even if the feedback is designated by you as confidential. **Additional** Such feedback shall be considered a contribution and licensed to Microsoft under the terms of Section 4 above.
|
48 |
+
|
49 |
+
## 7) COMPLIANCE WITH TRADE LAWS.
|
50 |
+
|
51 |
+
You acknowledge that the Materials may be subject to applicable trade laws in one or more countries. You will comply with all relevant laws and regulations applicable to the import or export of the Materials, including but not limited to, trade laws such as the U.S. Export Administration Regulations or other end-user, end use, and destination restrictions by the U.S. and other governments, as well as sanctions regulations administered by the U.S. Office of Foreign Assets Control. Microsoft may suspend or terminate the agreement immediately to the extent that Microsoft reasonably concludes that continued performance would violate trade laws or put it at risk of becoming subject to sanctions or penalties under trade laws. For additional information, see www.microsoft.com/exporting.
|
52 |
+
|
53 |
+
## 8) SUPPORT SERVICES.
|
54 |
+
|
55 |
+
Microsoft is not obligated under this agreement to provide any support services for the Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
|
56 |
+
|
57 |
+
## 9) BINDING ARBITRATION AND CLASS ACTION WAIVER.
|
58 |
+
|
59 |
+
**This Section applies if you live in (or, if a business, your principal place of business is in) the United States.** If you and Microsoft have a dispute, you and Microsoft agree to try for 60 days to resolve it informally. If you and Microsoft can’t, you and Microsoft agree to **binding individual arbitration before the American Arbitration Association** under the Federal Arbitration Act ("FAA"), and not to **sue in court in front of a judge or jury.** Instead, a neutral arbitrator will decide. **Class action lawsuits, class-wide arbitrations, private attorney-general actions,** and any other proceeding where someone acts in a representative capacity **are not allowed;** nor is combining individual proceedings without the consent of all parties. The complete Arbitration Agreement contains more terms and is at aka.ms/arb-agreement-1. You and Microsoft agree to these terms.
|
60 |
+
|
61 |
+
## 10) ENTIRE AGREEMENT.
|
62 |
+
|
63 |
+
This agreement, and any other terms Microsoft may provide for supplements, updates, or third-party applications, is the entire agreement for the Materials.
|
64 |
+
|
65 |
+
## 11) APPLICABLE LAW AND PLACE TO RESOLVE DISPUTES.
|
66 |
+
|
67 |
+
If you acquired the Materials in the United States or Canada, the laws of the state or province where you live (or, if a business, where your principal place of business is located) govern the interpretation of this agreement, claims for its breach, and all other claims (including consumer protection, unfair competition, and tort claims), regardless of conflict of laws principles, except that the FAA governs everything related to arbitration. If you acquired the Materials in any other country, its laws apply, except that the FAA governs everything related to arbitration. If U.S. federal jurisdiction exists, you and Microsoft consent to exclusive jurisdiction and venue in the federal court in King County, Washington for all disputes heard in court (excluding arbitration). If not, you and Microsoft consent to exclusive jurisdiction and venue in the Superior Court of King County, Washington for all disputes heard in court (excluding arbitration).
|
68 |
+
|
69 |
+
## 12) CONSUMER RIGHTS; REGIONAL VARIATIONS.
|
70 |
+
|
71 |
+
This agreement describes certain legal rights. You may have other rights, including consumer rights, under the laws of your state, province, or country. Separate and apart from your relationship with Microsoft, you may also have rights with respect to the party from which you acquired the Materials. This agreement does not change those other rights if the laws of your state, province, or country do not permit it to do so. For example, if you acquired the Materials in one of the below regions, or mandatory country law applies, then the following provisions apply to you:
|
72 |
+
|
73 |
+
a) **Australia.** You have statutory guarantees under the Australian Consumer Law and nothing in this agreement is intended to affect those rights.
|
74 |
+
|
75 |
+
b) **Canada.** If you acquired this software in Canada, you may stop receiving updates by turning off the automatic update feature, disconnecting your device from the Internet (if and when you re-connect to the Internet, however, the Materials will resume checking for and installing updates), or uninstalling the Materials. The product documentation, if any, may also specify how to turn off updates for your specific device or software.
|
76 |
+
|
77 |
+
c) **Germany and Austria.**
|
78 |
+
i. **Warranty.** The properly licensed software will perform substantially as described in any Microsoft materials that accompany the Materials. However, Microsoft gives no contractual guarantee in relation to the licensed software.
|
79 |
+
ii. **Limitation of Liability.** In case of intentional conduct, gross negligence, claims based on the Product Liability Act, as well as, in case of death or personal or physical injury, Microsoft is liable according to the statutory law.
|
80 |
+
|
81 |
+
Subject to the foregoing clause (ii), Microsoft will only be liable for slight negligence if Microsoft is in breach of such material contractual obligations, the fulfillment of which facilitate the due performance of this agreement, the breach of which would endanger the purpose of this agreement and the compliance with which a party may constantly trust in (so-called "cardinal obligations"). In other cases of slight negligence, Microsoft will not be liable for slight negligence.
|
82 |
+
|
83 |
+
## 13) DISCLAIMER OF WARRANTY.
|
84 |
+
|
85 |
+
THE MATERIALS ARE LICENSED "AS IS." YOU BEAR THE RISK OF USING THEM. MICROSOFT GIVES NO EXPRESS WARRANTIES, GUARANTEES, OR CONDITIONS. TO THE EXTENT PERMITTED UNDER APPLICABLE LAWS, MICROSOFT EXCLUDES ALL IMPLIED WARRANTIES, INCLUDING MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT.
|
86 |
+
|
87 |
+
## 14) LIMITATION ON AND EXCLUSION OF DAMAGES.
|
88 |
+
|
89 |
+
IF YOU HAVE ANY BASIS FOR RECOVERING DAMAGES DESPITE THE PRECEDING DISCLAIMER OF WARRANTY, YOU CAN RECOVER FROM MICROSOFT AND ITS SUPPLIERS ONLY DIRECT DAMAGES UP TO U.S. $5.00. YOU CANNOT RECOVER ANY OTHER DAMAGES, INCLUDING CONSEQUENTIAL, LOST PROFITS, SPECIAL, INDIRECT OR INCIDENTAL DAMAGES.
|
90 |
+
|
91 |
+
This limitation applies to:
|
92 |
+
- (a) anything related to the Materials, services, content (including code) on third party Internet sites, or third party applications; and
|
93 |
+
- (b) claims for breach of contract, warranty, guarantee, or condition; strict liability, negligence, or other tort; or any other claim; in each case to the extent permitted by applicable law.
|
94 |
+
|
95 |
+
It also applies even if Microsoft knew or should have known about the possibility of the damages. The above limitation or exclusion may not apply to you because your state, province, or country may not allow the exclusion or limitation of incidental, consequential, or other damages.
|
96 |
+
|
SECURITY.md
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Security
|
2 |
+
|
3 |
+
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
|
4 |
+
|
5 |
+
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
|
6 |
+
|
7 |
+
## Reporting Security Issues
|
8 |
+
|
9 |
+
**Please do not report security vulnerabilities through public GitHub issues.**
|
10 |
+
|
11 |
+
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
|
12 |
+
|
13 |
+
If you prefer to submit without logging in, send email to [[email protected]](mailto:[email protected]). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
|
14 |
+
|
15 |
+
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
|
16 |
+
|
17 |
+
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
|
18 |
+
|
19 |
+
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
|
20 |
+
* Full paths of source file(s) related to the manifestation of the issue
|
21 |
+
* The location of the affected source code (tag/branch/commit or direct URL)
|
22 |
+
* Any special configuration required to reproduce the issue
|
23 |
+
* Step-by-step instructions to reproduce the issue
|
24 |
+
* Proof-of-concept or exploit code (if possible)
|
25 |
+
* Impact of the issue, including how an attacker might exploit the issue
|
26 |
+
|
27 |
+
This information will help us triage your report more quickly.
|
28 |
+
|
29 |
+
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
|
30 |
+
|
31 |
+
## Preferred Languages
|
32 |
+
|
33 |
+
We prefer all communications to be in English.
|
34 |
+
|
35 |
+
## Policy
|
36 |
+
|
37 |
+
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
|
assets/Demonstrator/Fig_01.png
ADDED
![]() |
assets/Demonstrator/Fig_02.png
ADDED
![]() |
assets/Demonstrator/Fig_03.png
ADDED
![]() |
assets/Demonstrator/Fig_04.png
ADDED
![]() |
assets/Demonstrator/Fig_05.png
ADDED
![]() |
assets/Demonstrator/Fig_06.png
ADDED
![]() |
assets/Demonstrator/Fig_07.png
ADDED
![]() |
assets/Demonstrator/Fig_08.png
ADDED
![]() |
assets/Demonstrator/Fig_09.png
ADDED
![]() |
assets/Demonstrator/Fig_10.png
ADDED
![]() |
assets/Demonstrator/Fig_11.png
ADDED
![]() |
assets/Demonstrator/Fig_12.png
ADDED
![]() |
assets/Demonstrator/Fig_13.png
ADDED
![]() |
assets/Demonstrator/Fig_14.png
ADDED
![]() |
assets/Demonstrator/Fig_15.png
ADDED
![]() |
assets/Readme/model_capabilities.gif
ADDED
![]() |
Git LFS Details
|
assets/Readme/wham_gen_1.gif
ADDED
![]() |
Git LFS Details
|
assets/Readme/wham_gen_2.gif
ADDED
![]() |
Git LFS Details
|
assets/Readme/wham_gen_3.gif
ADDED
![]() |
Git LFS Details
|
assets/Readme/wham_gen_4.gif
ADDED
![]() |
Git LFS Details
|
assets/Readme/wham_gen_5.gif
ADDED
![]() |
Git LFS Details
|
assets/Readme/wham_gen_6.gif
ADDED
![]() |
Git LFS Details
|
assets/Readme/wham_gen_7.gif
ADDED
![]() |
Git LFS Details
|
assets/Readme/wham_gen_8.gif
ADDED
![]() |
Git LFS Details
|
assets/Readme/wham_gen_9.gif
ADDED
![]() |
Git LFS Details
|
configs/metadata_custom_tag.config
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
%Image::ExifTool::UserDefined = (
|
2 |
+
'Image::ExifTool::XMP::xmp' => {
|
3 |
+
'ProgramName' => { Name => 'ProgramName', Writable => 'string' }
|
4 |
+
}
|
5 |
+
);
|
models/WHAM_1.6B_v1.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9c4997074883aa1a39a5994a7dea91fb62b2382fc039523458827adb777af8e9
|
3 |
+
size 20339650059
|
models/WHAM_200M.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5ddb8e03a33f0849a63da030fea3de4994d95e16888993b8ab92faa904f3b31f
|
3 |
+
size 3980245067
|
requirements.txt
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--find-links https://download.pytorch.org/whl/torch_stable.html
|
2 |
+
aiohttp==3.9.3
|
3 |
+
aiosignal==1.3.1
|
4 |
+
async-timeout==4.0.3
|
5 |
+
attrs==23.2.0
|
6 |
+
blinker==1.7.0
|
7 |
+
certifi==2024.2.2
|
8 |
+
charset-normalizer==3.3.2
|
9 |
+
click==8.1.7
|
10 |
+
cloudpickle==3.0.0
|
11 |
+
cmake==3.28.3
|
12 |
+
einops==0.6.0
|
13 |
+
ffmpegcv==0.3.10
|
14 |
+
filelock==3.13.1
|
15 |
+
Flask==3.0.2
|
16 |
+
frozenlist==1.4.1
|
17 |
+
fsspec==2024.2.0
|
18 |
+
idna==3.6
|
19 |
+
importlib_metadata==7.0.2
|
20 |
+
itsdangerous==2.1.2
|
21 |
+
Jinja2==3.1.3
|
22 |
+
lightning-utilities==0.10.1
|
23 |
+
lit==17.0.6
|
24 |
+
MarkupSafe==2.1.5
|
25 |
+
mpmath==1.3.0
|
26 |
+
multidict==6.0.5
|
27 |
+
networkx==3.2.1
|
28 |
+
numpy==1.25.2
|
29 |
+
opencv-python==4.6.0.66
|
30 |
+
opencv-python-headless==4.9.0.80
|
31 |
+
packaging==23.2
|
32 |
+
pillow==10.2.0
|
33 |
+
pytorch-lightning==1.9.4
|
34 |
+
PyYAML==6.0.1
|
35 |
+
requests==2.31.0
|
36 |
+
sympy==1.12
|
37 |
+
tensordict==0.1.2
|
38 |
+
torch==2.0.1+cu118
|
39 |
+
torchinfo==1.7.1
|
40 |
+
torchmetrics==0.11.4
|
41 |
+
torchvision==0.15.2+cu118
|
42 |
+
tqdm==4.66.2
|
43 |
+
triton==2.0.0
|
44 |
+
typing_extensions==4.10.0
|
45 |
+
urllib3==2.2.1
|
46 |
+
Werkzeug==3.0.1
|
47 |
+
yarl==1.9.4
|
48 |
+
zipp==3.17.0
|
run_dreaming.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Example script for running dreaming on a dataset.
|
3 |
+
The idea is that there are ground_truth ("reference") video clips, and we dream the same clips given some initial context.
|
4 |
+
|
5 |
+
After dreaming, we have two sets of videos which, barring the intrinsic noise of the game environment (e.g., randomness of other players),
|
6 |
+
should be identical if model was ideal.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import argparse
|
10 |
+
from pathlib import Path
|
11 |
+
import os
|
12 |
+
import subprocess
|
13 |
+
|
14 |
+
import cv2
|
15 |
+
from tensordict import TensorDict
|
16 |
+
import torch as th
|
17 |
+
from tqdm import tqdm
|
18 |
+
import numpy as np
|
19 |
+
import ffmpegcv
|
20 |
+
from PIL import Image
|
21 |
+
|
22 |
+
import wham.utils as utils
|
23 |
+
|
24 |
+
|
25 |
+
parser = argparse.ArgumentParser(description="Run dreaming.")
|
26 |
+
parser.add_argument("--model_path", type=str, required=True, help="Path to the model checkpoint.")
|
27 |
+
parser.add_argument("--data_path", type=str, required=True, help="Path to the directory that contains the ground truth data to dream for.")
|
28 |
+
parser.add_argument("--output", type=str, default="dreaming_output", help="Path to the directory where output should be put.")
|
29 |
+
parser.add_argument("--max_files", type=int, default=None, help="Maximum number of files to process.")
|
30 |
+
parser.add_argument("--metadata_config", type=str, default="configs/metadata_custom_tag.config", help="Path to metadata tag config for origin field.")
|
31 |
+
|
32 |
+
|
33 |
+
parser.add_argument(
|
34 |
+
"--protocol",
|
35 |
+
type=str,
|
36 |
+
default="base",
|
37 |
+
choices=["base", "comprehensive"],
|
38 |
+
help="What protocol to use for the dreaming. base = action conditioned, comprehensive = dream actions as well.",
|
39 |
+
)
|
40 |
+
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for dreaming. Higher batch_size uses more VRAM but overall is faster.")
|
41 |
+
parser.add_argument("--context_length", type=int, default=10, help="Number of frames to use an initial context.")
|
42 |
+
parser.add_argument("--steps_to_dream", type=int, default=10, help="Batch size for dreaming.")
|
43 |
+
|
44 |
+
parser.add_argument("--sampling_temperature", type=float, default=0.9, help="Temperature for sampling from the model.")
|
45 |
+
parser.add_argument("--sampling_top_k", type=int, default=None, help="Top-k for sampling from the model.")
|
46 |
+
parser.add_argument("--sampling_top_p", type=float, default=None, help="Top-p for sampling from the model.")
|
47 |
+
|
48 |
+
|
49 |
+
def get_context_data(image_context, action_context, action_sequences):
|
50 |
+
# Make sure we have CHW images:
|
51 |
+
assert image_context.shape[-3] == 3, "Image context should be CHW"
|
52 |
+
|
53 |
+
image_context = th.from_numpy(image_context).cuda()
|
54 |
+
action_data = th.from_numpy(action_context).float().cuda()
|
55 |
+
action_sequences = th.from_numpy(action_sequences).float().cuda() if action_sequences is not None else None
|
56 |
+
|
57 |
+
return TensorDict({"images": image_context, "actions_output": action_data}, batch_size=image_context.shape[:2])
|
58 |
+
|
59 |
+
|
60 |
+
def add_video_metadata(file_path, metadata_config):
|
61 |
+
# Construct the exiftool command
|
62 |
+
cmd = [
|
63 |
+
'exiftool',
|
64 |
+
'-config', metadata_config,
|
65 |
+
f'-ProgramName=\"{utils.PROGRAM_NAME}\"',
|
66 |
+
'-overwrite_original',
|
67 |
+
file_path
|
68 |
+
]
|
69 |
+
|
70 |
+
try:
|
71 |
+
# Execute the exiftool command
|
72 |
+
subprocess.run(cmd, check=True)
|
73 |
+
print(f"Metadata modified successfully.")
|
74 |
+
# Print the new file metadata
|
75 |
+
cmd_output = [
|
76 |
+
'exiftool',
|
77 |
+
file_path
|
78 |
+
]
|
79 |
+
subprocess.run(cmd_output, check=True)
|
80 |
+
except subprocess.CalledProcessError as e:
|
81 |
+
print(f"Error modifying metadata: {e}")
|
82 |
+
|
83 |
+
|
84 |
+
@th.no_grad()
|
85 |
+
def do_dreaming(model, image_context, action_context, args, action_sequences=None):
|
86 |
+
"""
|
87 |
+
image_contect and action_context provide the initial context for the model to dream from.
|
88 |
+
|
89 |
+
If action_sequences (batch_size, args.steps_to_dream, action_dim) is provided, then model will be prompted with these actions.
|
90 |
+
"""
|
91 |
+
context_data = get_context_data(image_context, action_context, action_sequences)
|
92 |
+
encoded_context_data = model.encode_context(context_data)
|
93 |
+
|
94 |
+
encoded_action_sequences = None
|
95 |
+
if action_sequences is not None:
|
96 |
+
assert action_sequences.shape[1] == args.steps_to_dream, "action_sequences should have shape (batch_size, args.steps_to_dream, action_dim)"
|
97 |
+
action_sequences = TensorDict({"actions_output": action_sequences}, batch_size=action_sequences.shape[:2]).cuda()
|
98 |
+
encoded_action_sequences = model.encode_context(action_sequences)
|
99 |
+
|
100 |
+
encoded_dreamt_steps = []
|
101 |
+
|
102 |
+
for dream_step in range(args.steps_to_dream):
|
103 |
+
encoded_predicted_step, _ = model.predictor.predict_next_step(
|
104 |
+
encoded_context_data, temperature=args.sampling_temperature, top_k=args.sampling_top_k, top_p=args.sampling_top_p, min_tokens_to_keep=1
|
105 |
+
)
|
106 |
+
|
107 |
+
# Remove first step from context if we are at the max context length:
|
108 |
+
if encoded_context_data.shape[1] == args.context_length:
|
109 |
+
encoded_context_data = encoded_context_data[:, 1:]
|
110 |
+
|
111 |
+
# Add predicted image + action to the context
|
112 |
+
append_step = encoded_predicted_step
|
113 |
+
if encoded_action_sequences is not None:
|
114 |
+
# Replace predicted action with real action
|
115 |
+
append_step["actions_output"] = encoded_action_sequences["actions_output"][:, [dream_step], :]
|
116 |
+
encoded_context_data = th.cat((encoded_context_data, append_step), dim=1)
|
117 |
+
|
118 |
+
encoded_dreamt_steps.append(encoded_predicted_step)
|
119 |
+
|
120 |
+
# Decode everything
|
121 |
+
dreamed_images = []
|
122 |
+
actions_during_dream = []
|
123 |
+
for seq_i in range(args.steps_to_dream):
|
124 |
+
decoded_step = model.decode_context(encoded_dreamt_steps[seq_i])
|
125 |
+
dreamed_images.append(decoded_step["images"][:, [0]].cpu().numpy())
|
126 |
+
actions_during_dream.append(decoded_step["actions_output"][:, [0]].cpu().numpy())
|
127 |
+
|
128 |
+
dreamed_images = np.concatenate(dreamed_images, axis=1)
|
129 |
+
actions_during_dream = np.concatenate(actions_during_dream, axis=1)
|
130 |
+
|
131 |
+
return dreamed_images, actions_during_dream
|
132 |
+
|
133 |
+
|
134 |
+
@th.no_grad()
|
135 |
+
def encode_decode_images(model, images):
|
136 |
+
"""
|
137 |
+
Pass ground_truth images through the encoding/decoding process of the model.
|
138 |
+
"""
|
139 |
+
context = TensorDict({"images": th.from_numpy(images).cuda()}, batch_size=images.shape[:2])
|
140 |
+
output_images = []
|
141 |
+
for seq_i in range(images.shape[1]):
|
142 |
+
encoded_images = model.encode_context(context[:, [seq_i]])
|
143 |
+
decoded_images = model.decode_context(encoded_images)
|
144 |
+
output_images.append(decoded_images["images"].cpu().numpy())
|
145 |
+
return np.concatenate(output_images, axis=1)
|
146 |
+
|
147 |
+
|
148 |
+
def main(args):
|
149 |
+
total_video_length = args.context_length + args.steps_to_dream
|
150 |
+
|
151 |
+
# Now, load the model:
|
152 |
+
model_path = Path(args.model_path)
|
153 |
+
assert model_path.is_file(), "Could not find the model!"
|
154 |
+
model = utils.load_model_from_checkpoint(model_path).cuda()
|
155 |
+
|
156 |
+
# Glob the dataset to find all the ground truth segments we want to construct a dream for:
|
157 |
+
data_path = Path(args.data_path)
|
158 |
+
ground_truth_files = list(data_path.rglob("*.npz"))
|
159 |
+
num_dreams = len(ground_truth_files)
|
160 |
+
|
161 |
+
if args.max_files is not None:
|
162 |
+
# Sort to make sure we always get the same files
|
163 |
+
ground_truth_files = sorted(ground_truth_files)
|
164 |
+
ground_truth_files = ground_truth_files[: args.max_files]
|
165 |
+
num_dreams = len(ground_truth_files)
|
166 |
+
|
167 |
+
output_path = Path(args.output)
|
168 |
+
os.makedirs(output_path, exist_ok=True)
|
169 |
+
|
170 |
+
print("=" * 100)
|
171 |
+
print(f"GENERATING DREAMS OF {num_dreams} SEGMENTS")
|
172 |
+
print(f"WRITING TO {args.output}")
|
173 |
+
print("=" * 100)
|
174 |
+
|
175 |
+
dreams_created = 0
|
176 |
+
with tqdm(total=num_dreams, desc="Dreams") as pbar:
|
177 |
+
while ground_truth_files:
|
178 |
+
# Load batch_size headers:
|
179 |
+
batches = min(args.batch_size, len(ground_truth_files))
|
180 |
+
batched_image_context = []
|
181 |
+
batched_image_sequence = []
|
182 |
+
batched_action_context = []
|
183 |
+
batched_action_sequence = []
|
184 |
+
episode_names = []
|
185 |
+
for i in range(batches):
|
186 |
+
episode = ground_truth_files.pop()
|
187 |
+
episode_names.append(episode)
|
188 |
+
try:
|
189 |
+
data = np.load(episode)
|
190 |
+
images = data["images"]
|
191 |
+
actions = data["actions"]
|
192 |
+
except Exception:
|
193 |
+
print(f"Failed to load episode {episode} - skipping.")
|
194 |
+
continue
|
195 |
+
|
196 |
+
if actions.shape[0] < total_video_length:
|
197 |
+
# We want to make sure we have ground_truth comparisons for the entire dream, so we ensure the episode is long enough
|
198 |
+
raise ValueError(f"Episode {episode} is too short to dream from. It has {actions.shape[0]} steps, but we need at least {total_video_length}.")
|
199 |
+
batched_image_context.append(images[: args.context_length])
|
200 |
+
batched_image_sequence.append(images[args.context_length: total_video_length])
|
201 |
+
batched_action_context.append(actions[: args.context_length])
|
202 |
+
batched_action_sequence.append(actions[args.context_length: total_video_length])
|
203 |
+
|
204 |
+
image_context = np.array(batched_image_context)
|
205 |
+
image_sequences = np.array(batched_image_sequence)
|
206 |
+
action_context = np.array(batched_action_context)
|
207 |
+
action_sequences = np.array(batched_action_sequence)
|
208 |
+
|
209 |
+
if args.protocol == "comprehensive":
|
210 |
+
# We do not need to pass in the action sequences for comprehensive protocol
|
211 |
+
action_sequences = None
|
212 |
+
|
213 |
+
full_image_sequence = np.concatenate((image_context, image_sequences), axis=1)
|
214 |
+
|
215 |
+
dreamt_images, actions_during_dream = do_dreaming(model, image_context, action_context, args, action_sequences=action_sequences)
|
216 |
+
encoded_decoded_images_batch = encode_decode_images(model, full_image_sequence)
|
217 |
+
|
218 |
+
pbar.update(batches)
|
219 |
+
dreams_created += batches
|
220 |
+
|
221 |
+
# Save the dreams:
|
222 |
+
# We are aiming to mimic the folder structure of the ground truth dataset, so use the episode names
|
223 |
+
# but make them relative to our output folder:
|
224 |
+
for i, dream in enumerate(dreamt_images):
|
225 |
+
episode = episode_names[i]
|
226 |
+
output_file = output_path / episode.relative_to(data_path)
|
227 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
228 |
+
np.savez(
|
229 |
+
output_file,
|
230 |
+
context_length=args.context_length,
|
231 |
+
steps_to_dream=args.steps_to_dream,
|
232 |
+
raw_context=image_context[i],
|
233 |
+
dreamt_images=dream,
|
234 |
+
all_actions=np.concatenate((action_context[i], actions_during_dream[i])),
|
235 |
+
encoded_decoded_ground_truth_images=encoded_decoded_images_batch[i],
|
236 |
+
)
|
237 |
+
|
238 |
+
video_file = str(output_file.with_suffix(".mp4"))
|
239 |
+
writer = ffmpegcv.VideoWriter(video_file, None, utils.DREAMING_FPS)
|
240 |
+
full_sequence = np.concatenate((image_context[i], dream), axis=0)
|
241 |
+
for frame in full_sequence:
|
242 |
+
img = frame.transpose(1, 2, 0).astype(np.uint8).copy()
|
243 |
+
# Please DO NOT remove this watermark. This will infringe upon the repo's license agreement
|
244 |
+
(text_width, _), _ = cv2.getTextSize(utils.WATERMARK_TEXT, utils.WATERMARK_FONT, utils.WATERMARK_FONT_SCALE, utils.WATERMARK_FONT_THICKNESS)
|
245 |
+
x = img.shape[1] - text_width - 10 # 10 pixels from the right edge
|
246 |
+
y = img.shape[0] - 10 # 10 pixels from the bottom edge
|
247 |
+
cv2.putText(img, utils.WATERMARK_TEXT, (x, y), utils.WATERMARK_FONT, utils.WATERMARK_FONT_SCALE, utils.WATERMARK_FONT_COLOR, utils.WATERMARK_FONT_THICKNESS)
|
248 |
+
|
249 |
+
# Add image metadata
|
250 |
+
pil_image = Image.fromarray(img)
|
251 |
+
pil_image.info['Id'] = 0x0131
|
252 |
+
pil_image.info['Type'] = 2
|
253 |
+
pil_image.info['Value'] = utils.PROGRAM_NAME.encode("utf-8")
|
254 |
+
pil_image.info['Len'] = len(utils.PROGRAM_NAME) + 1
|
255 |
+
|
256 |
+
# Convert pil_image to a CV2 format for the video writer
|
257 |
+
cv_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
258 |
+
writer.write(cv_image)
|
259 |
+
writer.release()
|
260 |
+
add_video_metadata(video_file, args.metadata_config)
|
261 |
+
|
262 |
+
if __name__ == "__main__":
|
263 |
+
args = parser.parse_args()
|
264 |
+
main(args)
|
run_server.py
ADDED
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
import json
|
4 |
+
import copy
|
5 |
+
import multiprocessing as mp
|
6 |
+
import uuid
|
7 |
+
from datetime import datetime, timedelta
|
8 |
+
from collections import defaultdict, deque
|
9 |
+
import io
|
10 |
+
import zipfile
|
11 |
+
import queue
|
12 |
+
import time
|
13 |
+
import random
|
14 |
+
import logging
|
15 |
+
|
16 |
+
from tensordict import TensorDict
|
17 |
+
import cv2
|
18 |
+
from flask import Flask, request, make_response, send_file
|
19 |
+
from PIL import Image
|
20 |
+
import torchvision.transforms as T
|
21 |
+
import numpy as np
|
22 |
+
import torch as th
|
23 |
+
|
24 |
+
from wham.utils import load_model_from_checkpoint, POS_BINS_BOUNDARIES, POS_BINS_MIDDLE
|
25 |
+
|
26 |
+
logging.basicConfig(level=logging.INFO)
|
27 |
+
|
28 |
+
parser = argparse.ArgumentParser(description="Simple Dreamer")
|
29 |
+
parser.add_argument("--model", type=str, required=True, help="Path to the model file for the local runs")
|
30 |
+
parser.add_argument("--debug", action="store_true", help="Enable flask debug mode.")
|
31 |
+
parser.add_argument("--random_model", action="store_true", help="Use randomly initialized model instead of the provided one")
|
32 |
+
parser.add_argument("--port", type=int, default=5000)
|
33 |
+
|
34 |
+
parser.add_argument("--max_concurrent_jobs", type=int, default=30, help="Maximum number of jobs that can be run concurrently on this server.")
|
35 |
+
parser.add_argument("--max_dream_steps_per_job", type=int, default=10, help="Maximum number of dream steps each job can request.")
|
36 |
+
parser.add_argument("--max_job_lifespan", type=int, default=60 * 10, help="Maximum number of seconds we keep run around if not polled.")
|
37 |
+
|
38 |
+
parser.add_argument("--image_width", type=int, default=300, help="Width of the image")
|
39 |
+
parser.add_argument("--image_height", type=int, default=180, help="Height of the image")
|
40 |
+
|
41 |
+
parser.add_argument("--max_batch_size", type=int, default=3, help="Maximum batch size for the dreamer workers")
|
42 |
+
|
43 |
+
PREDICTION_JSON_FILENAME = "predictions.json"
|
44 |
+
# Minimum time between times we check when to delete jobs. We do this when adding new jobs.
|
45 |
+
JOB_CLEANUP_CHECK_RATE = timedelta(seconds=10)
|
46 |
+
|
47 |
+
MAX_CANCELLED_ID_QUEUE_SIZE = 100
|
48 |
+
|
49 |
+
DEFAULT_SAMPLING_SETTINGS = {
|
50 |
+
"temperature": 0.9,
|
51 |
+
"top_k": None,
|
52 |
+
"top_p": 1.0,
|
53 |
+
"max_context_length": 10,
|
54 |
+
}
|
55 |
+
|
56 |
+
|
57 |
+
def float_or_none(string):
|
58 |
+
if string.lower() == "none":
|
59 |
+
return None
|
60 |
+
return float(string)
|
61 |
+
|
62 |
+
|
63 |
+
def be_image_preprocess(image, target_width, target_height):
|
64 |
+
# If target_width and target_height are specified, resize the image.
|
65 |
+
if target_width is not None and target_height is not None:
|
66 |
+
# Make sure we do not try to resize if the image is already the correct size.
|
67 |
+
if image.shape[1] != target_width or image.shape[0] != target_height:
|
68 |
+
image = cv2.resize(image, (target_width, target_height))
|
69 |
+
return np.transpose(image, (2, 0, 1))
|
70 |
+
|
71 |
+
|
72 |
+
def action_vector_to_be_action_vector(action):
|
73 |
+
# Preprocess a BE action vector from 16 numbers with:
|
74 |
+
# 12 buttons [0, 1] and 4 stick directions [-1, 1]
|
75 |
+
# to discrete actions valid for the token model
|
76 |
+
# 12 buttons [0, 1] and 4 stick directions {discrete bin}
|
77 |
+
action[-4:] = np.digitize(action[-4:], bins=POS_BINS_BOUNDARIES) - 1
|
78 |
+
return action
|
79 |
+
|
80 |
+
|
81 |
+
def be_action_vector_to_action_vector(action):
|
82 |
+
# Preprocess a BE action vector into unified space
|
83 |
+
for stick_index in range(-4, 0):
|
84 |
+
action[stick_index] = POS_BINS_MIDDLE[int(action[stick_index])]
|
85 |
+
return action
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
@dataclass
|
90 |
+
class DreamJob:
|
91 |
+
job_id: str
|
92 |
+
sampling_settings: dict
|
93 |
+
num_predictions_remaining: int
|
94 |
+
num_predictions_done: int
|
95 |
+
# (B, T, C, H, W)
|
96 |
+
context_images: th.Tensor
|
97 |
+
context_actions: th.Tensor
|
98 |
+
# Tokens that will replace the context_images if they are provided
|
99 |
+
context_tokens: list
|
100 |
+
# This will replace the dreamed action if provided.
|
101 |
+
# For every step, we remove the first action until exhausted
|
102 |
+
actions_to_take: th.Tensor = None
|
103 |
+
|
104 |
+
|
105 |
+
@dataclass
|
106 |
+
class DreamJobResult:
|
107 |
+
job_id: str
|
108 |
+
dream_step_index: int
|
109 |
+
# (B, 1, C, H, W)
|
110 |
+
dreamt_image: th.Tensor
|
111 |
+
dreamt_action: th.Tensor
|
112 |
+
dreamt_tokens: th.Tensor
|
113 |
+
result_creation_time: datetime = field(default_factory=datetime.now)
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
def setup_and_load_model_be_model(args):
|
118 |
+
model = load_model_from_checkpoint(args.model)
|
119 |
+
th.set_float32_matmul_precision("high")
|
120 |
+
th.backends.cuda.matmul.allow_tf32 = True
|
121 |
+
return model
|
122 |
+
|
123 |
+
|
124 |
+
def get_job_batchable_information(job):
|
125 |
+
"""Return comparable object of job information. Used for batching"""
|
126 |
+
context_length = job.context_images.shape[1]
|
127 |
+
return (context_length, job.sampling_settings)
|
128 |
+
|
129 |
+
|
130 |
+
def fetch_list_of_batchable_jobs(job_queue, cancelled_ids_set, max_batch_size, timeout=1):
|
131 |
+
"""Return a list of jobs (or empty list) that can be batched together"""
|
132 |
+
batchable_jobs = []
|
133 |
+
required_job_info = None
|
134 |
+
while len(batchable_jobs) < max_batch_size:
|
135 |
+
try:
|
136 |
+
job = job_queue.get(timeout=timeout)
|
137 |
+
except queue.Empty:
|
138 |
+
break
|
139 |
+
# If pipe breaks, also gracefully return
|
140 |
+
except OSError:
|
141 |
+
break
|
142 |
+
if job.job_id in cancelled_ids_set:
|
143 |
+
# This job was cancelled, so discard it completely
|
144 |
+
continue
|
145 |
+
job_info = get_job_batchable_information(job)
|
146 |
+
if required_job_info is None:
|
147 |
+
required_job_info = job_info
|
148 |
+
elif required_job_info != job_info:
|
149 |
+
# This job is not batchable, put it back
|
150 |
+
job_queue.put(job)
|
151 |
+
# we assume here that, generally, the others jobs would also be
|
152 |
+
# invalid. So we just return the batchable jobs we have instead
|
153 |
+
# of going through more.
|
154 |
+
break
|
155 |
+
batchable_jobs.append(job)
|
156 |
+
return batchable_jobs
|
157 |
+
|
158 |
+
|
159 |
+
def update_cancelled_jobs(cancelled_ids_queue, cancelled_ids_deque, cancelled_ids_set):
|
160 |
+
"""IN-PLACE Update cancelled_ids_set with new ids from the queue"""
|
161 |
+
has_changed = False
|
162 |
+
while not cancelled_ids_queue.empty():
|
163 |
+
try:
|
164 |
+
cancelled_id = cancelled_ids_queue.get_nowait()
|
165 |
+
except queue.Empty:
|
166 |
+
break
|
167 |
+
cancelled_ids_deque.append(cancelled_id)
|
168 |
+
has_changed = True
|
169 |
+
|
170 |
+
if has_changed:
|
171 |
+
cancelled_ids_set.clear()
|
172 |
+
cancelled_ids_set.update(cancelled_ids_deque)
|
173 |
+
|
174 |
+
|
175 |
+
def predict_step(context_data, sampling_settings, model, tokens=None):
|
176 |
+
with th.no_grad():
|
177 |
+
predicted_step = model.predict_next_step(context_data, min_tokens_to_keep=1, tokens=tokens, **sampling_settings)
|
178 |
+
return predicted_step
|
179 |
+
|
180 |
+
|
181 |
+
def dreamer_worker(job_queue, result_queue, cancelled_jobs_queue, quit_flag, device_to_use, args):
|
182 |
+
logger = logging.getLogger(f"dreamer_worker {device_to_use}")
|
183 |
+
logger.info("Loading up model...")
|
184 |
+
model = setup_and_load_model_be_model(args)
|
185 |
+
model = model.to(device_to_use)
|
186 |
+
logger.info("Model loaded. Fetching results")
|
187 |
+
|
188 |
+
cancelled_ids_deque = deque(maxlen=MAX_CANCELLED_ID_QUEUE_SIZE)
|
189 |
+
cancelled_ids_set = set()
|
190 |
+
|
191 |
+
while not quit_flag.is_set():
|
192 |
+
update_cancelled_jobs(cancelled_jobs_queue, cancelled_ids_deque, cancelled_ids_set)
|
193 |
+
batchable_jobs = fetch_list_of_batchable_jobs(job_queue, cancelled_ids_set, max_batch_size=args.max_batch_size)
|
194 |
+
if len(batchable_jobs) == 0:
|
195 |
+
continue
|
196 |
+
sampling_settings = batchable_jobs[0].sampling_settings
|
197 |
+
# make better way for passing these arguments around. sampling_settings
|
198 |
+
# is passed as kwargs to predicting step, but max_context_length is not part of valid
|
199 |
+
# keys there, so we need to pop it out.
|
200 |
+
max_context_length = sampling_settings.pop("max_context_length")
|
201 |
+
|
202 |
+
images = [job.context_images[:, :max_context_length] for job in batchable_jobs]
|
203 |
+
actions = [job.context_actions[:, :max_context_length] for job in batchable_jobs]
|
204 |
+
tokens = [job.context_tokens for job in batchable_jobs]
|
205 |
+
|
206 |
+
images = th.concat(images, dim=0).to(device_to_use)
|
207 |
+
actions = th.concat(actions, dim=0).to(device_to_use)
|
208 |
+
|
209 |
+
context_data = TensorDict({
|
210 |
+
"images": images,
|
211 |
+
"actions_output": actions
|
212 |
+
}, batch_size=images.shape[:2])
|
213 |
+
|
214 |
+
predicted_step, predicted_image_tokens = predict_step(context_data, sampling_settings, model, tokens)
|
215 |
+
|
216 |
+
predicted_step = predicted_step.cpu()
|
217 |
+
predicted_images = predicted_step["images"]
|
218 |
+
predicted_actions = predicted_step["actions_output"]
|
219 |
+
predicted_image_tokens = predicted_image_tokens.cpu()
|
220 |
+
|
221 |
+
for job_i, job in enumerate(batchable_jobs):
|
222 |
+
image_context = job.context_images
|
223 |
+
action_context = job.context_actions
|
224 |
+
token_context = job.context_tokens
|
225 |
+
# Keep batch dimension
|
226 |
+
dreamt_image = predicted_images[job_i].unsqueeze(0)
|
227 |
+
dreamt_action = predicted_actions[job_i].unsqueeze(0)
|
228 |
+
dreamt_tokens = predicted_image_tokens[job_i].unsqueeze(0)
|
229 |
+
|
230 |
+
# Replace the dreamed action if provided
|
231 |
+
actions_to_take = job.actions_to_take
|
232 |
+
if actions_to_take is not None and actions_to_take.shape[1] > 0:
|
233 |
+
dreamt_action = actions_to_take[:, 0:1]
|
234 |
+
# Remove the action we took
|
235 |
+
actions_to_take = actions_to_take[:, 1:]
|
236 |
+
if actions_to_take.shape[1] == 0:
|
237 |
+
actions_to_take = None
|
238 |
+
|
239 |
+
result_queue.put(DreamJobResult(
|
240 |
+
job_id=job.job_id,
|
241 |
+
dream_step_index=job.num_predictions_done,
|
242 |
+
dreamt_image=dreamt_image,
|
243 |
+
dreamt_action=dreamt_action,
|
244 |
+
dreamt_tokens=dreamt_tokens
|
245 |
+
))
|
246 |
+
|
247 |
+
# Add job back in the queue if we have more steps to do
|
248 |
+
if job.num_predictions_remaining > 0:
|
249 |
+
# Stack the dreamt image and action to the context
|
250 |
+
if image_context.shape[1] >= max_context_length:
|
251 |
+
image_context = image_context[:, 1:]
|
252 |
+
action_context = action_context[:, 1:]
|
253 |
+
token_context = token_context[1:]
|
254 |
+
image_context = th.cat([image_context, dreamt_image], dim=1)
|
255 |
+
action_context = th.cat([action_context, dreamt_action], dim=1)
|
256 |
+
token_context.append(dreamt_tokens[0, 0].tolist())
|
257 |
+
# We need to add context length back to sampling settings...
|
258 |
+
# add some better way of passing these settings around
|
259 |
+
job.sampling_settings["max_context_length"] = max_context_length
|
260 |
+
job_queue.put(DreamJob(
|
261 |
+
job_id=job.job_id,
|
262 |
+
sampling_settings=job.sampling_settings,
|
263 |
+
num_predictions_remaining=job.num_predictions_remaining - 1,
|
264 |
+
num_predictions_done=job.num_predictions_done + 1,
|
265 |
+
context_images=image_context,
|
266 |
+
context_actions=action_context,
|
267 |
+
context_tokens=token_context,
|
268 |
+
actions_to_take=actions_to_take
|
269 |
+
))
|
270 |
+
|
271 |
+
|
272 |
+
class DreamerServer:
|
273 |
+
def __init__(self, num_workers, args):
|
274 |
+
self.num_workers = num_workers
|
275 |
+
self.args = args
|
276 |
+
self.model = None
|
277 |
+
self.jobs = mp.Queue(maxsize=args.max_concurrent_jobs)
|
278 |
+
self.results_queue = mp.Queue()
|
279 |
+
self.cancelled_jobs = set()
|
280 |
+
self.cancelled_jobs_queues = [mp.Queue() for _ in range(num_workers)]
|
281 |
+
# job_id -> results
|
282 |
+
self._last_result_cleanup = datetime.now()
|
283 |
+
self._max_job_lifespan_datetime = timedelta(seconds=args.max_job_lifespan)
|
284 |
+
self.local_results = defaultdict(list)
|
285 |
+
self.logger = logging.getLogger("DreamerServer")
|
286 |
+
|
287 |
+
def get_details(self):
|
288 |
+
details = {
|
289 |
+
"model_file": self.args.model,
|
290 |
+
"max_concurrent_jobs": self.args.max_concurrent_jobs,
|
291 |
+
"max_dream_steps_per_job": self.args.max_dream_steps_per_job,
|
292 |
+
"max_job_lifespan": self.args.max_job_lifespan,
|
293 |
+
}
|
294 |
+
return json.dumps(details)
|
295 |
+
|
296 |
+
def _check_if_should_remove_old_jobs(self):
|
297 |
+
time_now = datetime.now()
|
298 |
+
# Only cleanup every JOB_CLEANUP_CHECK_RATE seconds at most
|
299 |
+
if time_now - self._last_result_cleanup < JOB_CLEANUP_CHECK_RATE:
|
300 |
+
return
|
301 |
+
|
302 |
+
self._last_result_cleanup = time_now
|
303 |
+
# First add existing results to the local results
|
304 |
+
self._gather_new_results()
|
305 |
+
# Check if we should remove old jobs
|
306 |
+
job_ids = list(self.local_results.keys())
|
307 |
+
for job_id in job_ids:
|
308 |
+
results = self.local_results[job_id]
|
309 |
+
# If newest result is older than max_job_lifespan, remove the job
|
310 |
+
if time_now - results[-1].result_creation_time > self._max_job_lifespan_datetime:
|
311 |
+
self.logger.info(f"Deleted job {job_id} because it was too old. Last result was {results[-1].result_creation_time}")
|
312 |
+
del self.local_results[job_id]
|
313 |
+
|
314 |
+
def add_new_job(self, request, request_json):
|
315 |
+
"""
|
316 |
+
Add new dreaming job to the queues.
|
317 |
+
Request should have:
|
318 |
+
|
319 |
+
|
320 |
+
Returns: json object with new job id
|
321 |
+
"""
|
322 |
+
self._check_if_should_remove_old_jobs()
|
323 |
+
|
324 |
+
sampling_settings = copy.deepcopy(DEFAULT_SAMPLING_SETTINGS)
|
325 |
+
if "num_steps_to_predict" not in request_json:
|
326 |
+
return make_response("num_steps_to_predict not in request", 400)
|
327 |
+
num_steps_to_predict = request_json['num_steps_to_predict']
|
328 |
+
if num_steps_to_predict > self.args.max_dream_steps_per_job:
|
329 |
+
return make_response(f"num_steps_to_predict too large. Max {self.args.max_dream_steps_per_job}", 400)
|
330 |
+
|
331 |
+
num_parallel_predictions = int(request_json['num_parallel_predictions']) if 'num_parallel_predictions' in request_json else 1
|
332 |
+
|
333 |
+
if (self.jobs.qsize() + num_parallel_predictions) >= self.args.max_concurrent_jobs:
|
334 |
+
return make_response(f"Too many jobs already running. Max {self.args.max_concurrent_jobs}", 400)
|
335 |
+
|
336 |
+
for key in sampling_settings:
|
337 |
+
sampling_settings[key] = float_or_none(request_json[key]) if key in request_json else sampling_settings[key]
|
338 |
+
|
339 |
+
context_images = []
|
340 |
+
context_actions = []
|
341 |
+
context_tokens = []
|
342 |
+
future_actions = []
|
343 |
+
|
344 |
+
for step in request_json["steps"]:
|
345 |
+
image_path = step["image_name"]
|
346 |
+
image = np.array(Image.open(request.files[image_path].stream))
|
347 |
+
image = be_image_preprocess(image, target_width=self.args.image_width, target_height=self.args.image_height)
|
348 |
+
context_images.append(th.from_numpy(image))
|
349 |
+
|
350 |
+
action = step["action"]
|
351 |
+
action = action_vector_to_be_action_vector(action)
|
352 |
+
context_actions.append(th.tensor(action))
|
353 |
+
|
354 |
+
tokens = step["tokens"]
|
355 |
+
context_tokens.append(tokens)
|
356 |
+
|
357 |
+
future_actions = None
|
358 |
+
if "future_actions" in request_json:
|
359 |
+
future_actions = []
|
360 |
+
for step in request_json["future_actions"]:
|
361 |
+
# The rest is the action vector
|
362 |
+
action = step["action"]
|
363 |
+
action = action_vector_to_be_action_vector(action)
|
364 |
+
# Add sequence and batch dimensions
|
365 |
+
future_actions.append(th.tensor(action))
|
366 |
+
|
367 |
+
# Add batch dimensions
|
368 |
+
context_images = th.stack(context_images).unsqueeze(0)
|
369 |
+
context_actions = th.stack(context_actions).unsqueeze(0)
|
370 |
+
future_actions = th.stack(future_actions).unsqueeze(0) if future_actions is not None else None
|
371 |
+
|
372 |
+
list_of_job_ids = []
|
373 |
+
for _ in range(num_parallel_predictions):
|
374 |
+
job_id = uuid.uuid4().hex
|
375 |
+
self.jobs.put(DreamJob(
|
376 |
+
job_id=job_id,
|
377 |
+
sampling_settings=sampling_settings,
|
378 |
+
num_predictions_remaining=num_steps_to_predict,
|
379 |
+
num_predictions_done=0,
|
380 |
+
context_images=context_images,
|
381 |
+
context_actions=context_actions,
|
382 |
+
context_tokens=context_tokens,
|
383 |
+
actions_to_take=future_actions
|
384 |
+
))
|
385 |
+
list_of_job_ids.append(job_id)
|
386 |
+
|
387 |
+
job_queue_size = self.jobs.qsize()
|
388 |
+
return json.dumps({"job_ids": list_of_job_ids, "current_jobs_in_queue": job_queue_size})
|
389 |
+
|
390 |
+
def _gather_new_results(self):
|
391 |
+
if not self.results_queue.empty():
|
392 |
+
for _ in range(self.results_queue.qsize()):
|
393 |
+
result = self.results_queue.get()
|
394 |
+
if result.job_id in self.cancelled_jobs:
|
395 |
+
# Discard result if job was cancelled
|
396 |
+
continue
|
397 |
+
self.local_results[result.job_id].append(result)
|
398 |
+
|
399 |
+
def get_new_results(self, request, request_json):
|
400 |
+
if "job_ids" not in request_json:
|
401 |
+
return make_response("job_ids not in request", 400)
|
402 |
+
self._gather_new_results()
|
403 |
+
job_ids = request_json["job_ids"]
|
404 |
+
if not isinstance(job_ids, list):
|
405 |
+
job_ids = [job_ids]
|
406 |
+
return_results = []
|
407 |
+
for job_id in job_ids:
|
408 |
+
if job_id in self.local_results:
|
409 |
+
return_results.append(self.local_results[job_id])
|
410 |
+
del self.local_results[job_id]
|
411 |
+
|
412 |
+
if len(return_results) == 0:
|
413 |
+
return make_response("No new responses", 204)
|
414 |
+
|
415 |
+
output_json = []
|
416 |
+
output_image_bytes = {}
|
417 |
+
for job_results in return_results:
|
418 |
+
for result in job_results:
|
419 |
+
action = result.dreamt_action.numpy()
|
420 |
+
# Remember to remove batch and sequence dimensions
|
421 |
+
action = be_action_vector_to_action_vector(action[0, 0].tolist())
|
422 |
+
dreamt_tokens = result.dreamt_tokens[0, 0].tolist()
|
423 |
+
image_filename = f"{result.job_id}_{result.dream_step_index}.png"
|
424 |
+
output_json.append({
|
425 |
+
"job_id": result.job_id,
|
426 |
+
"dream_step_index": result.dream_step_index,
|
427 |
+
"action": action,
|
428 |
+
"tokens": dreamt_tokens,
|
429 |
+
"image_filename": image_filename
|
430 |
+
})
|
431 |
+
|
432 |
+
image_bytes = io.BytesIO()
|
433 |
+
# this probably is not as smooth as it could be
|
434 |
+
T.ToPILImage()(result.dreamt_image[0, 0]).save(image_bytes, format="PNG")
|
435 |
+
output_image_bytes[image_filename] = image_bytes.getvalue()
|
436 |
+
|
437 |
+
# Write a zip file with all the images
|
438 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
|
439 |
+
zip_bytes = io.BytesIO()
|
440 |
+
with zipfile.ZipFile(zip_bytes, "w") as z:
|
441 |
+
for filename, bytes in output_image_bytes.items():
|
442 |
+
z.writestr(filename, bytes)
|
443 |
+
# Write the json
|
444 |
+
z.writestr(PREDICTION_JSON_FILENAME, json.dumps(output_json))
|
445 |
+
|
446 |
+
zip_bytes.seek(0)
|
447 |
+
|
448 |
+
return send_file(
|
449 |
+
zip_bytes,
|
450 |
+
mimetype="zip",
|
451 |
+
as_attachment=True,
|
452 |
+
download_name=f"dreaming_results_{timestamp}.zip"
|
453 |
+
)
|
454 |
+
|
455 |
+
def cancel_job(self, request, request_json):
|
456 |
+
if "job_id" not in request_json:
|
457 |
+
return make_response("job_id not in request", 400)
|
458 |
+
job_id = request_json["job_id"]
|
459 |
+
self.cancelled_jobs.add(job_id)
|
460 |
+
# Cancel all jobs in the queue with this id
|
461 |
+
for job_queue in self.cancelled_jobs_queues:
|
462 |
+
job_queue.put(job_id)
|
463 |
+
return make_response("OK", 200)
|
464 |
+
|
465 |
+
|
466 |
+
def main_run(args):
|
467 |
+
app = Flask(__name__)
|
468 |
+
|
469 |
+
num_workers = th.cuda.device_count()
|
470 |
+
if num_workers == 0:
|
471 |
+
raise RuntimeError("No CUDA devices found. Cannot run Dreamer.")
|
472 |
+
|
473 |
+
server = DreamerServer(num_workers, args)
|
474 |
+
quit_flag = mp.Event()
|
475 |
+
|
476 |
+
# Start the dreamer worker(s)
|
477 |
+
dreamer_worker_processes = []
|
478 |
+
for device_i in range(num_workers):
|
479 |
+
device = f"cuda:{device_i}"
|
480 |
+
dreamer_worker_process = mp.Process(
|
481 |
+
target=dreamer_worker,
|
482 |
+
args=(server.jobs, server.results_queue, server.cancelled_jobs_queues[device_i], quit_flag, device, args)
|
483 |
+
)
|
484 |
+
dreamer_worker_process.daemon = True
|
485 |
+
dreamer_worker_process.start()
|
486 |
+
dreamer_worker_processes.append(dreamer_worker_process)
|
487 |
+
|
488 |
+
# Add the API endpoints
|
489 |
+
@app.route('/')
|
490 |
+
def details():
|
491 |
+
return server.get_details()
|
492 |
+
|
493 |
+
@app.route('/new_job', methods=['POST'])
|
494 |
+
def new_job():
|
495 |
+
request_json = json.loads(request.form["json"])
|
496 |
+
return server.add_new_job(request, request_json)
|
497 |
+
|
498 |
+
@app.route('/get_job_results', methods=['GET'])
|
499 |
+
def get_results():
|
500 |
+
# the "Json" is now in regular GET payload/parameters
|
501 |
+
request_json = {"job_ids": request.args.getlist("job_ids")}
|
502 |
+
return server.get_new_results(request, request_json)
|
503 |
+
|
504 |
+
@app.route('/cancel_job', methods=['GET'])
|
505 |
+
def cancel_job():
|
506 |
+
request_json = request.args.to_dict()
|
507 |
+
return server.cancel_job(request, request_json)
|
508 |
+
|
509 |
+
app.run(host="0.0.0.0", port=args.port, debug=args.debug)
|
510 |
+
|
511 |
+
# Cleanup
|
512 |
+
quit_flag.set()
|
513 |
+
for dreamer_worker_process in dreamer_worker_processes:
|
514 |
+
dreamer_worker_process.join()
|
515 |
+
|
516 |
+
|
517 |
+
if __name__ == '__main__':
|
518 |
+
args = parser.parse_args()
|
519 |
+
main_run(args)
|
setup_local.sh
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Tested using Python 3.9
|
2 |
+
|
3 |
+
echo "Making and activating a new virtual environment..."
|
4 |
+
python3.9 -m venv venv
|
5 |
+
|
6 |
+
echo "Activating the virtual environment..."
|
7 |
+
source venv/bin/activate
|
8 |
+
|
9 |
+
echo "Upgrading pip..."
|
10 |
+
pip install --upgrade pip
|
11 |
+
|
12 |
+
echo "Instaling the required packages..."
|
13 |
+
pip install -r requirements.txt
|
14 |
+
|
15 |
+
echo "Instaling the exiftool package for adding file metadata on Linux..."
|
16 |
+
sudo apt install -y exiftool
|
17 |
+
|
18 |
+
echo "Installing ffmpeg..."
|
19 |
+
sudo apt install ffmpeg
|
20 |
+
|
21 |
+
echo "All packages installed successfully!"
|
wham/models/nn/model_blocks.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
"""
|
4 |
+
Some Utility blocks for ViT-VQGAN.
|
5 |
+
|
6 |
+
ConvNeXt blocks are based on:
|
7 |
+
Liu, Zhuang, et al. "A convnet for the 2020s."
|
8 |
+
Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022.
|
9 |
+
"""
|
10 |
+
|
11 |
+
|
12 |
+
class ConvNextDownsampleBig(nn.Module):
|
13 |
+
def __init__(self, c_in, c_out):
|
14 |
+
super().__init__()
|
15 |
+
self.group_norm = nn.GroupNorm(c_in, c_in)
|
16 |
+
self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=8, stride=4, padding=0)
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
return self.conv1(self.group_norm(x))
|
20 |
+
|
21 |
+
|
22 |
+
class ConvNextBlock(nn.Module):
|
23 |
+
def __init__(self, channels):
|
24 |
+
super().__init__()
|
25 |
+
self.conv1 = nn.Conv2d(channels, channels, kernel_size=7, stride=1, padding=7 // 2, groups=channels) # 'Depthwise' conv
|
26 |
+
self.group_norm = nn.GroupNorm(channels, channels) # Should be equivalent to layernorm
|
27 |
+
|
28 |
+
# Transformer-style non-linearity
|
29 |
+
self.conv2 = nn.Conv2d(channels, channels * 4, kernel_size=1, stride=1, padding=0)
|
30 |
+
self.activation = nn.GELU()
|
31 |
+
self.conv3 = nn.Conv2d(channels * 4, channels, kernel_size=1, stride=1, padding=0)
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
y = self.conv1(x)
|
35 |
+
y = self.group_norm(y)
|
36 |
+
y = self.conv2(y)
|
37 |
+
y = self.activation(y)
|
38 |
+
y = self.conv3(y)
|
39 |
+
return x + y
|
40 |
+
|
41 |
+
|
42 |
+
class ConvNextDownsample(nn.Module):
|
43 |
+
def __init__(self, c_in, c_out):
|
44 |
+
super().__init__()
|
45 |
+
self.group_norm = nn.GroupNorm(c_in, c_in)
|
46 |
+
self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
return self.conv1(self.group_norm(x))
|
wham/models/nn/nanoGPT.py
ADDED
@@ -0,0 +1,665 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# From https://github.com/karpathy/nanoGPT/blob/master/model.py - Thanks Andrej Karpathy
|
2 |
+
|
3 |
+
# MIT License
|
4 |
+
# Copyright (c) 2022 Andrej Karpathy
|
5 |
+
# 2023 Microsoft Research
|
6 |
+
|
7 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
8 |
+
# of this software and associated documentation files (the "Software"), to deal
|
9 |
+
# in the Software without restriction, including without limitation the rights
|
10 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
11 |
+
# copies of the Software, and to permit persons to whom the Software is
|
12 |
+
# furnished to do so, subject to the following conditions:
|
13 |
+
|
14 |
+
# The above copyright notice and this permission notice shall be included in all
|
15 |
+
# copies or substantial portions of the Software.
|
16 |
+
|
17 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
18 |
+
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
19 |
+
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
20 |
+
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
21 |
+
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
22 |
+
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
23 |
+
# OR OTHER DEALINGS IN THE SOFTWARE.
|
24 |
+
|
25 |
+
|
26 |
+
"""
|
27 |
+
Full definition of a GPT Language Model, all of it in this single file.
|
28 |
+
References:
|
29 |
+
1) the official GPT-2 TensorFlow implementation released by OpenAI:
|
30 |
+
https://github.com/openai/gpt-2/blob/master/src/model.py
|
31 |
+
2) huggingface/transformers PyTorch implementation:
|
32 |
+
https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
|
33 |
+
"""
|
34 |
+
|
35 |
+
from dataclasses import dataclass
|
36 |
+
import inspect
|
37 |
+
import math
|
38 |
+
|
39 |
+
import torch
|
40 |
+
import torch.nn as nn
|
41 |
+
from torch.nn import functional as F
|
42 |
+
|
43 |
+
NEGATIVE_INFINITE_FLOAT = -float("inf")
|
44 |
+
CROSS_ENTROPY_INVALID_CLASS_TARGET = -1
|
45 |
+
|
46 |
+
# @torch.jit.script # good to enable when not using torch.compile, disable when using (our default)
|
47 |
+
def new_gelu(x):
|
48 |
+
"""
|
49 |
+
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
|
50 |
+
Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
|
51 |
+
"""
|
52 |
+
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
53 |
+
|
54 |
+
|
55 |
+
def limit_logits_to_valid_range(logits, valid_token_range):
|
56 |
+
"""
|
57 |
+
MODIFIES logits INPLACE.
|
58 |
+
Mask out invalid positions in the logits tensor with -inf so they are not considered by the softmax.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
logits: logits tensor of shape (batch_size, vocab_size)
|
62 |
+
valid_token_range: tuple of (start, end) indices of valid positions in the logits tensor (inclusive).
|
63 |
+
Everything outside is masked out with -inf.
|
64 |
+
"""
|
65 |
+
logits[:, : valid_token_range[0]] = NEGATIVE_INFINITE_FLOAT
|
66 |
+
logits[:, valid_token_range[1] + 1 :] = NEGATIVE_INFINITE_FLOAT
|
67 |
+
|
68 |
+
|
69 |
+
def default_sample_token(logits, valid_token_range=None, temperature=1.0, deterministic=False, top_k=None, top_p=None, min_tokens_to_keep=1):
|
70 |
+
"""
|
71 |
+
Given a vector of logits, sample and return an index according to settings.
|
72 |
+
|
73 |
+
logits: tensor of shape (batch_size, vocab_size)
|
74 |
+
|
75 |
+
valid_token_range should be a tuple, specifying start and end indices we'd like to sample from (inclusive).
|
76 |
+
If None, we'll sample from the full vocab.
|
77 |
+
|
78 |
+
If deterministic is True, we'll take the argmax of the logits which implies top-k sampling with top_k = 1, therefore user inputted values of top_p and top_k will be ignored.
|
79 |
+
|
80 |
+
Otherwise, either top-p (float) value can be specified or top-k (int) value can be specified.
|
81 |
+
Top-p (float top_p) : only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
|
82 |
+
Top-k (int top_k) : selects top_k tokens for generation.
|
83 |
+
min_tokens_to_keep: Used with both top_p and top_k sampling.
|
84 |
+
"""
|
85 |
+
assert top_k is None or top_p is None, "Can only specify one of top-k or top-p sampling."
|
86 |
+
if temperature < 0.1:
|
87 |
+
# Avoid too low a temp, especially 0
|
88 |
+
temperature = 0.1
|
89 |
+
logits = logits / temperature
|
90 |
+
if valid_token_range is not None:
|
91 |
+
limit_logits_to_valid_range(logits, valid_token_range)
|
92 |
+
if deterministic:
|
93 |
+
selected_logits = select_logits(logits, top_k=1)
|
94 |
+
else:
|
95 |
+
selected_logits = select_logits(logits, top_p=top_p, top_k=top_k, min_tokens_to_keep=min_tokens_to_keep)
|
96 |
+
probs = F.softmax(selected_logits, dim=-1)
|
97 |
+
# More robustly handle errors in the sampling here
|
98 |
+
sampled_idx = torch.multinomial(probs, num_samples=1).squeeze(-1)
|
99 |
+
return sampled_idx
|
100 |
+
|
101 |
+
|
102 |
+
def select_logits(logits, top_k=None, top_p=None, min_tokens_to_keep=1):
|
103 |
+
"""
|
104 |
+
Select from original logits using top-k or top-p sampling.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
logits (torch.Tensor): Logits to sample from.
|
108 |
+
k (int, optional): Number of top elements to consider in top-k sampling.
|
109 |
+
p (float, optional): Threshold probability for top-p sampling.
|
110 |
+
min_tokens_to_keep (int, optional): Minimum number of tokens to keep in the output.
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
logits: Selected logits after top-k or top-p sampling. Sets all logits outside the selected ones to NEGATIVE_INFINITE_FLOAT.
|
114 |
+
"""
|
115 |
+
assert top_k is None or top_p is None, "Can only specify one of top-k or top-p sampling."
|
116 |
+
min_tokens_to_keep = min(min_tokens_to_keep, logits.size(-1))
|
117 |
+
if top_k is not None:
|
118 |
+
if not isinstance(top_k, int) or top_k <= 0:
|
119 |
+
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
|
120 |
+
|
121 |
+
# Top-k sampling
|
122 |
+
top_k = max(top_k, min_tokens_to_keep)
|
123 |
+
top_k = min(top_k, logits.size(-1))
|
124 |
+
top_k_logits, _ = torch.topk(logits, top_k)
|
125 |
+
indices_to_remove = logits < top_k_logits[..., -1:]
|
126 |
+
logits = torch.where(indices_to_remove, NEGATIVE_INFINITE_FLOAT, logits)
|
127 |
+
|
128 |
+
elif top_p is not None:
|
129 |
+
top_p = float(top_p)
|
130 |
+
if top_p < 0 or top_p > 1.0:
|
131 |
+
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
|
132 |
+
|
133 |
+
# Top-p sampling
|
134 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
135 |
+
sorted_probs = torch.softmax(sorted_logits, dim=-1)
|
136 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
137 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
138 |
+
|
139 |
+
# Remove tokens with cumulative probability above the threshold
|
140 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = False
|
141 |
+
|
142 |
+
# scatter sorted tensors to original indexing
|
143 |
+
indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
|
144 |
+
logits = torch.where(indices_to_remove, NEGATIVE_INFINITE_FLOAT, logits)
|
145 |
+
|
146 |
+
else:
|
147 |
+
# Return logits as is
|
148 |
+
pass
|
149 |
+
|
150 |
+
return logits
|
151 |
+
|
152 |
+
|
153 |
+
class LayerNorm(nn.Module):
|
154 |
+
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
|
155 |
+
|
156 |
+
def __init__(self, ndim, bias):
|
157 |
+
super().__init__()
|
158 |
+
self.weight = nn.Parameter(torch.ones(ndim))
|
159 |
+
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
|
160 |
+
|
161 |
+
def forward(self, input):
|
162 |
+
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
|
163 |
+
|
164 |
+
class LayerNormMinimal(nn.Module):
|
165 |
+
"""LayerNorm like above, but without learnable parameters"""
|
166 |
+
|
167 |
+
def __init__(self, ndim, bias):
|
168 |
+
super().__init__()
|
169 |
+
self.ndim = (ndim,)
|
170 |
+
|
171 |
+
def forward(self, input):
|
172 |
+
return F.layer_norm(input, self.ndim, eps=1e-5)
|
173 |
+
|
174 |
+
|
175 |
+
class CausalSelfAttention(nn.Module):
|
176 |
+
def __init__(self, config):
|
177 |
+
super().__init__()
|
178 |
+
assert config.n_embd % config.n_head == 0
|
179 |
+
# key, query, value projections for all heads, but in a batch
|
180 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
181 |
+
# output projection
|
182 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
183 |
+
# regularization
|
184 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
185 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
186 |
+
self.n_head = config.n_head
|
187 |
+
self.n_embd = config.n_embd
|
188 |
+
self.dropout = config.dropout
|
189 |
+
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
|
190 |
+
self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0
|
191 |
+
# causal mask to ensure that attention is only applied to the left in the input sequence
|
192 |
+
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size), persistent=False)
|
193 |
+
|
194 |
+
self.cached_k = None
|
195 |
+
self.cached_v = None
|
196 |
+
self.current_cache_size = 0
|
197 |
+
|
198 |
+
def _manual_causal_attention(self, q, k, v, mask):
|
199 |
+
# q, k and v should be of shape (B, nh, T, hs)
|
200 |
+
token_len = q.size(-2)
|
201 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
202 |
+
att = att.masked_fill(mask[:, :, :token_len, :token_len] == 0, float("-inf"))
|
203 |
+
att = F.softmax(att, dim=-1)
|
204 |
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
205 |
+
return y
|
206 |
+
|
207 |
+
def forward(self, x, cache=False):
|
208 |
+
batch_size, token_len, n_embd = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
209 |
+
|
210 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
211 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
212 |
+
k = k.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
213 |
+
q = q.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
214 |
+
v = v.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
215 |
+
|
216 |
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
217 |
+
if self.flash and not cache:
|
218 |
+
# efficient attention using Flash Attention CUDA kernels
|
219 |
+
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
|
220 |
+
elif cache:
|
221 |
+
# manual implemention of attention (as below), but cache arrays we can reuse
|
222 |
+
assert token_len == 1, "Cache only works for single step"
|
223 |
+
assert self.cached_k is not None, "Must call reset_cache() before using cache"
|
224 |
+
assert self.current_cache_size < self.cached_k.size(2), "Trying to generate more steps than provided in reset_cache() `num_steps_to_come`"
|
225 |
+
assert self.dropout == 0.0, "Dropout not supported with caching"
|
226 |
+
this_step_q = q
|
227 |
+
self.cached_k[:, :, self.current_cache_size, :] = k[:, :, 0, :]
|
228 |
+
self.cached_v[:, :, self.current_cache_size, :] = v[:, :, 0, :]
|
229 |
+
# Remove the zero parts
|
230 |
+
k = self.cached_k[:, :, : self.current_cache_size + 1, :]
|
231 |
+
# compute last row of the attention mask
|
232 |
+
this_step_att_row = (this_step_q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
233 |
+
this_step_att_row = F.softmax(this_step_att_row, dim=-1)
|
234 |
+
# We only need output for the current step
|
235 |
+
y = this_step_att_row @ self.cached_v[:, :, : self.current_cache_size + 1, :]
|
236 |
+
# Update cache
|
237 |
+
self.current_cache_size += 1
|
238 |
+
else:
|
239 |
+
y = self._manual_causal_attention(q, k, v, self.bias)
|
240 |
+
y = y.transpose(1, 2).contiguous().view(batch_size, token_len, n_embd) # re-assemble all head outputs side by side
|
241 |
+
|
242 |
+
# output projection
|
243 |
+
y = self.resid_dropout(self.c_proj(y))
|
244 |
+
return y
|
245 |
+
|
246 |
+
def reset_cache(self, x, num_steps_to_come):
|
247 |
+
"""
|
248 |
+
Reset caches by doing initial pass with x data (returning same output as forward).
|
249 |
+
Also set the number of steps to come, which is used to initialize the buffers
|
250 |
+
"""
|
251 |
+
batch_size, token_len, n_embd = x.size()
|
252 |
+
|
253 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
254 |
+
k = k.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
255 |
+
q = q.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
256 |
+
v = v.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
257 |
+
|
258 |
+
# Use SDPA instead of a manual implementation
|
259 |
+
# y = self._manual_causal_attention(q, k, v, self.bias)
|
260 |
+
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
|
261 |
+
|
262 |
+
y = y.transpose(1, 2).contiguous().view(batch_size, token_len, n_embd)
|
263 |
+
# output projection
|
264 |
+
y = self.resid_dropout(self.c_proj(y))
|
265 |
+
|
266 |
+
# Create full k,q,v for predicting all future steps.
|
267 |
+
# Just null-out the last num_steps_to_come-1 steps
|
268 |
+
pad_size = num_steps_to_come
|
269 |
+
self.current_cache_size = token_len
|
270 |
+
self.cached_k = torch.cat([k, torch.zeros(batch_size, self.n_head, pad_size, n_embd // self.n_head, device=k.device)], dim=2)
|
271 |
+
self.cached_v = torch.cat([v, torch.zeros(batch_size, self.n_head, pad_size, n_embd // self.n_head, device=v.device)], dim=2)
|
272 |
+
|
273 |
+
return y
|
274 |
+
|
275 |
+
class SelfAttention(nn.Module):
|
276 |
+
"""
|
277 |
+
Non-causal self-attention layer, the same as CausalSelfAttention but without the causal mask.
|
278 |
+
Duplicating the code to keep this separate for clarity.
|
279 |
+
"""
|
280 |
+
|
281 |
+
def __init__(self, config):
|
282 |
+
super().__init__()
|
283 |
+
assert config.n_embd % config.n_head == 0
|
284 |
+
# key, query, value projections for all heads, but in a batch
|
285 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
286 |
+
# output projection
|
287 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
288 |
+
# regularization
|
289 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
290 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
291 |
+
self.n_head = config.n_head
|
292 |
+
self.n_embd = config.n_embd
|
293 |
+
self.dropout = config.dropout
|
294 |
+
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
|
295 |
+
self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0
|
296 |
+
assert self.flash, "SelfAttention only supports flash attention for now."
|
297 |
+
|
298 |
+
self.register_buffer("attn_mask", torch.ones((config.block_size, config.block_size)).bool().unsqueeze(0).unsqueeze(0))
|
299 |
+
|
300 |
+
def forward(self, x):
|
301 |
+
batch_size, token_len, n_embd = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
302 |
+
|
303 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
304 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
305 |
+
k = k.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
306 |
+
q = q.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
307 |
+
v = v.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
308 |
+
|
309 |
+
# self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
310 |
+
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=self.attn_mask, dropout_p=self.dropout, is_causal=False)
|
311 |
+
y = y.transpose(1, 2).contiguous().view(batch_size, token_len, n_embd) # re-assemble all head outputs side by side
|
312 |
+
|
313 |
+
# output projection
|
314 |
+
y = self.resid_dropout(self.c_proj(y))
|
315 |
+
return y
|
316 |
+
|
317 |
+
class MLP(nn.Module):
|
318 |
+
def __init__(self, config):
|
319 |
+
super().__init__()
|
320 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
321 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
322 |
+
self.dropout = nn.Dropout(config.dropout)
|
323 |
+
|
324 |
+
def forward(self, x):
|
325 |
+
x = self.c_fc(x)
|
326 |
+
x = new_gelu(x)
|
327 |
+
x = self.c_proj(x)
|
328 |
+
x = self.dropout(x)
|
329 |
+
return x
|
330 |
+
|
331 |
+
class GELU_MLP(nn.Module):
|
332 |
+
"""MLP Block using PyTorch's native GELU activation function"""
|
333 |
+
def __init__(self, config):
|
334 |
+
super().__init__()
|
335 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
336 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
337 |
+
self.dropout = nn.Dropout(config.dropout)
|
338 |
+
|
339 |
+
def forward(self, x):
|
340 |
+
x = self.c_fc(x)
|
341 |
+
x = F.gelu(x, approximate="tanh")
|
342 |
+
x = self.c_proj(x)
|
343 |
+
x = self.dropout(x)
|
344 |
+
return x
|
345 |
+
|
346 |
+
|
347 |
+
class Block(nn.Module):
|
348 |
+
def __init__(self, config):
|
349 |
+
super().__init__()
|
350 |
+
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
|
351 |
+
self.attn = CausalSelfAttention(config)
|
352 |
+
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
|
353 |
+
self.mlp = MLP(config)
|
354 |
+
|
355 |
+
def forward(self, x, cache=False, reset_cache_with_num_steps_to_come=None):
|
356 |
+
"""
|
357 |
+
Args:
|
358 |
+
cache: If True, use the cache to predict the next token (assumes model was initialized with `reset_cache`).
|
359 |
+
reset_cache_with_num_steps_to_come:
|
360 |
+
If not None, reset and prepare the cache for cached prediction of the next `reset_cache_with_num_steps_to_come` tokens.
|
361 |
+
This is same as calling `reset_cache` with the same argument, but we include option here in `forward` to support torch hook functions (used to get embeddings from this module output).
|
362 |
+
|
363 |
+
Caching example:
|
364 |
+
```
|
365 |
+
# Initialize model with reset_cache_with_num_steps_to_come=10
|
366 |
+
outputs[0] = model(inputs, reset_cache_with_num_steps_to_come=10)
|
367 |
+
# Predict next 10 tokens using cache
|
368 |
+
for i in range(10):
|
369 |
+
outputs[i+1] = model(inputs, cache=True)
|
370 |
+
```
|
371 |
+
"""
|
372 |
+
if reset_cache_with_num_steps_to_come:
|
373 |
+
return self.reset_cache(x, num_steps_to_come=reset_cache_with_num_steps_to_come)
|
374 |
+
x = x + self.attn(self.ln_1(x), cache=cache)
|
375 |
+
x = x + self.mlp(self.ln_2(x))
|
376 |
+
return x
|
377 |
+
|
378 |
+
def reset_cache(self, x, num_steps_to_come):
|
379 |
+
x = x + self.attn.reset_cache(self.ln_1(x), num_steps_to_come=num_steps_to_come)
|
380 |
+
x = x + self.mlp(self.ln_2(x))
|
381 |
+
return x
|
382 |
+
|
383 |
+
class BlockV2(nn.Module):
|
384 |
+
"""
|
385 |
+
Compared to the Block in the original implementation, this one uses non-parametric LayerNorm and Pytorch's GELU.
|
386 |
+
These two changes save significant vram but are incompatible with previously trained models.
|
387 |
+
Hence the separate class.
|
388 |
+
"""
|
389 |
+
|
390 |
+
def __init__(self, config):
|
391 |
+
super().__init__()
|
392 |
+
self.ln_1 = LayerNormMinimal(config.n_embd, bias=config.bias)
|
393 |
+
self.attn = CausalSelfAttention(config)
|
394 |
+
self.ln_2 = LayerNormMinimal(config.n_embd, bias=config.bias)
|
395 |
+
self.mlp = GELU_MLP(config)
|
396 |
+
|
397 |
+
def forward(self, x, cache=False, reset_cache_with_num_steps_to_come=None):
|
398 |
+
if reset_cache_with_num_steps_to_come:
|
399 |
+
return self.reset_cache(x, num_steps_to_come=reset_cache_with_num_steps_to_come)
|
400 |
+
x = x + self.attn(self.ln_1(x), cache=cache)
|
401 |
+
x = x + self.mlp(self.ln_2(x))
|
402 |
+
return x
|
403 |
+
|
404 |
+
def reset_cache(self, x, num_steps_to_come):
|
405 |
+
x = x + self.attn.reset_cache(self.ln_1(x), num_steps_to_come=num_steps_to_come)
|
406 |
+
x = x + self.mlp(self.ln_2(x))
|
407 |
+
return x
|
408 |
+
|
409 |
+
class SelfAttentionBlock(nn.Module):
|
410 |
+
def __init__(self, config):
|
411 |
+
super().__init__()
|
412 |
+
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
|
413 |
+
self.attn = SelfAttention(config)
|
414 |
+
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
|
415 |
+
self.mlp = MLP(config)
|
416 |
+
|
417 |
+
def forward(self, x):
|
418 |
+
x = x + self.attn(self.ln_1(x))
|
419 |
+
x = x + self.mlp(self.ln_2(x))
|
420 |
+
return x
|
421 |
+
|
422 |
+
@dataclass
|
423 |
+
class GPTConfig:
|
424 |
+
block_size: int = 1024
|
425 |
+
vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
|
426 |
+
n_layer: int = 12
|
427 |
+
n_head: int = 12
|
428 |
+
n_embd: int = 768
|
429 |
+
dropout: float = 0.0
|
430 |
+
bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
|
431 |
+
version: int = 1 # Version 1 is the original GPT, Version 2 is the one with non-parametric LayerNorm and Pytorch's GELU
|
432 |
+
|
433 |
+
|
434 |
+
class GPT(nn.Module):
|
435 |
+
def __init__(self, config):
|
436 |
+
super().__init__()
|
437 |
+
assert config.vocab_size is not None
|
438 |
+
assert config.block_size is not None
|
439 |
+
self.config = config
|
440 |
+
|
441 |
+
self.version = config.version
|
442 |
+
|
443 |
+
print(f"[nanoGPT] creating model with version {self.version}")
|
444 |
+
|
445 |
+
if self.version == 1:
|
446 |
+
transformer_dict = dict(
|
447 |
+
wpe=nn.Embedding(config.block_size, config.n_embd),
|
448 |
+
drop=nn.Dropout(config.dropout),
|
449 |
+
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
|
450 |
+
ln_f=LayerNorm(config.n_embd, bias=config.bias),
|
451 |
+
)
|
452 |
+
elif self.version == 2:
|
453 |
+
transformer_dict = dict(
|
454 |
+
wpe=nn.Embedding(config.block_size, config.n_embd),
|
455 |
+
drop=nn.Dropout(config.dropout),
|
456 |
+
h=nn.ModuleList([BlockV2(config) for _ in range(config.n_layer)]),
|
457 |
+
ln_f=LayerNorm(config.n_embd, bias=config.bias), # This one is still parametric due to user error
|
458 |
+
)
|
459 |
+
|
460 |
+
transformer_dict["wte"] = nn.Embedding(config.vocab_size, config.n_embd)
|
461 |
+
self.transformer = nn.ModuleDict(transformer_dict)
|
462 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
463 |
+
# with weight tying when using torch.compile() some warnings get generated:
|
464 |
+
# "UserWarning: functional_call was passed multiple values for tied weights.
|
465 |
+
# This behavior is deprecated and will be an error in future versions"
|
466 |
+
# not 100% sure what this is, so far seems to be harmless.
|
467 |
+
self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
|
468 |
+
|
469 |
+
# init all weights
|
470 |
+
self.apply(self._init_weights)
|
471 |
+
# apply special scaled init to the residual projections, per GPT-2 paper
|
472 |
+
for pn, p in self.named_parameters():
|
473 |
+
if pn.endswith("c_proj.weight"):
|
474 |
+
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
|
475 |
+
|
476 |
+
def get_num_params(self, non_embedding=True):
|
477 |
+
"""
|
478 |
+
Return the number of parameters in the model.
|
479 |
+
For non-embedding count (default), the position embeddings get subtracted.
|
480 |
+
The token embeddings would too, except due to the parameter sharing these
|
481 |
+
params are actually used as weights in the final layer, so we include them.
|
482 |
+
"""
|
483 |
+
n_params = sum(p.numel() for p in self.parameters())
|
484 |
+
if non_embedding:
|
485 |
+
n_params -= self.transformer.wpe.weight.numel()
|
486 |
+
return n_params
|
487 |
+
|
488 |
+
def _init_weights(self, module):
|
489 |
+
if isinstance(module, nn.Linear):
|
490 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
491 |
+
if module.bias is not None:
|
492 |
+
torch.nn.init.zeros_(module.bias)
|
493 |
+
elif isinstance(module, nn.Embedding):
|
494 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
495 |
+
|
496 |
+
def _apply_pos_encoding(self, x):
|
497 |
+
device = x.device
|
498 |
+
token_len = x.size(1)
|
499 |
+
pos = torch.arange(0, token_len, dtype=torch.long, device=device).unsqueeze(0)
|
500 |
+
pos_emb = self.transformer.wpe(pos)
|
501 |
+
x = x + pos_emb
|
502 |
+
return x
|
503 |
+
|
504 |
+
def original_forward(self, idx, targets=None, loss_mask=None, loss_reduction="mean"):
|
505 |
+
batch_size, seq_len = idx.shape[:2]
|
506 |
+
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
507 |
+
x = self.transformer.drop(self._apply_pos_encoding(tok_emb))
|
508 |
+
for block in self.transformer.h:
|
509 |
+
x = block(x)
|
510 |
+
x = self.transformer.ln_f(x)
|
511 |
+
|
512 |
+
if targets is not None:
|
513 |
+
# if we are given some desired targets also calculate the loss
|
514 |
+
logits = self.lm_head(x)
|
515 |
+
if loss_mask is not None:
|
516 |
+
# Feeding target = CROSS_ENTROPY_INVALID_CLASS_TARGET to cross_entropy will ignore the loss
|
517 |
+
# for that position. This is useful for padding tokens.
|
518 |
+
targets[loss_mask == 0] = CROSS_ENTROPY_INVALID_CLASS_TARGET
|
519 |
+
loss = F.cross_entropy(
|
520 |
+
logits.view(batch_size * seq_len, self.config.vocab_size), targets.view(-1), ignore_index=CROSS_ENTROPY_INVALID_CLASS_TARGET, reduction=loss_reduction
|
521 |
+
)
|
522 |
+
if loss_reduction == "none":
|
523 |
+
# Reshape back into batch_size and seq_len
|
524 |
+
loss = loss.view(batch_size, seq_len)
|
525 |
+
else:
|
526 |
+
# inference-time mini-optimization: only forward the lm_head on the very last position
|
527 |
+
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
528 |
+
loss = None
|
529 |
+
|
530 |
+
return logits, loss
|
531 |
+
|
532 |
+
def forward(self, x, targets=None, loss_mask=None, loss_reduction="mean"):
|
533 |
+
token_len = x.size(1)
|
534 |
+
assert token_len <= self.config.block_size, f"Cannot forward sequence of length {token_len}, block size is only {self.config.block_size}"
|
535 |
+
return self.original_forward(x, targets, loss_mask, loss_reduction)
|
536 |
+
|
537 |
+
@torch.no_grad()
|
538 |
+
def generate(self, idx, max_new_tokens, valid_token_range=None, temperature=1.0, top_k=None, raise_cropping=False, deterministic=False):
|
539 |
+
"""
|
540 |
+
valid_token_range should be a tuple, specifying start and end indices we'd like to sample from (inclusive).
|
541 |
+
if None, we'll sample from the full vocab.
|
542 |
+
|
543 |
+
If raise_cropping is True, we'll raise an error if we need to crop the sequence context.
|
544 |
+
"""
|
545 |
+
if valid_token_range is None:
|
546 |
+
valid_token_range = (0, self.config.vocab_size - 1)
|
547 |
+
assert len(valid_token_range) == 2
|
548 |
+
assert valid_token_range[0] < valid_token_range[1]
|
549 |
+
for _ in range(max_new_tokens):
|
550 |
+
# if the sequence context is growing too long we must crop it at block_size
|
551 |
+
idx_cond = idx
|
552 |
+
if idx.size(1) > self.config.block_size:
|
553 |
+
if raise_cropping:
|
554 |
+
raise ValueError("Tried to crop idxs but flag told to raise this")
|
555 |
+
else:
|
556 |
+
idx_cond = idx[:, -self.config.block_size :]
|
557 |
+
# forward the model to get the logits for the index in the sequence
|
558 |
+
logits, _ = self(idx_cond)
|
559 |
+
# pluck the logits at the final step and scale by desired temperature
|
560 |
+
logits = logits[:, -1, :] / temperature # logits is B T Vocabsize -> B Vocabsize
|
561 |
+
# optionally crop the logits to only the top k options
|
562 |
+
if top_k is not None:
|
563 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
564 |
+
logits[logits < v[:, [-1]]] = NEGATIVE_INFINITE_FLOAT
|
565 |
+
|
566 |
+
# Crop out the logits we don't want to sample from
|
567 |
+
if valid_token_range is not None:
|
568 |
+
limit_logits_to_valid_range(logits, valid_token_range)
|
569 |
+
|
570 |
+
# apply softmax to convert logits to (normalized) probabilities
|
571 |
+
probs = F.softmax(logits, dim=-1)
|
572 |
+
|
573 |
+
if deterministic:
|
574 |
+
# Take max of the results
|
575 |
+
idx_next = torch.argmax(probs, dim=-1, keepdim=True)
|
576 |
+
else:
|
577 |
+
# sample from the distribution
|
578 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
579 |
+
# append sampled index to the running sequence and continue
|
580 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
581 |
+
|
582 |
+
return idx
|
583 |
+
|
584 |
+
@torch.no_grad()
|
585 |
+
def optimized_generate(
|
586 |
+
self,
|
587 |
+
idx,
|
588 |
+
num_new_tokens,
|
589 |
+
valid_token_ranges=None,
|
590 |
+
temperature=1.0,
|
591 |
+
deterministic=False,
|
592 |
+
raise_cropping=False,
|
593 |
+
top_k=None,
|
594 |
+
top_p=None,
|
595 |
+
min_tokens_to_keep=1,
|
596 |
+
):
|
597 |
+
"""
|
598 |
+
Generate function but optimized by caching the results in transformer blocks (think this is referred to as "attention caching").
|
599 |
+
The higher the num_new_tokens, the more the speedup compared to original generate.
|
600 |
+
|
601 |
+
Caveat: the context length + num_new_tokens must be less than the block size. This means that the first
|
602 |
+
generated tokens do not have full context length.
|
603 |
+
|
604 |
+
valid_token_ranges should be None or list of length num_new_tokens, specifying valid range for tokens for every step
|
605 |
+
"""
|
606 |
+
# Properly compile the modules used and/or quantize for improved speed.
|
607 |
+
logit_layer = self.lm_head
|
608 |
+
embedder_fn = self.transformer.wte
|
609 |
+
|
610 |
+
if valid_token_ranges is None:
|
611 |
+
valid_token_ranges = [[0, self.config.vocab_size] for _ in range(num_new_tokens)]
|
612 |
+
assert len(valid_token_ranges) == num_new_tokens, "valid_token_ranges should be list of length num_new_tokens or None"
|
613 |
+
|
614 |
+
_, token_len = idx.size()
|
615 |
+
if token_len + num_new_tokens > self.config.block_size:
|
616 |
+
raise ValueError("Can't use optimized generation with num_new_tokens + context_length > block_size")
|
617 |
+
new_idxs = torch.zeros(idx.size(0), num_new_tokens, dtype=torch.long, device=idx.device)
|
618 |
+
# First, we need to cull the sequence to the block size
|
619 |
+
# and remove first max_new_tokens so we can reuse same position embeddings
|
620 |
+
# and not have to recompute them
|
621 |
+
num_original_tokens = idx.size(1)
|
622 |
+
original_idx = idx
|
623 |
+
if (num_original_tokens + num_new_tokens) > self.config.block_size:
|
624 |
+
if raise_cropping:
|
625 |
+
raise ValueError("Tried to crop idxs but flag told to raise this")
|
626 |
+
original_idx = idx[:, -self.config.block_size + num_new_tokens :]
|
627 |
+
original_pos = torch.arange(0, original_idx.size(1), dtype=torch.long, device=idx.device).unsqueeze(0)
|
628 |
+
# Now cache results with the original context
|
629 |
+
original_tok_emb = embedder_fn(original_idx)
|
630 |
+
original_pos_emb = self.transformer.wpe(original_pos)
|
631 |
+
original_x = original_tok_emb + original_pos_emb
|
632 |
+
for block in self.transformer.h:
|
633 |
+
# Reset the cache for each block, and cache new result
|
634 |
+
original_x = block(original_x, reset_cache_with_num_steps_to_come=num_new_tokens)
|
635 |
+
|
636 |
+
# Sample the first token
|
637 |
+
original_x = self.transformer.ln_f(original_x)
|
638 |
+
last_logit = logit_layer(original_x[:, [-1], :])
|
639 |
+
new_idxs[:, 0] = default_sample_token(
|
640 |
+
last_logit[:, -1, :], valid_token_ranges[0], temperature, deterministic, top_k=top_k, top_p=top_p, min_tokens_to_keep=min_tokens_to_keep
|
641 |
+
)
|
642 |
+
|
643 |
+
# Generate rest of the steps
|
644 |
+
for generation_idx in range(1, num_new_tokens):
|
645 |
+
# forward the model to get the logits for the index in the sequence
|
646 |
+
# This is the position of the latest generated token, not the currently going-to-be-generated token
|
647 |
+
latest_token_pos = num_original_tokens + generation_idx - 1
|
648 |
+
# We only need to pass in the latest token
|
649 |
+
newest_idx = new_idxs[:, generation_idx - 1].unsqueeze(-1)
|
650 |
+
newest_tok_emb = embedder_fn(newest_idx)
|
651 |
+
newest_pos_emb = self.transformer.wpe(torch.tensor(latest_token_pos, dtype=torch.long, device=idx.device).unsqueeze(0))
|
652 |
+
newest_x = newest_tok_emb + newest_pos_emb
|
653 |
+
for block in self.transformer.h:
|
654 |
+
newest_x = block(newest_x, cache=True)
|
655 |
+
|
656 |
+
newest_x = self.transformer.ln_f(newest_x)
|
657 |
+
newest_logit = logit_layer(newest_x)
|
658 |
+
# Check this function isn't slowing things down noticeably
|
659 |
+
new_idxs[:, generation_idx] = default_sample_token(
|
660 |
+
newest_logit[:, -1, :], valid_token_ranges[generation_idx], temperature, deterministic, top_k=top_k, top_p=top_p, min_tokens_to_keep=min_tokens_to_keep
|
661 |
+
)
|
662 |
+
|
663 |
+
# Combine indices
|
664 |
+
new_idxs = torch.cat((idx, new_idxs), dim=1)
|
665 |
+
return new_idxs
|
wham/models/pl/__init__.py
ADDED
File without changes
|
wham/models/pl/pl_base_model.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
|
3 |
+
class BaseTrainingModel(pl.LightningModule):
|
4 |
+
def __init__(self, **kwargs):
|
5 |
+
super().__init__(**kwargs)
|
wham/models/vqgan/taming/LICENSE
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
All files under this directory are originally from the taming-transformers repository:
|
2 |
+
https://github.com/CompVis/taming-transformers
|
3 |
+
|
4 |
+
Below is a copy of the original license
|
5 |
+
------------------------------------------------------------------------------
|
6 |
+
Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
|
7 |
+
|
8 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
9 |
+
of this software and associated documentation files (the "Software"), to deal
|
10 |
+
in the Software without restriction, including without limitation the rights
|
11 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
12 |
+
copies of the Software, and to permit persons to whom the Software is
|
13 |
+
furnished to do so, subject to the following conditions:
|
14 |
+
|
15 |
+
The above copyright notice and this permission notice shall be included in all
|
16 |
+
copies or substantial portions of the Software.
|
17 |
+
|
18 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
19 |
+
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
20 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
21 |
+
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
22 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
23 |
+
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
24 |
+
OR OTHER DEALINGS IN THE SOFTWARE./
|
wham/models/vqgan/taming/model.py
ADDED
@@ -0,0 +1,696 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# All files under this directory are originally from the taming-transformers repository:
|
2 |
+
# https://github.com/CompVis/taming-transformers
|
3 |
+
|
4 |
+
# MIT License
|
5 |
+
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
|
6 |
+
# 2023 Microsoft Research
|
7 |
+
|
8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
10 |
+
# in the Software without restriction, including without limitation the rights
|
11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
13 |
+
# furnished to do so, subject to the following conditions:
|
14 |
+
|
15 |
+
# The above copyright notice and this permission notice shall be included in all
|
16 |
+
# copies or substantial portions of the Software.
|
17 |
+
|
18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
19 |
+
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
20 |
+
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
21 |
+
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
22 |
+
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
23 |
+
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
24 |
+
# OR OTHER DEALINGS IN THE SOFTWARE.
|
25 |
+
|
26 |
+
import math
|
27 |
+
import torch
|
28 |
+
import torch.nn as nn
|
29 |
+
import numpy as np
|
30 |
+
|
31 |
+
|
32 |
+
def get_timestep_embedding(timesteps, embedding_dim):
|
33 |
+
"""
|
34 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
35 |
+
From Fairseq.
|
36 |
+
Build sinusoidal embeddings.
|
37 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
38 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
39 |
+
"""
|
40 |
+
assert len(timesteps.shape) == 1
|
41 |
+
|
42 |
+
half_dim = embedding_dim // 2
|
43 |
+
emb = math.log(10000) / (half_dim - 1)
|
44 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
45 |
+
emb = emb.to(device=timesteps.device)
|
46 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
47 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
48 |
+
if embedding_dim % 2 == 1: # zero pad
|
49 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
50 |
+
return emb
|
51 |
+
|
52 |
+
|
53 |
+
def nonlinearity(x):
|
54 |
+
# swish
|
55 |
+
return x * torch.sigmoid(x)
|
56 |
+
|
57 |
+
|
58 |
+
def Normalize(in_channels):
|
59 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
60 |
+
|
61 |
+
|
62 |
+
class Upsample(nn.Module):
|
63 |
+
def __init__(self, in_channels, with_conv):
|
64 |
+
super().__init__()
|
65 |
+
self.with_conv = with_conv
|
66 |
+
if self.with_conv:
|
67 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
71 |
+
if self.with_conv:
|
72 |
+
x = self.conv(x)
|
73 |
+
return x
|
74 |
+
|
75 |
+
|
76 |
+
class Downsample(nn.Module):
|
77 |
+
def __init__(self, in_channels, with_conv):
|
78 |
+
super().__init__()
|
79 |
+
self.with_conv = with_conv
|
80 |
+
if self.with_conv:
|
81 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
82 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
83 |
+
|
84 |
+
def forward(self, x):
|
85 |
+
if self.with_conv:
|
86 |
+
pad = (0, 1, 0, 1)
|
87 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
88 |
+
x = self.conv(x)
|
89 |
+
else:
|
90 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
91 |
+
return x
|
92 |
+
|
93 |
+
|
94 |
+
class ResnetBlock(nn.Module):
|
95 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
|
96 |
+
super().__init__()
|
97 |
+
self.in_channels = in_channels
|
98 |
+
out_channels = in_channels if out_channels is None else out_channels
|
99 |
+
self.out_channels = out_channels
|
100 |
+
self.use_conv_shortcut = conv_shortcut
|
101 |
+
|
102 |
+
self.norm1 = Normalize(in_channels)
|
103 |
+
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
104 |
+
if temb_channels > 0:
|
105 |
+
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
106 |
+
self.norm2 = Normalize(out_channels)
|
107 |
+
self.dropout = torch.nn.Dropout(dropout)
|
108 |
+
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
109 |
+
if self.in_channels != self.out_channels:
|
110 |
+
if self.use_conv_shortcut:
|
111 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
112 |
+
else:
|
113 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
114 |
+
|
115 |
+
def forward(self, x, temb):
|
116 |
+
h = x
|
117 |
+
h = self.norm1(h)
|
118 |
+
h = nonlinearity(h)
|
119 |
+
h = self.conv1(h)
|
120 |
+
|
121 |
+
if temb is not None:
|
122 |
+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
123 |
+
|
124 |
+
h = self.norm2(h)
|
125 |
+
h = nonlinearity(h)
|
126 |
+
h = self.dropout(h)
|
127 |
+
h = self.conv2(h)
|
128 |
+
|
129 |
+
if self.in_channels != self.out_channels:
|
130 |
+
if self.use_conv_shortcut:
|
131 |
+
x = self.conv_shortcut(x)
|
132 |
+
else:
|
133 |
+
x = self.nin_shortcut(x)
|
134 |
+
|
135 |
+
return x + h
|
136 |
+
|
137 |
+
|
138 |
+
class AttnBlock(nn.Module):
|
139 |
+
def __init__(self, in_channels):
|
140 |
+
super().__init__()
|
141 |
+
self.in_channels = in_channels
|
142 |
+
|
143 |
+
self.norm = Normalize(in_channels)
|
144 |
+
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
145 |
+
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
146 |
+
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
147 |
+
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
h_ = x
|
151 |
+
h_ = self.norm(h_)
|
152 |
+
q = self.q(h_)
|
153 |
+
k = self.k(h_)
|
154 |
+
v = self.v(h_)
|
155 |
+
|
156 |
+
# compute attention
|
157 |
+
b, c, h, w = q.shape
|
158 |
+
q = q.reshape(b, c, h * w)
|
159 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
160 |
+
k = k.reshape(b, c, h * w) # b,c,hw
|
161 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
162 |
+
w_ = w_ * (int(c) ** (-0.5))
|
163 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
164 |
+
|
165 |
+
# attend to values
|
166 |
+
v = v.reshape(b, c, h * w)
|
167 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
168 |
+
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
169 |
+
h_ = h_.reshape(b, c, h, w)
|
170 |
+
|
171 |
+
h_ = self.proj_out(h_)
|
172 |
+
|
173 |
+
return x + h_
|
174 |
+
|
175 |
+
|
176 |
+
class Model(nn.Module):
|
177 |
+
def __init__(
|
178 |
+
self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, use_timestep=True
|
179 |
+
):
|
180 |
+
super().__init__()
|
181 |
+
self.ch = ch
|
182 |
+
self.temb_ch = self.ch * 4
|
183 |
+
self.num_resolutions = len(ch_mult)
|
184 |
+
self.num_res_blocks = num_res_blocks
|
185 |
+
self.resolution = resolution
|
186 |
+
self.in_channels = in_channels
|
187 |
+
|
188 |
+
self.use_timestep = use_timestep
|
189 |
+
if self.use_timestep:
|
190 |
+
# timestep embedding
|
191 |
+
self.temb = nn.Module()
|
192 |
+
self.temb.dense = nn.ModuleList(
|
193 |
+
[
|
194 |
+
torch.nn.Linear(self.ch, self.temb_ch),
|
195 |
+
torch.nn.Linear(self.temb_ch, self.temb_ch),
|
196 |
+
]
|
197 |
+
)
|
198 |
+
|
199 |
+
# downsampling
|
200 |
+
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
201 |
+
|
202 |
+
curr_res = resolution
|
203 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
204 |
+
self.down = nn.ModuleList()
|
205 |
+
for i_level in range(self.num_resolutions):
|
206 |
+
block = nn.ModuleList()
|
207 |
+
attn = nn.ModuleList()
|
208 |
+
block_in = ch * in_ch_mult[i_level]
|
209 |
+
block_out = ch * ch_mult[i_level]
|
210 |
+
for i_block in range(self.num_res_blocks):
|
211 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
|
212 |
+
block_in = block_out
|
213 |
+
if curr_res in attn_resolutions:
|
214 |
+
attn.append(AttnBlock(block_in))
|
215 |
+
down = nn.Module()
|
216 |
+
down.block = block
|
217 |
+
down.attn = attn
|
218 |
+
if i_level != self.num_resolutions - 1:
|
219 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
220 |
+
curr_res = curr_res // 2
|
221 |
+
self.down.append(down)
|
222 |
+
|
223 |
+
# middle
|
224 |
+
self.mid = nn.Module()
|
225 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
226 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
227 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
228 |
+
|
229 |
+
# upsampling
|
230 |
+
self.up = nn.ModuleList()
|
231 |
+
for i_level in reversed(range(self.num_resolutions)):
|
232 |
+
block = nn.ModuleList()
|
233 |
+
attn = nn.ModuleList()
|
234 |
+
block_out = ch * ch_mult[i_level]
|
235 |
+
skip_in = ch * ch_mult[i_level]
|
236 |
+
for i_block in range(self.num_res_blocks + 1):
|
237 |
+
if i_block == self.num_res_blocks:
|
238 |
+
skip_in = ch * in_ch_mult[i_level]
|
239 |
+
block.append(ResnetBlock(in_channels=block_in + skip_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
|
240 |
+
block_in = block_out
|
241 |
+
if curr_res in attn_resolutions:
|
242 |
+
attn.append(AttnBlock(block_in))
|
243 |
+
up = nn.Module()
|
244 |
+
up.block = block
|
245 |
+
up.attn = attn
|
246 |
+
if i_level != 0:
|
247 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
248 |
+
curr_res = curr_res * 2
|
249 |
+
self.up.insert(0, up) # prepend to get consistent order
|
250 |
+
|
251 |
+
# end
|
252 |
+
self.norm_out = Normalize(block_in)
|
253 |
+
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
254 |
+
|
255 |
+
def forward(self, x, t=None):
|
256 |
+
# assert x.shape[2] == x.shape[3] == self.resolution
|
257 |
+
|
258 |
+
if self.use_timestep:
|
259 |
+
# timestep embedding
|
260 |
+
assert t is not None
|
261 |
+
temb = get_timestep_embedding(t, self.ch)
|
262 |
+
temb = self.temb.dense[0](temb)
|
263 |
+
temb = nonlinearity(temb)
|
264 |
+
temb = self.temb.dense[1](temb)
|
265 |
+
else:
|
266 |
+
temb = None
|
267 |
+
|
268 |
+
# downsampling
|
269 |
+
hs = [self.conv_in(x)]
|
270 |
+
for i_level in range(self.num_resolutions):
|
271 |
+
for i_block in range(self.num_res_blocks):
|
272 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
273 |
+
if len(self.down[i_level].attn) > 0:
|
274 |
+
h = self.down[i_level].attn[i_block](h)
|
275 |
+
hs.append(h)
|
276 |
+
if i_level != self.num_resolutions - 1:
|
277 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
278 |
+
|
279 |
+
# middle
|
280 |
+
h = hs[-1]
|
281 |
+
h = self.mid.block_1(h, temb)
|
282 |
+
h = self.mid.attn_1(h)
|
283 |
+
h = self.mid.block_2(h, temb)
|
284 |
+
|
285 |
+
# upsampling
|
286 |
+
for i_level in reversed(range(self.num_resolutions)):
|
287 |
+
for i_block in range(self.num_res_blocks + 1):
|
288 |
+
h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
|
289 |
+
if len(self.up[i_level].attn) > 0:
|
290 |
+
h = self.up[i_level].attn[i_block](h)
|
291 |
+
if i_level != 0:
|
292 |
+
h = self.up[i_level].upsample(h)
|
293 |
+
|
294 |
+
# end
|
295 |
+
h = self.norm_out(h)
|
296 |
+
h = nonlinearity(h)
|
297 |
+
h = self.conv_out(h)
|
298 |
+
return h
|
299 |
+
|
300 |
+
|
301 |
+
class Encoder(nn.Module):
|
302 |
+
def __init__(
|
303 |
+
self,
|
304 |
+
*,
|
305 |
+
ch,
|
306 |
+
out_ch,
|
307 |
+
ch_mult=(1, 2, 4, 8),
|
308 |
+
num_res_blocks,
|
309 |
+
attn_resolutions,
|
310 |
+
dropout=0.0,
|
311 |
+
resamp_with_conv=True,
|
312 |
+
in_channels,
|
313 |
+
resolution,
|
314 |
+
z_channels,
|
315 |
+
double_z=True,
|
316 |
+
**ignore_kwargs
|
317 |
+
):
|
318 |
+
super().__init__()
|
319 |
+
self.ch = ch
|
320 |
+
self.temb_ch = 0
|
321 |
+
self.num_resolutions = len(ch_mult)
|
322 |
+
self.num_res_blocks = num_res_blocks
|
323 |
+
self.resolution = resolution
|
324 |
+
self.in_channels = in_channels
|
325 |
+
|
326 |
+
# downsampling
|
327 |
+
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
328 |
+
|
329 |
+
curr_res = resolution
|
330 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
331 |
+
self.down = nn.ModuleList()
|
332 |
+
for i_level in range(self.num_resolutions):
|
333 |
+
block = nn.ModuleList()
|
334 |
+
attn = nn.ModuleList()
|
335 |
+
block_in = ch * in_ch_mult[i_level]
|
336 |
+
block_out = ch * ch_mult[i_level]
|
337 |
+
for i_block in range(self.num_res_blocks):
|
338 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
|
339 |
+
block_in = block_out
|
340 |
+
if curr_res in attn_resolutions:
|
341 |
+
attn.append(AttnBlock(block_in))
|
342 |
+
down = nn.Module()
|
343 |
+
down.block = block
|
344 |
+
down.attn = attn
|
345 |
+
if i_level != self.num_resolutions - 1:
|
346 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
347 |
+
curr_res = curr_res // 2
|
348 |
+
self.down.append(down)
|
349 |
+
|
350 |
+
# middle
|
351 |
+
self.mid = nn.Module()
|
352 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
353 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
354 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
355 |
+
|
356 |
+
# end
|
357 |
+
self.norm_out = Normalize(block_in)
|
358 |
+
self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1)
|
359 |
+
|
360 |
+
def forward(self, x):
|
361 |
+
# assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
|
362 |
+
|
363 |
+
# timestep embedding
|
364 |
+
temb = None
|
365 |
+
|
366 |
+
# downsampling
|
367 |
+
hs = [self.conv_in(x)]
|
368 |
+
for i_level in range(self.num_resolutions):
|
369 |
+
for i_block in range(self.num_res_blocks):
|
370 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
371 |
+
if len(self.down[i_level].attn) > 0:
|
372 |
+
h = self.down[i_level].attn[i_block](h)
|
373 |
+
hs.append(h)
|
374 |
+
if i_level != self.num_resolutions - 1:
|
375 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
376 |
+
|
377 |
+
# middle
|
378 |
+
h = hs[-1]
|
379 |
+
h = self.mid.block_1(h, temb)
|
380 |
+
h = self.mid.attn_1(h)
|
381 |
+
h = self.mid.block_2(h, temb)
|
382 |
+
|
383 |
+
# end
|
384 |
+
h = self.norm_out(h)
|
385 |
+
h = nonlinearity(h)
|
386 |
+
h = self.conv_out(h)
|
387 |
+
return h
|
388 |
+
|
389 |
+
|
390 |
+
class Decoder(nn.Module):
|
391 |
+
def __init__(
|
392 |
+
self,
|
393 |
+
*,
|
394 |
+
ch,
|
395 |
+
out_ch,
|
396 |
+
ch_mult=(1, 2, 4, 8),
|
397 |
+
num_res_blocks,
|
398 |
+
attn_resolutions,
|
399 |
+
dropout=0.0,
|
400 |
+
resamp_with_conv=True,
|
401 |
+
in_channels,
|
402 |
+
resolution,
|
403 |
+
z_channels,
|
404 |
+
give_pre_end=False,
|
405 |
+
**ignorekwargs
|
406 |
+
):
|
407 |
+
super().__init__()
|
408 |
+
self.ch = ch
|
409 |
+
self.temb_ch = 0
|
410 |
+
self.num_resolutions = len(ch_mult)
|
411 |
+
self.num_res_blocks = num_res_blocks
|
412 |
+
self.resolution = resolution
|
413 |
+
self.in_channels = in_channels
|
414 |
+
self.give_pre_end = give_pre_end
|
415 |
+
|
416 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
417 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
418 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
419 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
420 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
421 |
+
|
422 |
+
# z to block_in
|
423 |
+
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
424 |
+
|
425 |
+
# middle
|
426 |
+
self.mid = nn.Module()
|
427 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
428 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
429 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
430 |
+
|
431 |
+
# upsampling
|
432 |
+
self.up = nn.ModuleList()
|
433 |
+
for i_level in reversed(range(self.num_resolutions)):
|
434 |
+
block = nn.ModuleList()
|
435 |
+
attn = nn.ModuleList()
|
436 |
+
block_out = ch * ch_mult[i_level]
|
437 |
+
for i_block in range(self.num_res_blocks + 1):
|
438 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
|
439 |
+
block_in = block_out
|
440 |
+
if curr_res in attn_resolutions:
|
441 |
+
attn.append(AttnBlock(block_in))
|
442 |
+
up = nn.Module()
|
443 |
+
up.block = block
|
444 |
+
up.attn = attn
|
445 |
+
if i_level != 0:
|
446 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
447 |
+
curr_res = curr_res * 2
|
448 |
+
self.up.insert(0, up) # prepend to get consistent order
|
449 |
+
|
450 |
+
# end
|
451 |
+
self.norm_out = Normalize(block_in)
|
452 |
+
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
453 |
+
|
454 |
+
def forward(self, z):
|
455 |
+
# assert z.shape[1:] == self.z_shape[1:]
|
456 |
+
self.last_z_shape = z.shape
|
457 |
+
|
458 |
+
# timestep embedding
|
459 |
+
temb = None
|
460 |
+
|
461 |
+
# z to block_in
|
462 |
+
h = self.conv_in(z)
|
463 |
+
|
464 |
+
# middle
|
465 |
+
h = self.mid.block_1(h, temb)
|
466 |
+
h = self.mid.attn_1(h)
|
467 |
+
h = self.mid.block_2(h, temb)
|
468 |
+
|
469 |
+
# upsampling
|
470 |
+
for i_level in reversed(range(self.num_resolutions)):
|
471 |
+
for i_block in range(self.num_res_blocks + 1):
|
472 |
+
h = self.up[i_level].block[i_block](h, temb)
|
473 |
+
if len(self.up[i_level].attn) > 0:
|
474 |
+
h = self.up[i_level].attn[i_block](h)
|
475 |
+
if i_level != 0:
|
476 |
+
h = self.up[i_level].upsample(h)
|
477 |
+
|
478 |
+
# end
|
479 |
+
if self.give_pre_end:
|
480 |
+
return h
|
481 |
+
|
482 |
+
h = self.norm_out(h)
|
483 |
+
h = nonlinearity(h)
|
484 |
+
h = self.conv_out(h)
|
485 |
+
return h
|
486 |
+
|
487 |
+
|
488 |
+
class VUNet(nn.Module):
|
489 |
+
def __init__(
|
490 |
+
self,
|
491 |
+
*,
|
492 |
+
ch,
|
493 |
+
out_ch,
|
494 |
+
ch_mult=(1, 2, 4, 8),
|
495 |
+
num_res_blocks,
|
496 |
+
attn_resolutions,
|
497 |
+
dropout=0.0,
|
498 |
+
resamp_with_conv=True,
|
499 |
+
in_channels,
|
500 |
+
c_channels,
|
501 |
+
resolution,
|
502 |
+
z_channels,
|
503 |
+
use_timestep=False,
|
504 |
+
**ignore_kwargs
|
505 |
+
):
|
506 |
+
super().__init__()
|
507 |
+
self.ch = ch
|
508 |
+
self.temb_ch = self.ch * 4
|
509 |
+
self.num_resolutions = len(ch_mult)
|
510 |
+
self.num_res_blocks = num_res_blocks
|
511 |
+
self.resolution = resolution
|
512 |
+
|
513 |
+
self.use_timestep = use_timestep
|
514 |
+
if self.use_timestep:
|
515 |
+
# timestep embedding
|
516 |
+
self.temb = nn.Module()
|
517 |
+
self.temb.dense = nn.ModuleList(
|
518 |
+
[
|
519 |
+
torch.nn.Linear(self.ch, self.temb_ch),
|
520 |
+
torch.nn.Linear(self.temb_ch, self.temb_ch),
|
521 |
+
]
|
522 |
+
)
|
523 |
+
|
524 |
+
# downsampling
|
525 |
+
self.conv_in = torch.nn.Conv2d(c_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
526 |
+
|
527 |
+
curr_res = resolution
|
528 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
529 |
+
self.down = nn.ModuleList()
|
530 |
+
for i_level in range(self.num_resolutions):
|
531 |
+
block = nn.ModuleList()
|
532 |
+
attn = nn.ModuleList()
|
533 |
+
block_in = ch * in_ch_mult[i_level]
|
534 |
+
block_out = ch * ch_mult[i_level]
|
535 |
+
for i_block in range(self.num_res_blocks):
|
536 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
|
537 |
+
block_in = block_out
|
538 |
+
if curr_res in attn_resolutions:
|
539 |
+
attn.append(AttnBlock(block_in))
|
540 |
+
down = nn.Module()
|
541 |
+
down.block = block
|
542 |
+
down.attn = attn
|
543 |
+
if i_level != self.num_resolutions - 1:
|
544 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
545 |
+
curr_res = curr_res // 2
|
546 |
+
self.down.append(down)
|
547 |
+
|
548 |
+
self.z_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=1, stride=1, padding=0)
|
549 |
+
# middle
|
550 |
+
self.mid = nn.Module()
|
551 |
+
self.mid.block_1 = ResnetBlock(in_channels=2 * block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
552 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
553 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
554 |
+
|
555 |
+
# upsampling
|
556 |
+
self.up = nn.ModuleList()
|
557 |
+
for i_level in reversed(range(self.num_resolutions)):
|
558 |
+
block = nn.ModuleList()
|
559 |
+
attn = nn.ModuleList()
|
560 |
+
block_out = ch * ch_mult[i_level]
|
561 |
+
skip_in = ch * ch_mult[i_level]
|
562 |
+
for i_block in range(self.num_res_blocks + 1):
|
563 |
+
if i_block == self.num_res_blocks:
|
564 |
+
skip_in = ch * in_ch_mult[i_level]
|
565 |
+
block.append(ResnetBlock(in_channels=block_in + skip_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
|
566 |
+
block_in = block_out
|
567 |
+
if curr_res in attn_resolutions:
|
568 |
+
attn.append(AttnBlock(block_in))
|
569 |
+
up = nn.Module()
|
570 |
+
up.block = block
|
571 |
+
up.attn = attn
|
572 |
+
if i_level != 0:
|
573 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
574 |
+
curr_res = curr_res * 2
|
575 |
+
self.up.insert(0, up) # prepend to get consistent order
|
576 |
+
|
577 |
+
# end
|
578 |
+
self.norm_out = Normalize(block_in)
|
579 |
+
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
580 |
+
|
581 |
+
def forward(self, x, z):
|
582 |
+
# assert x.shape[2] == x.shape[3] == self.resolution
|
583 |
+
|
584 |
+
if self.use_timestep:
|
585 |
+
# timestep embedding
|
586 |
+
assert t is not None
|
587 |
+
temb = get_timestep_embedding(t, self.ch)
|
588 |
+
temb = self.temb.dense[0](temb)
|
589 |
+
temb = nonlinearity(temb)
|
590 |
+
temb = self.temb.dense[1](temb)
|
591 |
+
else:
|
592 |
+
temb = None
|
593 |
+
|
594 |
+
# downsampling
|
595 |
+
hs = [self.conv_in(x)]
|
596 |
+
for i_level in range(self.num_resolutions):
|
597 |
+
for i_block in range(self.num_res_blocks):
|
598 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
599 |
+
if len(self.down[i_level].attn) > 0:
|
600 |
+
h = self.down[i_level].attn[i_block](h)
|
601 |
+
hs.append(h)
|
602 |
+
if i_level != self.num_resolutions - 1:
|
603 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
604 |
+
|
605 |
+
# middle
|
606 |
+
h = hs[-1]
|
607 |
+
z = self.z_in(z)
|
608 |
+
h = torch.cat((h, z), dim=1)
|
609 |
+
h = self.mid.block_1(h, temb)
|
610 |
+
h = self.mid.attn_1(h)
|
611 |
+
h = self.mid.block_2(h, temb)
|
612 |
+
|
613 |
+
# upsampling
|
614 |
+
for i_level in reversed(range(self.num_resolutions)):
|
615 |
+
for i_block in range(self.num_res_blocks + 1):
|
616 |
+
h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
|
617 |
+
if len(self.up[i_level].attn) > 0:
|
618 |
+
h = self.up[i_level].attn[i_block](h)
|
619 |
+
if i_level != 0:
|
620 |
+
h = self.up[i_level].upsample(h)
|
621 |
+
|
622 |
+
# end
|
623 |
+
h = self.norm_out(h)
|
624 |
+
h = nonlinearity(h)
|
625 |
+
h = self.conv_out(h)
|
626 |
+
return h
|
627 |
+
|
628 |
+
|
629 |
+
class SimpleDecoder(nn.Module):
|
630 |
+
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
631 |
+
super().__init__()
|
632 |
+
self.model = nn.ModuleList(
|
633 |
+
[
|
634 |
+
nn.Conv2d(in_channels, in_channels, 1),
|
635 |
+
ResnetBlock(in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0),
|
636 |
+
ResnetBlock(in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0),
|
637 |
+
ResnetBlock(in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0),
|
638 |
+
nn.Conv2d(2 * in_channels, in_channels, 1),
|
639 |
+
Upsample(in_channels, with_conv=True),
|
640 |
+
]
|
641 |
+
)
|
642 |
+
# end
|
643 |
+
self.norm_out = Normalize(in_channels)
|
644 |
+
self.conv_out = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
645 |
+
|
646 |
+
def forward(self, x):
|
647 |
+
for i, layer in enumerate(self.model):
|
648 |
+
if i in [1, 2, 3]:
|
649 |
+
x = layer(x, None)
|
650 |
+
else:
|
651 |
+
x = layer(x)
|
652 |
+
|
653 |
+
h = self.norm_out(x)
|
654 |
+
h = nonlinearity(h)
|
655 |
+
x = self.conv_out(h)
|
656 |
+
return x
|
657 |
+
|
658 |
+
|
659 |
+
class UpsampleDecoder(nn.Module):
|
660 |
+
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch_mult=(2, 2), dropout=0.0):
|
661 |
+
super().__init__()
|
662 |
+
# upsampling
|
663 |
+
self.temb_ch = 0
|
664 |
+
self.num_resolutions = len(ch_mult)
|
665 |
+
self.num_res_blocks = num_res_blocks
|
666 |
+
block_in = in_channels
|
667 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
668 |
+
self.res_blocks = nn.ModuleList()
|
669 |
+
self.upsample_blocks = nn.ModuleList()
|
670 |
+
for i_level in range(self.num_resolutions):
|
671 |
+
res_block = []
|
672 |
+
block_out = ch * ch_mult[i_level]
|
673 |
+
for i_block in range(self.num_res_blocks + 1):
|
674 |
+
res_block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
|
675 |
+
block_in = block_out
|
676 |
+
self.res_blocks.append(nn.ModuleList(res_block))
|
677 |
+
if i_level != self.num_resolutions - 1:
|
678 |
+
self.upsample_blocks.append(Upsample(block_in, True))
|
679 |
+
curr_res = curr_res * 2
|
680 |
+
|
681 |
+
# end
|
682 |
+
self.norm_out = Normalize(block_in)
|
683 |
+
self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
|
684 |
+
|
685 |
+
def forward(self, x):
|
686 |
+
# upsampling
|
687 |
+
h = x
|
688 |
+
for k, i_level in enumerate(range(self.num_resolutions)):
|
689 |
+
for i_block in range(self.num_res_blocks + 1):
|
690 |
+
h = self.res_blocks[i_level][i_block](h, None)
|
691 |
+
if i_level != self.num_resolutions - 1:
|
692 |
+
h = self.upsample_blocks[k](h)
|
693 |
+
h = self.norm_out(h)
|
694 |
+
h = nonlinearity(h)
|
695 |
+
h = self.conv_out(h)
|
696 |
+
return h
|
wham/models/vqgan/taming/quantize.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# All files under this directory are originally from the taming-transformers repository:
|
2 |
+
# https://github.com/CompVis/taming-transformers
|
3 |
+
|
4 |
+
# MIT License
|
5 |
+
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
|
6 |
+
# 2023 Microsoft Research
|
7 |
+
|
8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
10 |
+
# in the Software without restriction, including without limitation the rights
|
11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
13 |
+
# furnished to do so, subject to the following conditions:
|
14 |
+
|
15 |
+
# The above copyright notice and this permission notice shall be included in all
|
16 |
+
# copies or substantial portions of the Software.
|
17 |
+
|
18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
19 |
+
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
20 |
+
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
21 |
+
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
22 |
+
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
23 |
+
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
24 |
+
# OR OTHER DEALINGS IN THE SOFTWARE.
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.nn as nn
|
28 |
+
import numpy as np
|
29 |
+
from einops import rearrange
|
30 |
+
|
31 |
+
|
32 |
+
class VectorQuantizer2(nn.Module):
|
33 |
+
"""
|
34 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
35 |
+
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
36 |
+
"""
|
37 |
+
|
38 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
39 |
+
# backwards compatibility we use the buggy version by default, but you can
|
40 |
+
# specify legacy=False to fix it.
|
41 |
+
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
|
42 |
+
super().__init__()
|
43 |
+
self.n_e = n_e
|
44 |
+
self.e_dim = e_dim
|
45 |
+
self.beta = beta
|
46 |
+
self.legacy = legacy
|
47 |
+
|
48 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
49 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
50 |
+
|
51 |
+
self.remap = remap
|
52 |
+
if self.remap is not None:
|
53 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
54 |
+
self.re_embed = self.used.shape[0]
|
55 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
56 |
+
if self.unknown_index == "extra":
|
57 |
+
self.unknown_index = self.re_embed
|
58 |
+
self.re_embed = self.re_embed + 1
|
59 |
+
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " f"Using {self.unknown_index} for unknown indices.")
|
60 |
+
else:
|
61 |
+
self.re_embed = n_e
|
62 |
+
|
63 |
+
self.sane_index_shape = sane_index_shape
|
64 |
+
|
65 |
+
def remap_to_used(self, inds):
|
66 |
+
ishape = inds.shape
|
67 |
+
assert len(ishape) > 1
|
68 |
+
inds = inds.reshape(ishape[0], -1)
|
69 |
+
used = self.used.to(inds)
|
70 |
+
match = (inds[:, :, None] == used[None, None, ...]).long()
|
71 |
+
new = match.argmax(-1)
|
72 |
+
unknown = match.sum(2) < 1
|
73 |
+
if self.unknown_index == "random":
|
74 |
+
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
75 |
+
else:
|
76 |
+
new[unknown] = self.unknown_index
|
77 |
+
return new.reshape(ishape)
|
78 |
+
|
79 |
+
def unmap_to_all(self, inds):
|
80 |
+
ishape = inds.shape
|
81 |
+
assert len(ishape) > 1
|
82 |
+
inds = inds.reshape(ishape[0], -1)
|
83 |
+
used = self.used.to(inds)
|
84 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
85 |
+
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
86 |
+
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
87 |
+
return back.reshape(ishape)
|
88 |
+
|
89 |
+
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
90 |
+
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
91 |
+
assert rescale_logits == False, "Only for interface compatible with Gumbel"
|
92 |
+
assert return_logits == False, "Only for interface compatible with Gumbel"
|
93 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
94 |
+
z = rearrange(z, "b c h w -> b h w c").contiguous()
|
95 |
+
z_flattened = z.view(-1, self.e_dim)
|
96 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
97 |
+
|
98 |
+
d = (
|
99 |
+
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
100 |
+
+ torch.sum(self.embedding.weight**2, dim=1)
|
101 |
+
- 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
|
102 |
+
)
|
103 |
+
|
104 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
105 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
106 |
+
perplexity = None
|
107 |
+
min_encodings = None
|
108 |
+
|
109 |
+
# compute loss for embedding
|
110 |
+
if not self.legacy:
|
111 |
+
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
|
112 |
+
else:
|
113 |
+
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
114 |
+
|
115 |
+
# preserve gradients
|
116 |
+
z_q = z + (z_q - z).detach()
|
117 |
+
|
118 |
+
# reshape back to match original input shape
|
119 |
+
z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
|
120 |
+
|
121 |
+
if self.remap is not None:
|
122 |
+
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
|
123 |
+
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
124 |
+
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
125 |
+
|
126 |
+
if self.sane_index_shape:
|
127 |
+
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
128 |
+
|
129 |
+
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
130 |
+
|
131 |
+
def get_codebook_entry(self, indices, shape):
|
132 |
+
# shape specifying (batch, height, width, channel)
|
133 |
+
if self.remap is not None:
|
134 |
+
indices = indices.reshape(shape[0], -1) # add batch axis
|
135 |
+
indices = self.unmap_to_all(indices)
|
136 |
+
indices = indices.reshape(-1) # flatten again
|
137 |
+
|
138 |
+
# get quantized latent vectors
|
139 |
+
z_q = self.embedding(indices)
|
140 |
+
|
141 |
+
if shape is not None:
|
142 |
+
z_q = z_q.view(shape)
|
143 |
+
# reshape back to match original input shape
|
144 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
145 |
+
|
146 |
+
return z_q
|
wham/models/vqgan/taming_vq_model.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Wrapper for the VQ models from the taming-transformers repo
|
2 |
+
# https://github.com/CompVis/taming-transformers
|
3 |
+
|
4 |
+
from typing import Any, Mapping
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from wham.models.vqgan.taming.model import Encoder, Decoder
|
10 |
+
from wham.models.vqgan.taming.quantize import VectorQuantizer2 as VectorQuantizer
|
11 |
+
|
12 |
+
from wham.models.wham_base.tensor_spaces import TensorSpace
|
13 |
+
from wham.models.wham_base.encoder_decoder import EncoderDecoderBase
|
14 |
+
|
15 |
+
|
16 |
+
HARDCODED_IMAGE_SIZE = 128
|
17 |
+
|
18 |
+
|
19 |
+
def taming_vq_preprocess_images(imgs):
|
20 |
+
"""Normalize images (as pytorch tensor uint8s) as in taming-transformers"""
|
21 |
+
return imgs.float() / 127.5 - 1.0
|
22 |
+
|
23 |
+
|
24 |
+
def taming_vq_revert_preprocess_images(imgs):
|
25 |
+
"""Revert preprocessing of images from taming to uint8 as in taming-transformers"""
|
26 |
+
# Clamp first
|
27 |
+
imgs = torch.clamp(imgs, -1.0, 1.0)
|
28 |
+
return ((imgs + 1) * 127.5).byte()
|
29 |
+
|
30 |
+
|
31 |
+
class _VQModelFromTamingRepository(pl.LightningModule):
|
32 |
+
"""
|
33 |
+
This aims to be the original VQ model from the taming-transformers repo with as little modifications as possible. This should not be used directly.
|
34 |
+
Source: https://github.com/CompVis/taming-transformers/blob/master/taming/models/vqgan.py
|
35 |
+
|
36 |
+
MIT License
|
37 |
+
Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
|
38 |
+
2023 Microsoft Research
|
39 |
+
|
40 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
41 |
+
of this software and associated documentation files (the "Software"), to deal
|
42 |
+
in the Software without restriction, including without limitation the rights
|
43 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
44 |
+
copies of the Software, and to permit persons to whom the Software is
|
45 |
+
furnished to do so, subject to the following conditions:
|
46 |
+
|
47 |
+
The above copyright notice and this permission notice shall be included in all
|
48 |
+
copies or substantial portions of the Software.
|
49 |
+
|
50 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
51 |
+
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
52 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
53 |
+
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
54 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
55 |
+
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
56 |
+
OR OTHER DEALINGS IN THE SOFTWARE.
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
ddconfig,
|
62 |
+
n_embed,
|
63 |
+
embed_dim,
|
64 |
+
ckpt_path=None,
|
65 |
+
ignore_keys=[],
|
66 |
+
image_key="image",
|
67 |
+
colorize_nlabels=None,
|
68 |
+
monitor=None,
|
69 |
+
remap=None,
|
70 |
+
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
71 |
+
):
|
72 |
+
super().__init__()
|
73 |
+
self.image_key = image_key
|
74 |
+
self.encoder = Encoder(**ddconfig)
|
75 |
+
self.decoder = Decoder(**ddconfig)
|
76 |
+
# NOTE: Loss is disabled for this repo (we only want inference)
|
77 |
+
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape)
|
78 |
+
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
79 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
80 |
+
# Note: the '!= "None"' check is for checkpoints that mistakenly stored the None as a string
|
81 |
+
if ckpt_path is not None and ckpt_path != "None":
|
82 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
83 |
+
self.image_key = image_key
|
84 |
+
if colorize_nlabels is not None:
|
85 |
+
assert type(colorize_nlabels) == int
|
86 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
87 |
+
if monitor is not None:
|
88 |
+
self.monitor = monitor
|
89 |
+
|
90 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
91 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
92 |
+
keys = list(sd.keys())
|
93 |
+
for k in keys:
|
94 |
+
for ik in ignore_keys:
|
95 |
+
if k.startswith(ik):
|
96 |
+
print("Deleting key {} from state_dict.".format(k))
|
97 |
+
del sd[k]
|
98 |
+
self.load_state_dict(sd, strict=False)
|
99 |
+
print(f"Restored from {path}")
|
100 |
+
|
101 |
+
def encode(self, x):
|
102 |
+
h = self.encoder(x)
|
103 |
+
h = self.quant_conv(h)
|
104 |
+
quant, emb_loss, info = self.quantize(h)
|
105 |
+
return quant, emb_loss, info
|
106 |
+
|
107 |
+
def decode(self, quant):
|
108 |
+
quant = self.post_quant_conv(quant)
|
109 |
+
dec = self.decoder(quant)
|
110 |
+
return dec
|
111 |
+
|
112 |
+
def forward(self, input):
|
113 |
+
quant, diff, _ = self.encode(input)
|
114 |
+
dec = self.decode(quant)
|
115 |
+
return dec, diff
|
116 |
+
|
117 |
+
def get_input(self, batch, k):
|
118 |
+
x = batch[k]
|
119 |
+
if len(x.shape) == 3:
|
120 |
+
x = x[..., None]
|
121 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
122 |
+
return x.float()
|
123 |
+
|
124 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
125 |
+
raise NotImplementedError("This copy of the model code does not support training")
|
126 |
+
|
127 |
+
def validation_step(self, batch, batch_idx):
|
128 |
+
raise NotImplementedError("This copy of the model code does not support training")
|
129 |
+
|
130 |
+
def configure_optimizers(self):
|
131 |
+
raise NotImplementedError("This copy of the model code does not support training")
|
132 |
+
|
133 |
+
def get_last_layer(self):
|
134 |
+
return self.decoder.conv_out.weight
|
135 |
+
|
136 |
+
def log_images(self, batch, **kwargs):
|
137 |
+
log = dict()
|
138 |
+
x = self.get_input(batch, self.image_key)
|
139 |
+
x = x.to(self.device)
|
140 |
+
xrec, _ = self(x)
|
141 |
+
if x.shape[1] > 3:
|
142 |
+
# colorize with random projection
|
143 |
+
assert xrec.shape[1] > 3
|
144 |
+
x = self.to_rgb(x)
|
145 |
+
xrec = self.to_rgb(xrec)
|
146 |
+
log["inputs"] = x
|
147 |
+
log["reconstructions"] = xrec
|
148 |
+
return log
|
149 |
+
|
150 |
+
def to_rgb(self, x):
|
151 |
+
assert self.image_key == "segmentation"
|
152 |
+
if not hasattr(self, "colorize"):
|
153 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
154 |
+
x = F.conv2d(x, weight=self.colorize)
|
155 |
+
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
|
156 |
+
return x
|
157 |
+
|
158 |
+
|
159 |
+
class TamingVQModel(EncoderDecoderBase):
|
160 |
+
|
161 |
+
__DEBUG_CREATION_KWARGS__ = {
|
162 |
+
"ckpt_path": None,
|
163 |
+
"model_spec": {
|
164 |
+
"taming_n_embed": 16,
|
165 |
+
"taming_embed_dim": 8,
|
166 |
+
"taming_num_indices_per_axis": 8,
|
167 |
+
"taming_ddconfig": {
|
168 |
+
"double_z": False,
|
169 |
+
"z_channels": 16,
|
170 |
+
"resolution": 128,
|
171 |
+
"in_channels": 3,
|
172 |
+
"out_ch": 3,
|
173 |
+
"ch": 128,
|
174 |
+
"ch_mult": [1, 1, 1, 1, 1],
|
175 |
+
"num_res_blocks": 1,
|
176 |
+
"attn_resolutions": [16],
|
177 |
+
"dropout": 0.0,
|
178 |
+
},
|
179 |
+
},
|
180 |
+
}
|
181 |
+
|
182 |
+
def __init__(self, model_spec, ckpt_path, **kwargs):
|
183 |
+
super().__init__()
|
184 |
+
self._vocab_size = model_spec["taming_n_embed"]
|
185 |
+
self.num_indices_per_axis = model_spec["taming_num_indices_per_axis"]
|
186 |
+
self.num_indices_total = self.num_indices_per_axis**2
|
187 |
+
self.taming_embed_dim = model_spec["taming_embed_dim"]
|
188 |
+
taming_ddconfig = model_spec.get("taming_ddconfig", None)
|
189 |
+
if taming_ddconfig is None:
|
190 |
+
raise ValueError("To run TamingVQModel, specify model_spec.taming_ddconfig, which should match the ddconfig used when training the model")
|
191 |
+
|
192 |
+
self.vq_model = _VQModelFromTamingRepository(taming_ddconfig, self._vocab_size, self.taming_embed_dim, ckpt_path=ckpt_path)
|
193 |
+
|
194 |
+
resolution = taming_ddconfig["resolution"]
|
195 |
+
in_channels = taming_ddconfig["in_channels"]
|
196 |
+
self.world_space = TensorSpace((in_channels, resolution, resolution), dtype=torch.uint8, low=0, high=255)
|
197 |
+
self.encoder_space = TensorSpace((self.num_indices_total,), dtype=torch.long, low=0, high=self.vocab_size - 1)
|
198 |
+
|
199 |
+
@property
|
200 |
+
def vocab_size(self):
|
201 |
+
"""Return the number of entries in the codebook."""
|
202 |
+
return self._vocab_size
|
203 |
+
|
204 |
+
@property
|
205 |
+
def encoded_bottleneck_dim(self):
|
206 |
+
"""Return the dimensionality of the latent vector encoded into codebook indices."""
|
207 |
+
return self.num_indices_total
|
208 |
+
|
209 |
+
def _preprocess_images(self, images):
|
210 |
+
"""Preprocess images (B, C, H, W)"""
|
211 |
+
return taming_vq_preprocess_images(images)
|
212 |
+
|
213 |
+
def _revert_image_preprocess(self, x_batch):
|
214 |
+
"""Revert the preprocessing done in _preprocess_images"""
|
215 |
+
return taming_vq_revert_preprocess_images(x_batch)
|
216 |
+
|
217 |
+
def decode_from_encoding_indices(self, encoding_indices, return_vq_embeddings=False):
|
218 |
+
"""Return decoded images (B, C, H, W) for a batch of encoding indices (B, self.encoded_bottleneck_dim)"""
|
219 |
+
batch_size = encoding_indices.shape[0]
|
220 |
+
z = self.vq_model.quantize.get_codebook_entry(encoding_indices, shape=(batch_size, self.num_indices_per_axis, self.num_indices_per_axis, self.taming_embed_dim))
|
221 |
+
data_recon = self.vq_model.decode(z)
|
222 |
+
# Denormalize and cast to uint8
|
223 |
+
data_recon = self._revert_image_preprocess(data_recon)
|
224 |
+
if return_vq_embeddings:
|
225 |
+
return data_recon, z
|
226 |
+
return data_recon
|
227 |
+
|
228 |
+
def get_encoding_indices_for_images(self, images):
|
229 |
+
"""
|
230 |
+
Return encoding indices (B, self.encoded_bottleneck_dim) for a batch of images (B, C, H, W).
|
231 |
+
Useful auxiliary method for testing.
|
232 |
+
"""
|
233 |
+
x_batch = self._preprocess_images(images)
|
234 |
+
_, _, (_, _, encoding_indices) = self.vq_model.encode(x_batch)
|
235 |
+
# Split back into (B, self.encoded_bottleneck_dim)
|
236 |
+
encoding_indices = encoding_indices.view(images.shape[0], -1)
|
237 |
+
return encoding_indices
|
238 |
+
|
239 |
+
def forward_returning_action_and_embedding(self, states, actions_input, timesteps, attention_mask, images):
|
240 |
+
seq_len_dim = 1
|
241 |
+
assert images.shape[seq_len_dim] == 1, f"We require seq_len==1, but provided {images.shape[seq_len_dim]}."
|
242 |
+
images = images.squeeze(dim=seq_len_dim) # get rid of timestep dimension
|
243 |
+
x_batch = self._preprocess_images(images)
|
244 |
+
quant, _, (_, _, encoding_indices) = self.vq_model.encode(x_batch)
|
245 |
+
# Split back into (B, self.encoded_bottleneck_dim)
|
246 |
+
encoding_indices = encoding_indices.reshape(quant.shape[0], 1, quant.shape[2], quant.shape[3])
|
247 |
+
quant = quant.unsqueeze(seq_len_dim)
|
248 |
+
return None, {"quantized": quant, "encoding_indices": encoding_indices}
|
249 |
+
|
250 |
+
def _encode(self, world_space_tensor: torch.tensor) -> torch.tensor:
|
251 |
+
batch, time = world_space_tensor.shape[:2]
|
252 |
+
world_space_tensor = world_space_tensor.view(batch * time, *world_space_tensor.shape[2:])
|
253 |
+
encodings = self.get_encoding_indices_for_images(world_space_tensor)
|
254 |
+
# Reshape back to (batch, time, ...)
|
255 |
+
encodings = encodings.view(batch, time, -1)
|
256 |
+
return encodings
|
257 |
+
|
258 |
+
def _decode(self, encoder_space_tensor: torch.tensor) -> torch.tensor:
|
259 |
+
batch, time = encoder_space_tensor.shape[:2]
|
260 |
+
encoder_space_tensor = encoder_space_tensor.view(batch * time, *encoder_space_tensor.shape[2:])
|
261 |
+
decoded = self.decode_from_encoding_indices(encoder_space_tensor)
|
262 |
+
# Reshape back to (batch, time, ...)
|
263 |
+
decoded = decoded.view(batch, time, *decoded.shape[1:])
|
264 |
+
return decoded
|
wham/models/vqgan/vqgan.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from wham.models.wham_base.tensor_spaces import TensorSpace
|
7 |
+
from wham.models.wham_base.encoder_decoder import EncoderDecoderBase
|
8 |
+
|
9 |
+
from wham.models.vqgan import vqgan_models as vqgan
|
10 |
+
from wham.models.vqvae.vqvae_utils import make_grid, normalise_rgb, rev_normalise_rgb
|
11 |
+
|
12 |
+
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
13 |
+
from pytorch_lightning.loggers.wandb import WandbLogger
|
14 |
+
|
15 |
+
TARGET_GAN_UPDATE = 5
|
16 |
+
GAN_DWEIGHT_MAX = 250
|
17 |
+
GAN_LOGIT_CAP = 5.0
|
18 |
+
MAX_PIXEL_WEIGHTING = 0.1
|
19 |
+
|
20 |
+
# The GAN parts are from Taming Transformers (https://github.com/CompVis/taming-transformers)
|
21 |
+
"""
|
22 |
+
ViT-VQGAN is based on:
|
23 |
+
Yu, Jiahui, et al. "Vector-quantized image modeling with improved vqgan."
|
24 |
+
ICLR 2022
|
25 |
+
"""
|
26 |
+
|
27 |
+
|
28 |
+
def create_vqgan_model_for_training(variant):
|
29 |
+
return VQGANModel(variant=variant)
|
30 |
+
|
31 |
+
|
32 |
+
class VQGANModel(EncoderDecoderBase):
|
33 |
+
@classmethod
|
34 |
+
def create_from_variant(cls, variant):
|
35 |
+
return VQGANModel(variant=variant)
|
36 |
+
|
37 |
+
def __init__(self, variant=None, ckpt_path=None, model_spec=None):
|
38 |
+
super().__init__()
|
39 |
+
self.save_hyperparameters()
|
40 |
+
self.variant = variant
|
41 |
+
if model_spec is not None:
|
42 |
+
self.model_spec = model_spec
|
43 |
+
else:
|
44 |
+
self.model_spec = variant["model_spec"]
|
45 |
+
|
46 |
+
# Batches of images we will use for logging
|
47 |
+
self.reference_x_batch = None # Same images used throughout training to see progress of the model
|
48 |
+
self.random_batch = None # Different images every iteration
|
49 |
+
|
50 |
+
if variant is None and "image_size_per_y_axis" in self.model_spec:
|
51 |
+
self.image_size_x = self.model_spec["image_size_per_x_axis"]
|
52 |
+
self.image_size_y = self.model_spec["image_size_per_y_axis"]
|
53 |
+
else:
|
54 |
+
assert "image_size_per_x_axis" in variant and "image_size_per_y_axis" in variant, "Please provide the image size as separate x and y for the VQGAN model"
|
55 |
+
self.image_size_x = variant["image_size_per_x_axis"]
|
56 |
+
self.image_size_y = variant["image_size_per_y_axis"]
|
57 |
+
|
58 |
+
self._embedding_dim = self.model_spec["embedding_dim"]
|
59 |
+
self.encoder = vqgan.ViTEncoder(
|
60 |
+
patch_size=self.model_spec["patch_size"],
|
61 |
+
transf_dim=self.model_spec["transf_dim"],
|
62 |
+
embedding_dim=self.model_spec["embedding_dim"],
|
63 |
+
image_size_x=self.image_size_x,
|
64 |
+
image_size_y=self.image_size_y,
|
65 |
+
num_layers=self.model_spec["num_layers"],
|
66 |
+
head_size=self.model_spec["head_size"],
|
67 |
+
)
|
68 |
+
self._bottleneck_size = self.encoder.bottleneck
|
69 |
+
|
70 |
+
self.vq_vae = vqgan.ViTVectorQuantizer(
|
71 |
+
self.model_spec["vocab_size"],
|
72 |
+
self.model_spec["embedding_dim"],
|
73 |
+
self.model_spec["commitment_cost"],
|
74 |
+
)
|
75 |
+
|
76 |
+
self.decoder = vqgan.ViTDecoder(
|
77 |
+
patch_size=self.model_spec["patch_size"],
|
78 |
+
transf_dim=self.model_spec["transf_dim"],
|
79 |
+
embedding_dim=self.model_spec["embedding_dim"],
|
80 |
+
image_size_x=self.image_size_x,
|
81 |
+
image_size_y=self.image_size_y,
|
82 |
+
num_layers=self.model_spec["num_layers"],
|
83 |
+
head_size=self.model_spec["head_size"],
|
84 |
+
expected_bottleneck=self._bottleneck_size,
|
85 |
+
)
|
86 |
+
|
87 |
+
self.is_perceptual = self.model_spec["is_perceptual"]
|
88 |
+
assert self.is_perceptual # This should be on
|
89 |
+
|
90 |
+
# Keep track of the usage of the codebook indices
|
91 |
+
self.codebook_index_usage = np.zeros(self.model_spec["vocab_size"], dtype=np.int64)
|
92 |
+
|
93 |
+
self.gan = self.model_spec.get("use_gan", False)
|
94 |
+
if self.gan:
|
95 |
+
# Only make the patchgan if we are using it. This makes it easier to experiment with GAN settings after pretraining the VQ-VAE for instance
|
96 |
+
self.patch_gan = vqgan.PatchGan(channel_start=self.model_spec["gan_channel_start"])
|
97 |
+
# Make a copy of the patchgan since we are only using a single optimizer
|
98 |
+
self.target_patchgan = vqgan.PatchGan(channel_start=self.model_spec["gan_channel_start"])
|
99 |
+
self.target_patchgan.requires_grad_(False)
|
100 |
+
self.target_patchgan.load_state_dict(self.patch_gan.state_dict())
|
101 |
+
self.target_update = TARGET_GAN_UPDATE
|
102 |
+
|
103 |
+
# At which iteration to start using the GAN loss
|
104 |
+
self.gan_start = self.model_spec["gan_start"]
|
105 |
+
# How much weight to give to the GAN loss gradients compared to the vq autoencoder loss
|
106 |
+
self.gan_weight = self.model_spec["gan_weight"]
|
107 |
+
# How many steps to train the discriminator before applying the gan loss.
|
108 |
+
self.gan_discrim_pretrain = self.model_spec["gan_discrim_pretrain"]
|
109 |
+
# How many steps to warmup the gan loss
|
110 |
+
self.gan_discrim_warmup = self.model_spec["gan_discrim_warmup"]
|
111 |
+
# Keeping track of the number of updates
|
112 |
+
self.updates = 0
|
113 |
+
print(f"Using GAN with weight {self.gan_weight} and target update {self.target_update} and gan start {self.gan_start} over {self.gan_discrim_warmup} steps")
|
114 |
+
|
115 |
+
self.lpips_model = None
|
116 |
+
# We don't need this for using the encoder/decoder
|
117 |
+
# self.lpips_model = lpips.LPIPS(net=self.model_spec["lpips_model"]).eval()
|
118 |
+
# for param in self.lpips_model.parameters():
|
119 |
+
# param.requires_grad = False
|
120 |
+
|
121 |
+
if ckpt_path is not None and ckpt_path != "None":
|
122 |
+
print(f"Initing VQGAN model from {ckpt_path}")
|
123 |
+
loaded_ckpt = torch.load(ckpt_path, map_location="cpu")
|
124 |
+
# Can ignore stuff here
|
125 |
+
self.load_state_dict(loaded_ckpt["state_dict"], strict=False)
|
126 |
+
|
127 |
+
self.world_space = TensorSpace((3, self.image_size_y, self.image_size_x), dtype=torch.uint8, low=0, high=255)
|
128 |
+
self.encoder_space = TensorSpace((self._bottleneck_size,), dtype=torch.long, low=0, high=self.vocab_size - 1)
|
129 |
+
|
130 |
+
@property
|
131 |
+
def vocab_size(self):
|
132 |
+
"""Return the number of entries in the codebook."""
|
133 |
+
return self.vq_vae._vocab_size
|
134 |
+
|
135 |
+
@property
|
136 |
+
def encoded_bottleneck_dim(self):
|
137 |
+
"""Return the dimensionality of the latent vector encoded into codebook indices."""
|
138 |
+
return self._bottleneck_size
|
139 |
+
|
140 |
+
@property
|
141 |
+
def embedding_dim(self):
|
142 |
+
"""The dimensionality of quantized vectors (the dimension of codebook vectors)."""
|
143 |
+
return self.vq_vae._embedding_dim
|
144 |
+
|
145 |
+
def _get_last_layer(self):
|
146 |
+
"""
|
147 |
+
The last layer used for generating the image.
|
148 |
+
Used for balancing the gradients of the reconstruction and the GAN loss.
|
149 |
+
"""
|
150 |
+
return self.decoder.get_last_layer()
|
151 |
+
|
152 |
+
def _preprocess_images(self, images):
|
153 |
+
"""Preprocess images (B, C, H, W)"""
|
154 |
+
x_batch = images.float() / 255
|
155 |
+
x_batch = normalise_rgb(x_batch)
|
156 |
+
return x_batch
|
157 |
+
|
158 |
+
def _revert_image_preprocess(self, x_batch):
|
159 |
+
"""Revert the preprocessing done in _preprocess_images"""
|
160 |
+
normalized_imgs = rev_normalise_rgb(x_batch.clone())
|
161 |
+
x_batch = torch.clip(normalized_imgs, 0, 1)
|
162 |
+
images = (x_batch * 255).byte()
|
163 |
+
return images
|
164 |
+
|
165 |
+
def _get_latent_continuous(self, batch):
|
166 |
+
z = self.encoder(batch)
|
167 |
+
return z
|
168 |
+
|
169 |
+
def _get_latent_discretized(self, z):
|
170 |
+
z_quantized, vq_loss, perplexity, indices = self.vq_vae(z)
|
171 |
+
return z_quantized, vq_loss, perplexity, indices
|
172 |
+
|
173 |
+
def _encode_decode(self, x_batch):
|
174 |
+
z = self._get_latent_continuous(x_batch)
|
175 |
+
z_quantized, vq_loss, perplexity, indices = self._get_latent_discretized(z)
|
176 |
+
data_recon = self.decoder(z_quantized)
|
177 |
+
return vq_loss, perplexity, data_recon, indices
|
178 |
+
|
179 |
+
def _log_vars(self, log_vars):
|
180 |
+
prefix = "train" if self.training else "val"
|
181 |
+
for key, val in log_vars.items():
|
182 |
+
self.log(f"{prefix}/{key}", val, on_step=True, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
|
183 |
+
|
184 |
+
def decode_from_encoding_indices(self, encoding_indices):
|
185 |
+
"""Return decoded images (B, C, H, W) for a batch of encoding indices (B, self.encoded_bottleneck_dim)"""
|
186 |
+
z = self.vq_vae.convert_encoding_indices_to_quantized_embeddings(encoding_indices)
|
187 |
+
data_recon = self.decoder(z)
|
188 |
+
# Denormalize and cast to uint8
|
189 |
+
data_recon = self._revert_image_preprocess(data_recon)
|
190 |
+
return data_recon
|
191 |
+
|
192 |
+
def get_encoding_indices_for_images(self, images):
|
193 |
+
"""
|
194 |
+
Return encoding indices (B, self.encoded_bottleneck_dim) for a batch of images (B, C, H, W).
|
195 |
+
Useful auxiliary method for testing.
|
196 |
+
"""
|
197 |
+
x_batch = self._preprocess_images(images)
|
198 |
+
z = self._get_latent_continuous(x_batch)
|
199 |
+
encoding_indices = self.vq_vae(z, only_return_encoding_indices=True)
|
200 |
+
return encoding_indices
|
201 |
+
|
202 |
+
def forward_returning_action_and_embedding(self, states, actions_input, timesteps, attention_mask, images):
|
203 |
+
raise NotImplementedError
|
204 |
+
|
205 |
+
def get_encoding_output(self, images):
|
206 |
+
"""
|
207 |
+
Return outputs from the encoder for a batch of images (B, C, H, W).
|
208 |
+
Returns:
|
209 |
+
quantized_z: (B, self.encoded_bottleneck_dim, self.embedding_dim), quantized latent vectors with straight-through gradient estimator
|
210 |
+
vq_loss: (B, ), VQ loss for each image
|
211 |
+
perplexity: (B, ), perplexity for each image
|
212 |
+
encoding_indices: (B, self.encoded_bottleneck_dim), encoding indices for each image
|
213 |
+
"""
|
214 |
+
x_batch = self._preprocess_images(images)
|
215 |
+
z = self._get_latent_continuous(x_batch)
|
216 |
+
quantized_z, vq_loss, perplexity, encoding_indices = self.vq_vae(z)
|
217 |
+
quantized_z = quantized_z.view(quantized_z.shape[0], self.encoded_bottleneck_dim, self.embedding_dim)
|
218 |
+
return quantized_z, vq_loss, perplexity, encoding_indices
|
219 |
+
|
220 |
+
def _encode(self, world_space_tensor: torch.tensor) -> torch.tensor:
|
221 |
+
# Flatten time and batch dim into one
|
222 |
+
batch, time = world_space_tensor.shape[:2]
|
223 |
+
world_space_tensor = world_space_tensor.view(batch * time, *world_space_tensor.shape[2:])
|
224 |
+
encodings = self.get_encoding_indices_for_images(world_space_tensor)
|
225 |
+
# Reshape back to (batch, time, ...)
|
226 |
+
encodings = encodings.view(batch, time, -1)
|
227 |
+
return encodings
|
228 |
+
|
229 |
+
def _decode(self, encoder_space_tensor: torch.tensor) -> torch.tensor:
|
230 |
+
# Flatten time and batch dim into one
|
231 |
+
batch, time = encoder_space_tensor.shape[:2]
|
232 |
+
encoder_space_tensor = encoder_space_tensor.view(batch * time, *encoder_space_tensor.shape[2:])
|
233 |
+
decoded = self.decode_from_encoding_indices(encoder_space_tensor)
|
234 |
+
# Reshape back to (batch, time, ...)
|
235 |
+
decoded = decoded.view(batch, time, *decoded.shape[1:])
|
236 |
+
return decoded
|
wham/models/vqgan/vqgan_models.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
# Copyright (c) 2018 Zalando Research
|
3 |
+
# 2023 Microsoft Research
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
16 |
+
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
17 |
+
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
18 |
+
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
19 |
+
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
20 |
+
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
21 |
+
# OR OTHER DEALINGS IN THE SOFTWARE.
|
22 |
+
|
23 |
+
from math import sqrt
|
24 |
+
|
25 |
+
import torch
|
26 |
+
import torch.nn as nn
|
27 |
+
import torch.nn.functional as F
|
28 |
+
|
29 |
+
from wham.models.nn.nanoGPT import GPTConfig, SelfAttentionBlock
|
30 |
+
from wham.models.nn.model_blocks import ConvNextBlock, ConvNextDownsample, ConvNextDownsampleBig
|
31 |
+
|
32 |
+
# Mainly following https://github.com/zalandoresearch/pytorch-vq-vae/blob/master/vq-vae.ipynb
|
33 |
+
"""
|
34 |
+
ViT-VQGAN is based on:
|
35 |
+
Yu, Jiahui, et al. "Vector-quantized image modeling with improved vqgan."
|
36 |
+
ICLR 2022
|
37 |
+
"""
|
38 |
+
|
39 |
+
|
40 |
+
def _convert_encoding_indices_to_quantized_embeddings(encoding_indices, embedding_layer, vocab_size, embedding_dim):
|
41 |
+
"""
|
42 |
+
Args:
|
43 |
+
encoding_indices: tensor of integers (batch_size, bottleneck_size)
|
44 |
+
Each batch item represents a single image as a sequence of integers (indeces of codebook vectors)
|
45 |
+
Output:
|
46 |
+
quantized: tensor of floats (batch_size, bottleneck_size, embedding_dim)
|
47 |
+
"""
|
48 |
+
batch_dim, bottleneck_size = encoding_indices.shape[:2]
|
49 |
+
|
50 |
+
encoding_indices = encoding_indices.view(-1).unsqueeze(1)
|
51 |
+
one_hot_encoding_indices = torch.zeros(encoding_indices.shape[0], vocab_size, device=encoding_indices.device)
|
52 |
+
one_hot_encoding_indices.scatter_(1, encoding_indices, 1)
|
53 |
+
|
54 |
+
quantized = torch.matmul(one_hot_encoding_indices, embedding_layer)
|
55 |
+
quantized = quantized.view(batch_dim, bottleneck_size, embedding_dim).contiguous()
|
56 |
+
return quantized
|
57 |
+
|
58 |
+
|
59 |
+
class ViTVectorQuantizer(nn.Module):
|
60 |
+
"""
|
61 |
+
Vector Quantizer for a Vision Transformer based VQ model using normalised codebook embeddings as in https://arxiv.org/abs/2110.04627.
|
62 |
+
"""
|
63 |
+
|
64 |
+
def __init__(self, vocab_size, embedding_dim, commitment_cost, epsilon=1e-5):
|
65 |
+
super().__init__()
|
66 |
+
|
67 |
+
self._embedding_dim = embedding_dim
|
68 |
+
self._vocab_size = vocab_size
|
69 |
+
self._epsilon = epsilon
|
70 |
+
|
71 |
+
self._embedding = nn.Embedding(self._vocab_size, self._embedding_dim)
|
72 |
+
self._embedding.weight.data.uniform_(-1 / self._vocab_size, 1 / self._vocab_size)
|
73 |
+
self._commitment_cost = commitment_cost
|
74 |
+
|
75 |
+
@property
|
76 |
+
def vocab_size(self):
|
77 |
+
"""Return the number of entries in the codebook."""
|
78 |
+
return self._vocab_size
|
79 |
+
|
80 |
+
def convert_encoding_indices_to_quantized_embeddings(self, encoding_indices):
|
81 |
+
"""
|
82 |
+
Args:
|
83 |
+
encoding_indices: tensor of integers (batch_size, bottleneck_size)
|
84 |
+
Each batch item represents a single image as a sequence of integers (indeces of codebook vectors)
|
85 |
+
Output:
|
86 |
+
quantized: tensor of floats (batch_size, self._embedding_dim, bottleneck_size)
|
87 |
+
"""
|
88 |
+
return _convert_encoding_indices_to_quantized_embeddings(encoding_indices, F.normalize(self._embedding.weight), self._vocab_size, self._embedding_dim)
|
89 |
+
|
90 |
+
def forward(self, inputs, only_return_encoding_indices=False):
|
91 |
+
"""
|
92 |
+
If only_return_encoding_indices is True, then only return the indices of codebook vectors
|
93 |
+
"""
|
94 |
+
input_shape = inputs.shape
|
95 |
+
|
96 |
+
# Flatten input from Batch Tokens Embedding to B*T E
|
97 |
+
flat_input = inputs.view(-1, self._embedding_dim)
|
98 |
+
# Normalize inputs
|
99 |
+
flat_input = F.normalize(flat_input)
|
100 |
+
|
101 |
+
# Embeddings are always normalized
|
102 |
+
embeddings_to_use = F.normalize(self._embedding.weight)
|
103 |
+
|
104 |
+
# Calculate distances
|
105 |
+
distances = torch.sum(flat_input**2, dim=1, keepdim=True) + torch.sum(embeddings_to_use**2, dim=1) - 2 * torch.matmul(flat_input, embeddings_to_use.t())
|
106 |
+
|
107 |
+
# Encoding
|
108 |
+
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
|
109 |
+
if only_return_encoding_indices:
|
110 |
+
# Add back batch dimension
|
111 |
+
return encoding_indices.view(input_shape[0], -1)
|
112 |
+
one_hot_encoding_indices = torch.zeros(encoding_indices.shape[0], self._vocab_size, device=inputs.device)
|
113 |
+
one_hot_encoding_indices.scatter_(1, encoding_indices, 1)
|
114 |
+
|
115 |
+
# Quantize and unflatten
|
116 |
+
quantized = torch.matmul(one_hot_encoding_indices, embeddings_to_use).view(input_shape)
|
117 |
+
|
118 |
+
# Loss
|
119 |
+
e_latent_loss = F.mse_loss(quantized.detach(), inputs)
|
120 |
+
q_latent_loss = F.mse_loss(quantized, inputs.detach())
|
121 |
+
loss = q_latent_loss + self._commitment_cost * e_latent_loss
|
122 |
+
|
123 |
+
quantized = inputs + (quantized - inputs).detach()
|
124 |
+
avg_probs = torch.mean(one_hot_encoding_indices, dim=0)
|
125 |
+
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + self._epsilon)))
|
126 |
+
|
127 |
+
return quantized, loss, perplexity, encoding_indices.view(input_shape[0], -1)
|
128 |
+
|
129 |
+
|
130 |
+
class ViTEncoder(nn.Module):
|
131 |
+
def __init__(self, patch_size, transf_dim, embedding_dim, image_size_x, image_size_y, num_layers, head_size):
|
132 |
+
super().__init__()
|
133 |
+
|
134 |
+
self.image_size_x = image_size_x
|
135 |
+
self.image_size_y = image_size_y
|
136 |
+
# We will pad the image to make it divisible by patch_size
|
137 |
+
self.x_pad = (patch_size - (self.image_size_x % patch_size)) % patch_size
|
138 |
+
self.y_pad = (patch_size - (self.image_size_y % patch_size)) % patch_size
|
139 |
+
assert (self.image_size_x + self.x_pad) % patch_size == 0 and (
|
140 |
+
self.image_size_y + self.y_pad
|
141 |
+
) % patch_size == 0, "image_size_x and image_size_y must be divisible by patch_size"
|
142 |
+
|
143 |
+
self.vit_tokens = ((image_size_x + self.x_pad) // patch_size) * ((image_size_y + self.y_pad) // patch_size)
|
144 |
+
self._bottleneck = self.vit_tokens
|
145 |
+
print(f"Bottleneck is {self.bottleneck} for image size {image_size_x}x{image_size_y} with ViT Encoder and patch size {patch_size}")
|
146 |
+
|
147 |
+
self.patch_size = patch_size
|
148 |
+
self.transf_dim = transf_dim
|
149 |
+
self.embedding_dim = embedding_dim
|
150 |
+
|
151 |
+
self.proj1 = nn.Linear(3 * patch_size * patch_size, transf_dim)
|
152 |
+
self.pos_embeds = nn.Embedding(self.vit_tokens, transf_dim)
|
153 |
+
|
154 |
+
assert self.transf_dim % head_size == 0, "transf_dim must be divisible by head_size"
|
155 |
+
n_heads = self.transf_dim // head_size
|
156 |
+
transformer_config = GPTConfig(block_size=self.vit_tokens, n_layer=num_layers, n_head=n_heads, n_embd=transf_dim, bias=False, dropout=0)
|
157 |
+
self.vit = nn.Sequential(*[SelfAttentionBlock(transformer_config) for _ in range(num_layers)])
|
158 |
+
|
159 |
+
self.output_ln = nn.LayerNorm(transf_dim)
|
160 |
+
self.output_proj = nn.Linear(transf_dim, embedding_dim)
|
161 |
+
|
162 |
+
# init all weights
|
163 |
+
self.apply(self._init_weights)
|
164 |
+
# apply special scaled init to the residual projections, per GPT-2 paper
|
165 |
+
for pn, p in self.named_parameters():
|
166 |
+
if pn.endswith("c_proj.weight"):
|
167 |
+
torch.nn.init.normal_(p, mean=0.0, std=0.02 / sqrt(2 * transformer_config.n_layer))
|
168 |
+
|
169 |
+
@property
|
170 |
+
def bottleneck(self):
|
171 |
+
return self._bottleneck
|
172 |
+
|
173 |
+
def _init_weights(self, module):
|
174 |
+
if isinstance(module, nn.Linear):
|
175 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
176 |
+
if module.bias is not None:
|
177 |
+
torch.nn.init.zeros_(module.bias)
|
178 |
+
elif isinstance(module, nn.Embedding):
|
179 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
180 |
+
|
181 |
+
def forward(self, inputs):
|
182 |
+
# inputs: (batch_size, 3, image_size_x, image_size_y)
|
183 |
+
|
184 |
+
# Patch input images
|
185 |
+
batch_size = inputs.shape[0]
|
186 |
+
padded_inputs = F.pad(inputs, (0, self.x_pad, 0, self.y_pad), mode="constant", value=0)
|
187 |
+
x = padded_inputs.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
|
188 |
+
num_x_patches = (self.image_size_x + self.x_pad) // self.patch_size
|
189 |
+
num_y_patches = (self.image_size_y + self.y_pad) // self.patch_size
|
190 |
+
|
191 |
+
# inputs is of shape (batch_size, 3, num_x_patches, num_y_patches, patch_size, patch_size)
|
192 |
+
# Turn it into (batch_size, patches, input_dim)
|
193 |
+
patches = x.permute(0, 2, 3, 1, 4, 5).contiguous().view(batch_size, num_x_patches * num_y_patches, 3 * self.patch_size * self.patch_size)
|
194 |
+
|
195 |
+
proj_patches = self.proj1(patches)
|
196 |
+
|
197 |
+
pos_embeds = self.pos_embeds.weight.unsqueeze(0).repeat(batch_size, 1, 1)
|
198 |
+
vit_input = proj_patches + pos_embeds
|
199 |
+
vit_output = self.vit(vit_input)
|
200 |
+
|
201 |
+
vit_output = self.output_ln(vit_output)
|
202 |
+
embeddings = self.output_proj(vit_output)
|
203 |
+
normalised_embeddings = F.normalize(embeddings, dim=-1)
|
204 |
+
|
205 |
+
return normalised_embeddings
|
206 |
+
|
207 |
+
|
208 |
+
class ViTDecoder(nn.Module):
|
209 |
+
def __init__(self, patch_size, transf_dim, embedding_dim, image_size_x, image_size_y, num_layers, head_size, expected_bottleneck=None):
|
210 |
+
super().__init__()
|
211 |
+
|
212 |
+
self.image_size_x = image_size_x
|
213 |
+
self.image_size_y = image_size_y
|
214 |
+
self.x_pad = (patch_size - (self.image_size_x % patch_size)) % patch_size
|
215 |
+
self.y_pad = (patch_size - (self.image_size_y % patch_size)) % patch_size
|
216 |
+
|
217 |
+
assert (self.image_size_x + self.x_pad) % patch_size == 0 and (
|
218 |
+
self.image_size_y + self.y_pad
|
219 |
+
) % patch_size == 0, "image_size_x and image_size_y must be divisible by patch_size"
|
220 |
+
|
221 |
+
self.vit_tokens = ((image_size_x + self.x_pad) // patch_size) * ((image_size_y + self.y_pad) // patch_size)
|
222 |
+
if expected_bottleneck is not None:
|
223 |
+
assert (
|
224 |
+
self.vit_tokens == expected_bottleneck
|
225 |
+
), f"Expected bottleneck of {expected_bottleneck} but got {self.vit_tokens} for image size {image_size_x}x{image_size_y} with ViT Decoder and patch size {patch_size}"
|
226 |
+
|
227 |
+
self.patch_size = patch_size
|
228 |
+
self.transf_dim = transf_dim
|
229 |
+
self.embedding_dim = embedding_dim
|
230 |
+
|
231 |
+
self.proj1 = nn.Linear(embedding_dim, transf_dim)
|
232 |
+
self.pos_embeds = nn.Embedding(self.vit_tokens, transf_dim)
|
233 |
+
|
234 |
+
assert self.transf_dim % head_size == 0, "transf_dim must be divisible by head_size"
|
235 |
+
n_heads = self.transf_dim // head_size
|
236 |
+
transformer_config = GPTConfig(block_size=self.vit_tokens, n_layer=num_layers, n_head=n_heads, n_embd=transf_dim, bias=False, dropout=0)
|
237 |
+
self.vit = nn.Sequential(*[SelfAttentionBlock(transformer_config) for _ in range(num_layers)])
|
238 |
+
|
239 |
+
self.output_ln = nn.LayerNorm(transf_dim)
|
240 |
+
self.output_proj = nn.Linear(transf_dim, 3 * patch_size * patch_size)
|
241 |
+
|
242 |
+
# Couldn't resist the name
|
243 |
+
self.folder = nn.Fold(
|
244 |
+
output_size=(self.image_size_y + self.y_pad, self.image_size_x + self.x_pad),
|
245 |
+
kernel_size=(self.patch_size, self.patch_size),
|
246 |
+
stride=(self.patch_size, self.patch_size),
|
247 |
+
)
|
248 |
+
|
249 |
+
# init all weights
|
250 |
+
self.apply(self._init_weights)
|
251 |
+
# apply special scaled init to the residual projections, per GPT-2 paper
|
252 |
+
for pn, p in self.named_parameters():
|
253 |
+
if pn.endswith("c_proj.weight"):
|
254 |
+
torch.nn.init.normal_(p, mean=0.0, std=0.02 / sqrt(2 * transformer_config.n_layer))
|
255 |
+
|
256 |
+
def _init_weights(self, module):
|
257 |
+
if isinstance(module, nn.Linear):
|
258 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
259 |
+
if module.bias is not None:
|
260 |
+
torch.nn.init.zeros_(module.bias)
|
261 |
+
elif isinstance(module, nn.Embedding):
|
262 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
263 |
+
|
264 |
+
def forward(self, inputs):
|
265 |
+
# Patch input images
|
266 |
+
batch_size = inputs.shape[0]
|
267 |
+
|
268 |
+
# Unproject the embeddings from the VQ embedding space to the transformer space
|
269 |
+
proj_patches = self.proj1(inputs).reshape(batch_size, self.vit_tokens, self.transf_dim)
|
270 |
+
|
271 |
+
pos_embeds = self.pos_embeds.weight.unsqueeze(0).repeat(batch_size, 1, 1)
|
272 |
+
vit_input = proj_patches + pos_embeds
|
273 |
+
vit_output = self.vit(vit_input)
|
274 |
+
|
275 |
+
vit_output = self.output_ln(vit_output)
|
276 |
+
|
277 |
+
predictions = self.output_proj(vit_output) # (batch, patches, 3 * patch_size * patch_size)
|
278 |
+
|
279 |
+
# Reassemble the image into (batch, 3, image_size_x, image_size_y)
|
280 |
+
fold_inputs = predictions.permute(0, 2, 1).contiguous()
|
281 |
+
image_pred = self.folder(fold_inputs)
|
282 |
+
|
283 |
+
unpadded_image_pred = image_pred[:, :, : self.image_size_y, : self.image_size_x] # Remove padding in the same way it was applied in the encoder
|
284 |
+
|
285 |
+
# Anything on the output?
|
286 |
+
return unpadded_image_pred
|
287 |
+
|
288 |
+
def get_last_layer(self):
|
289 |
+
"""
|
290 |
+
Return the last layer weights of the model, to use for loss balancing.
|
291 |
+
"""
|
292 |
+
return self.output_proj.weight
|
293 |
+
|
294 |
+
|
295 |
+
class PatchGan(nn.Module):
|
296 |
+
def __init__(self, channel_start):
|
297 |
+
super().__init__()
|
298 |
+
x = channel_start
|
299 |
+
self.downsample1 = ConvNextDownsampleBig(3, x)
|
300 |
+
self.block1 = ConvNextBlock(x)
|
301 |
+
self.downsample2 = ConvNextDownsampleBig(x, x)
|
302 |
+
self.block2 = ConvNextBlock(x)
|
303 |
+
self.last = nn.Conv2d(x, 1, kernel_size=1, stride=1, padding=0)
|
304 |
+
|
305 |
+
def forward(self, x):
|
306 |
+
batch_size = x.shape[0]
|
307 |
+
y = torch.nn.functional.gelu(self.downsample1(x))
|
308 |
+
y = self.block1(y)
|
309 |
+
z = torch.nn.functional.gelu(self.downsample2(y))
|
310 |
+
z = self.block2(z)
|
311 |
+
return self.last(z).reshape(batch_size, -1)
|
wham/models/vqvae/vqvae_utils.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def normalise_rgb(X, channels_first=True):
|
8 |
+
"""
|
9 |
+
Take in an image tensor of shape [ ... , 3], which is assumed to have already been divided by
|
10 |
+
255 so X \in [0,1]. These functions do additional normalisation, roughly ending up with mean
|
11 |
+
of zero and unit variance. The constants appeared in most vision repos, and are supposedly the
|
12 |
+
'right' constants to use based on imagenet statistics.
|
13 |
+
assert X.shape[-1] == 3
|
14 |
+
"""
|
15 |
+
channel_dim = 1 if channels_first else -1
|
16 |
+
assert X.shape[channel_dim] == 3
|
17 |
+
if channels_first:
|
18 |
+
X[:, 0, ...] -= 0.485
|
19 |
+
X[:, 0, ...] /= 0.229
|
20 |
+
X[:, 1, ...] -= 0.456
|
21 |
+
X[:, 1, ...] /= 0.224
|
22 |
+
X[:, 2, ...] -= 0.406
|
23 |
+
X[:, 2, ...] /= 0.225
|
24 |
+
else:
|
25 |
+
X[..., 0] -= 0.485
|
26 |
+
X[..., 0] /= 0.229
|
27 |
+
X[..., 1] -= 0.456
|
28 |
+
X[..., 1] /= 0.224
|
29 |
+
X[..., 2] -= 0.406
|
30 |
+
X[..., 2] /= 0.225
|
31 |
+
return X
|
32 |
+
|
33 |
+
|
34 |
+
def rev_normalise_rgb(X, channels_first=True):
|
35 |
+
"""
|
36 |
+
Reverse `normalise_rgb`, so the output lives in [0,1]. This function is needed for
|
37 |
+
reconstruction visualisation, etc.
|
38 |
+
"""
|
39 |
+
channel_dim = 1 if channels_first else -1
|
40 |
+
assert X.shape[channel_dim] == 3
|
41 |
+
if channels_first:
|
42 |
+
X[:, 0, ...] *= 0.229
|
43 |
+
X[:, 0, ...] += 0.485
|
44 |
+
X[:, 1, ...] *= 0.224
|
45 |
+
X[:, 1, ...] += 0.456
|
46 |
+
X[:, 2, ...] *= 0.225
|
47 |
+
X[:, 2, ...] += 0.406
|
48 |
+
else:
|
49 |
+
X[..., 0] *= 0.229
|
50 |
+
X[..., 0] += 0.485
|
51 |
+
X[..., 1] *= 0.224
|
52 |
+
X[..., 1] += 0.456
|
53 |
+
X[..., 2] *= 0.225
|
54 |
+
X[..., 2] += 0.406
|
55 |
+
return X
|
56 |
+
|
57 |
+
|
58 |
+
@torch.no_grad()
|
59 |
+
def make_grid(
|
60 |
+
tensor: Union[torch.Tensor, List[torch.Tensor]],
|
61 |
+
nrow: int = 8,
|
62 |
+
padding: int = 2,
|
63 |
+
normalize: bool = False,
|
64 |
+
value_range: Optional[Tuple[int, int]] = None,
|
65 |
+
scale_each: bool = False,
|
66 |
+
pad_value: float = 0.0,
|
67 |
+
**kwargs,
|
68 |
+
) -> torch.Tensor:
|
69 |
+
"""
|
70 |
+
Make a grid of images.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
|
74 |
+
or a list of images all of the same size.
|
75 |
+
nrow (int, optional): Number of images displayed in each row of the grid.
|
76 |
+
The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
|
77 |
+
padding (int, optional): amount of padding. Default: ``2``.
|
78 |
+
normalize (bool, optional): If True, shift the image to the range (0, 1),
|
79 |
+
by the min and max values specified by ``value_range``. Default: ``False``.
|
80 |
+
value_range (tuple, optional): tuple (min, max) where min and max are numbers,
|
81 |
+
then these numbers are used to normalize the image. By default, min and max
|
82 |
+
are computed from the tensor.
|
83 |
+
scale_each (bool, optional): If ``True``, scale each image in the batch of
|
84 |
+
images separately rather than the (min, max) over all images. Default: ``False``.
|
85 |
+
pad_value (float, optional): Value for the padded pixels. Default: ``0``.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
grid (Tensor): the tensor containing grid of images.
|
89 |
+
"""
|
90 |
+
if not torch.is_tensor(tensor):
|
91 |
+
if isinstance(tensor, list):
|
92 |
+
for t in tensor:
|
93 |
+
if not torch.is_tensor(t):
|
94 |
+
raise TypeError(f"tensor or list of tensors expected, got a list containing {type(t)}")
|
95 |
+
else:
|
96 |
+
raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
|
97 |
+
|
98 |
+
# if list of tensors, convert to a 4D mini-batch Tensor
|
99 |
+
if isinstance(tensor, list):
|
100 |
+
tensor = torch.stack(tensor, dim=0)
|
101 |
+
|
102 |
+
if tensor.dim() == 2: # single image H x W
|
103 |
+
tensor = tensor.unsqueeze(0)
|
104 |
+
if tensor.dim() == 3: # single image
|
105 |
+
if tensor.size(0) == 1: # if single-channel, convert to 3-channel
|
106 |
+
tensor = torch.cat((tensor, tensor, tensor), 0)
|
107 |
+
tensor = tensor.unsqueeze(0)
|
108 |
+
|
109 |
+
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
|
110 |
+
tensor = torch.cat((tensor, tensor, tensor), 1)
|
111 |
+
|
112 |
+
if normalize is True:
|
113 |
+
tensor = tensor.clone() # avoid modifying tensor in-place
|
114 |
+
if value_range is not None and not isinstance(value_range, tuple):
|
115 |
+
raise TypeError("value_range has to be a tuple (min, max) if specified. min and max are numbers")
|
116 |
+
|
117 |
+
def norm_ip(img, low, high):
|
118 |
+
img.clamp_(min=low, max=high)
|
119 |
+
img.sub_(low).div_(max(high - low, 1e-5))
|
120 |
+
|
121 |
+
def norm_range(t, value_range):
|
122 |
+
if value_range is not None:
|
123 |
+
norm_ip(t, value_range[0], value_range[1])
|
124 |
+
else:
|
125 |
+
norm_ip(t, float(t.min()), float(t.max()))
|
126 |
+
|
127 |
+
if scale_each is True:
|
128 |
+
for t in tensor: # loop over mini-batch dimension
|
129 |
+
norm_range(t, value_range)
|
130 |
+
else:
|
131 |
+
norm_range(tensor, value_range)
|
132 |
+
|
133 |
+
if not isinstance(tensor, torch.Tensor):
|
134 |
+
raise TypeError("tensor should be of type torch.Tensor")
|
135 |
+
if tensor.size(0) == 1:
|
136 |
+
return tensor.squeeze(0)
|
137 |
+
|
138 |
+
# make the mini-batch of images into a grid
|
139 |
+
nmaps = tensor.size(0)
|
140 |
+
xmaps = min(nrow, nmaps)
|
141 |
+
ymaps = int(math.ceil(float(nmaps) / xmaps))
|
142 |
+
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
|
143 |
+
num_channels = tensor.size(1)
|
144 |
+
grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
|
145 |
+
k = 0
|
146 |
+
for y in range(ymaps):
|
147 |
+
for x in range(xmaps):
|
148 |
+
if k >= nmaps:
|
149 |
+
break
|
150 |
+
# Tensor.copy_() is a valid method but seems to be missing from the stubs
|
151 |
+
# https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_
|
152 |
+
grid.narrow(1, y * height + padding, height - padding).narrow(2, x * width + padding, width - padding).copy_(tensor[k]) # type: ignore[attr-defined]
|
153 |
+
k = k + 1
|
154 |
+
return grid
|
wham/models/wham_base/__init__.py
ADDED
File without changes
|
wham/models/wham_base/encode_predict_decode_base.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Union, Type, Callable, Tuple, Mapping, Optional
|
2 |
+
|
3 |
+
import torch as th
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
from tensordict import TensorDict # type: ignore # requires installing stubs for tensordict
|
6 |
+
|
7 |
+
from .tensor_spaces import TensorDictSpace
|
8 |
+
from .encoder_decoder import EncoderDecoderBase
|
9 |
+
from .pl_creation_args import LightningModuleCreationArgs
|
10 |
+
|
11 |
+
|
12 |
+
def create_encoder_args_from_config_dict(
|
13 |
+
config_dict: dict[str, Union[dict[str, Any], tuple]], class_name_to_model: Callable[[str], Type[pl.LightningModule]]
|
14 |
+
) -> Mapping[str, Union[LightningModuleCreationArgs, Tuple[LightningModuleCreationArgs, LightningModuleCreationArgs]]]:
|
15 |
+
"""
|
16 |
+
Given a dictionary mapping modality names to their encoder-decoder arguments, create the corresponding
|
17 |
+
creation args (LightningModuleCreationArgs) for each modality.
|
18 |
+
|
19 |
+
See LightningModuleCreationArgs.from_dict for more details.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
config_dict: A dictionary mapping modality names to their encoder-decoder arguments.
|
23 |
+
Root level of this dictionary should be modality names we expect.
|
24 |
+
class_name_to_model: A function mapping class names to their corresponding model classes.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
A dictionary mapping modality names to their encoder-decoder creation args.
|
28 |
+
Each value may be a LightningModuleCreationArgs, or a tuple of two LightningModuleCreationArgs.
|
29 |
+
If value is a LightningModuleCreationArgs, then same model is used for encoding and decoding.
|
30 |
+
If value is a tuple of two LightningModuleCreationArgs, then first is used for encoding and second for decoding.
|
31 |
+
"""
|
32 |
+
# Giving explicit type hint here to make mypy happy
|
33 |
+
modalities: dict[str, Any] = {}
|
34 |
+
for modality_name, modality_config in config_dict.items():
|
35 |
+
if isinstance(modality_config, (list, tuple)):
|
36 |
+
assert len(modality_config) == 2, f"Expected two entries for modality {modality_name}, got {len(modality_config)}"
|
37 |
+
modalities[modality_name] = (
|
38 |
+
LightningModuleCreationArgs.from_dict(modality_config[0], class_name_to_model),
|
39 |
+
LightningModuleCreationArgs.from_dict(modality_config[1], class_name_to_model),
|
40 |
+
)
|
41 |
+
else:
|
42 |
+
modalities[modality_name] = LightningModuleCreationArgs.from_dict(modality_config, class_name_to_model)
|
43 |
+
return modalities
|
44 |
+
|
45 |
+
|
46 |
+
def create_encoder_modules_from_args(
|
47 |
+
encoders: Mapping[str, Union[LightningModuleCreationArgs, Tuple[LightningModuleCreationArgs, LightningModuleCreationArgs]]], remove_checkpoint_path: bool = True
|
48 |
+
) -> th.nn.ModuleDict:
|
49 |
+
"""
|
50 |
+
Create the encoder modules from given creation args (LightningModuleCreationArgs).
|
51 |
+
|
52 |
+
Args:
|
53 |
+
encoders: A dictionary mapping modality names to their encoder-decoder creation args.
|
54 |
+
If value is a LightningModuleCreationArgs, then same model is used for encoding and decoding.
|
55 |
+
If value is a tuple of two LightningModuleCreationArgs, then first is used for encoding and second for decoding.
|
56 |
+
remove_checkpoint_path: If True, then remove the checkpoint_path from the creation args. This prepares the
|
57 |
+
created moduled to be properly saved and loaded as part of the bigger model
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
A dictionary mapping modality names to their encoder-decoder modules.
|
61 |
+
"""
|
62 |
+
modalities = {}
|
63 |
+
for modality_name, modality_args in encoders.items():
|
64 |
+
if isinstance(modality_args, (list, tuple)):
|
65 |
+
modalities[modality_name] = th.nn.ModuleList(
|
66 |
+
[
|
67 |
+
modality_args[0].create_module(remove_checkpoint_path=remove_checkpoint_path),
|
68 |
+
modality_args[1].create_module(remove_checkpoint_path=remove_checkpoint_path),
|
69 |
+
]
|
70 |
+
)
|
71 |
+
else:
|
72 |
+
modalities[modality_name] = modality_args.create_module(remove_checkpoint_path=remove_checkpoint_path)
|
73 |
+
return th.nn.ModuleDict(modalities)
|
74 |
+
|
75 |
+
|
76 |
+
class EncodePredictDecodeModule(pl.LightningModule):
|
77 |
+
"""
|
78 |
+
Base-class for models that encode, predict and decode.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
context_encoders: A dictionary mapping modality names to their encoder-decoders.
|
82 |
+
If value is a pl.LightningModule, then same model is used for encoding and decoding.
|
83 |
+
If value is a tuple of two pl.LightningModule, then first is used for encoding and second for decoding.
|
84 |
+
condition_encoders: Same as `context_encoders`, but for conditions.
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
predictor_args: LightningModuleCreationArgs,
|
90 |
+
context_encoders: th.nn.ModuleDict,
|
91 |
+
condition_encoders: Optional[th.nn.ModuleDict] = None,
|
92 |
+
):
|
93 |
+
if condition_encoders is None:
|
94 |
+
condition_encoders = th.nn.ModuleDict(dict())
|
95 |
+
self._assert_encoders(context_encoders)
|
96 |
+
self._assert_encoders(condition_encoders)
|
97 |
+
super().__init__()
|
98 |
+
|
99 |
+
self.context_encoders = context_encoders
|
100 |
+
self.condition_encoders = condition_encoders
|
101 |
+
|
102 |
+
self.context_world_space, self.context_encoder_space = self._get_spaces_from_encoders(context_encoders)
|
103 |
+
self.condition_world_space, self.condition_encoder_space = self._get_spaces_from_encoders(condition_encoders)
|
104 |
+
|
105 |
+
self.predictor = predictor_args.create_module(context_space=self.context_encoder_space, condition_space=self.condition_encoder_space)
|
106 |
+
|
107 |
+
def _assert_encoders(self, encoders: th.nn.ModuleDict) -> None:
|
108 |
+
"""Check that encoder dictionary is valid"""
|
109 |
+
assert isinstance(encoders, th.nn.ModuleDict), f"Invalid type for encoders: {type(encoders)}. Expected th.nn.ModuleDict"
|
110 |
+
for modality_name, encoder in encoders.items():
|
111 |
+
assert isinstance(encoder, EncoderDecoderBase) or isinstance(
|
112 |
+
encoder, th.nn.ModuleList
|
113 |
+
), f"Invalid type for modality {modality_name}: {type(encoder)}. Expected EncoderDecoderBase or Tuple[EncoderDecoderBase]"
|
114 |
+
if isinstance(encoder, th.nn.ModuleList):
|
115 |
+
assert len(encoder) == 2, f"Invalid number of arguments for modality {modality_name}: {len(encoder)}. Expected two (encoder, decoder)"
|
116 |
+
assert isinstance(
|
117 |
+
encoder[0], EncoderDecoderBase
|
118 |
+
), f"Invalid type for encoder of modality {modality_name}: {type(encoder[0])}. Expected EncoderDecoderBase"
|
119 |
+
assert isinstance(
|
120 |
+
encoder[1], EncoderDecoderBase
|
121 |
+
), f"Invalid type for decoder of modality {modality_name}: {type(encoder[1])}. Expected EncoderDecoderBase"
|
122 |
+
|
123 |
+
def _get_spaces_from_encoders(self, encoders: th.nn.ModuleDict) -> Tuple[TensorDictSpace, TensorDictSpace]:
|
124 |
+
"""
|
125 |
+
Given a modality dictionary mapping modality names to their encoders and decoders,
|
126 |
+
extract the world space and encoder space,
|
127 |
+
"""
|
128 |
+
world_spaces = {}
|
129 |
+
encoder_spaces = {}
|
130 |
+
for modality_name, modality in encoders.items():
|
131 |
+
if isinstance(modality, EncoderDecoderBase):
|
132 |
+
encoder_spaces[modality_name] = modality.encoder_space
|
133 |
+
world_spaces[modality_name] = modality.world_space
|
134 |
+
elif isinstance(modality, th.nn.ModuleList):
|
135 |
+
assert len(modality) == 2, f"Invalid number of modules for modality {modality_name}: {len(modality)}. Expected 2."
|
136 |
+
# Make sure that both encoder and decoder spaces match the expected space
|
137 |
+
encoder_encoder_space = modality[0].encoder_space
|
138 |
+
decoder_encoder_space = modality[1].encoder_space
|
139 |
+
assert (
|
140 |
+
encoder_encoder_space == decoder_encoder_space
|
141 |
+
), f"Encoder and decoder spaces for modality {modality_name} do not match: {encoder_encoder_space} != {decoder_encoder_space}"
|
142 |
+
encoder_world_space = modality[0].world_space
|
143 |
+
decoder_world_space = modality[1].world_space
|
144 |
+
assert (
|
145 |
+
encoder_world_space == decoder_world_space
|
146 |
+
), f"Encoder and decoder world spaces for modality {modality_name} do not match: {encoder_world_space} != {decoder_world_space}"
|
147 |
+
encoder_spaces[modality_name] = encoder_encoder_space
|
148 |
+
world_spaces[modality_name] = encoder_world_space
|
149 |
+
else:
|
150 |
+
raise TypeError(f"Invalid type for modality {modality_name}: {type(modality)}. Expected EncoderDecoderBase or th.nn.ModuleList")
|
151 |
+
return TensorDictSpace(world_spaces), TensorDictSpace(encoder_spaces)
|
152 |
+
|
153 |
+
def _encode(self, input_td: TensorDict, encoders: th.nn.ModuleDict, space: TensorDictSpace) -> TensorDict:
|
154 |
+
"""
|
155 |
+
Encode input_td into encoder space using the given encoders.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
input_td: A tensordict mapping modality names to their inputs.
|
159 |
+
encoders: A dictionary mapping modality names to their encoders.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
An encoded tensordict.
|
163 |
+
"""
|
164 |
+
encoded_context = {}
|
165 |
+
preceding_dims = space.get_preceding_dimensions(input_td, allow_key_subset=True)
|
166 |
+
for modality_name in input_td.keys():
|
167 |
+
encoder = encoders[modality_name]
|
168 |
+
if isinstance(encoder, EncoderDecoderBase):
|
169 |
+
encoded_context[modality_name] = encoder.encode(input_td[modality_name])
|
170 |
+
elif isinstance(encoder, th.nn.ModuleList):
|
171 |
+
encoded_context[modality_name] = encoder[0].encode(input_td[modality_name])
|
172 |
+
else:
|
173 |
+
raise TypeError(f"Invalid type for modality {modality_name}: {type(encoder)}. Expected EncoderDecoderBase or th.nn.ModuleList")
|
174 |
+
return TensorDict(encoded_context, batch_size=preceding_dims)
|
175 |
+
|
176 |
+
def _decode(self, input_td: TensorDict, encoders: th.nn.ModuleDict, space: TensorDictSpace) -> TensorDict:
|
177 |
+
"""
|
178 |
+
Decode input_td into the original space using the given encoders.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
input_td: A tensordict mapping modality names to their encoded inputs.
|
182 |
+
encoders: A dictionary mapping modality names to their encoders.
|
183 |
+
|
184 |
+
Returns:
|
185 |
+
A decoded tensordict.
|
186 |
+
"""
|
187 |
+
decoded_context = {}
|
188 |
+
preceding_dims = space.get_preceding_dimensions(input_td, allow_key_subset=True)
|
189 |
+
for modality_name in input_td.keys():
|
190 |
+
encoder = encoders[modality_name]
|
191 |
+
if isinstance(encoder, EncoderDecoderBase):
|
192 |
+
decoded_context[modality_name] = encoder.decode(input_td[modality_name])
|
193 |
+
elif isinstance(encoder, th.nn.ModuleList):
|
194 |
+
decoded_context[modality_name] = encoder[1].decode(input_td[modality_name])
|
195 |
+
else:
|
196 |
+
raise TypeError(f"Invalid type for modality {modality_name}: {type(encoder)}. Expected EncoderDecoderBase or th.nn.ModuleList")
|
197 |
+
return TensorDict(decoded_context, batch_size=preceding_dims)
|
198 |
+
|
199 |
+
def encode_context(self, context: TensorDict) -> TensorDict:
|
200 |
+
"""
|
201 |
+
Encode the given context into the encoder space.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
context: A tensordict mapping modality names to their inputs.
|
205 |
+
|
206 |
+
Returns:
|
207 |
+
An encoded tensordict.
|
208 |
+
"""
|
209 |
+
assert self.context_world_space.contains(context, allow_key_subset=True), f"Context {context} is not contained in context world space {self.context_world_space}"
|
210 |
+
return self._encode(context, self.context_encoders, self.context_world_space)
|
211 |
+
|
212 |
+
def decode_context(self, encoded_context: TensorDict) -> TensorDict:
|
213 |
+
"""
|
214 |
+
Decode the given encoded context into the original space.
|
215 |
+
|
216 |
+
Args:
|
217 |
+
encoded_context: A tensordict mapping modality names to their encoded inputs.
|
218 |
+
|
219 |
+
Returns:
|
220 |
+
A decoded tensordict.
|
221 |
+
"""
|
222 |
+
assert self.context_encoder_space.contains(
|
223 |
+
encoded_context,
|
224 |
+
allow_key_subset=True,
|
225 |
+
), f"Encoded context {encoded_context} is not contained in context encoder space {self.context_encoder_space}"
|
226 |
+
return self._decode(encoded_context, self.context_encoders, self.context_encoder_space)
|
227 |
+
|
228 |
+
def encode_condition(self, condition: TensorDict) -> TensorDict:
|
229 |
+
"""
|
230 |
+
Encode the given condition into the encoder space.
|
231 |
+
|
232 |
+
Args:
|
233 |
+
condition: A tensordict mapping modality names to their inputs.
|
234 |
+
|
235 |
+
Returns:
|
236 |
+
An encoded tensordict.
|
237 |
+
"""
|
238 |
+
assert self.condition_world_space.contains(
|
239 |
+
condition, allow_key_subset=True
|
240 |
+
), f"Condition {condition} is not contained in condition world space {self.condition_world_space}"
|
241 |
+
return self._encode(condition, self.condition_encoders, self.condition_world_space)
|
242 |
+
|
243 |
+
def decode_condition(self, encoded_condition: TensorDict) -> TensorDict:
|
244 |
+
"""
|
245 |
+
Decode the given encoded condition into the original space.
|
246 |
+
|
247 |
+
Args:
|
248 |
+
encoded_condition: A tensordict mapping modality names to their encoded inputs.
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
A decoded tensordict.
|
252 |
+
"""
|
253 |
+
assert self.condition_encoder_space.contains(
|
254 |
+
encoded_condition, allow_key_subset=True
|
255 |
+
), f"Encoded condition {encoded_condition} is not contained in condition encoder space {self.condition_encoder_space}"
|
256 |
+
return self._decode(encoded_condition, self.condition_encoders, self.condition_encoder_space)
|