Commit
·
012c9b1
1
Parent(s):
e06ad17
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- av_hubert/.gitmodules +3 -0
- av_hubert/CODE_OF_CONDUCT.md +80 -0
- av_hubert/CONTRIBUTING.md +31 -0
- av_hubert/LICENSE +159 -0
- av_hubert/README.md +164 -0
- av_hubert/assets/lipreading.gif +3 -0
- av_hubert/avhubert/__init__.py +10 -0
- av_hubert/avhubert/clustering/README.md +100 -0
- av_hubert/avhubert/clustering/dump_hubert_feature.py +177 -0
- av_hubert/avhubert/clustering/dump_km_label.py +99 -0
- av_hubert/avhubert/clustering/dump_mfcc_feature.py +117 -0
- av_hubert/avhubert/clustering/learn_kmeans.py +147 -0
- av_hubert/avhubert/clustering/requirements.txt +6 -0
- av_hubert/avhubert/clustering/submit_cluster.py +132 -0
- av_hubert/avhubert/conf/av-finetune/base_noise_pt_noise_ft_30h.yaml +121 -0
- av_hubert/avhubert/conf/av-finetune/base_noise_pt_noise_ft_433h.yaml +121 -0
- av_hubert/avhubert/conf/av-finetune/large_noise_pt_noise_ft_30h.yaml +124 -0
- av_hubert/avhubert/conf/av-finetune/large_noise_pt_noise_ft_433h.yaml +124 -0
- av_hubert/avhubert/conf/finetune/base_lrs3_30h.yaml +118 -0
- av_hubert/avhubert/conf/finetune/base_lrs3_433h.yaml +118 -0
- av_hubert/avhubert/conf/finetune/base_vox_30h.yaml +118 -0
- av_hubert/avhubert/conf/finetune/base_vox_433h.yaml +118 -0
- av_hubert/avhubert/conf/finetune/large_lrs3_30h.yaml +121 -0
- av_hubert/avhubert/conf/finetune/large_lrs3_433h.yaml +121 -0
- av_hubert/avhubert/conf/finetune/large_vox_30h.yaml +121 -0
- av_hubert/avhubert/conf/finetune/large_vox_433h.yaml +121 -0
- av_hubert/avhubert/conf/finetune/self_large_vox_30h.yaml +121 -0
- av_hubert/avhubert/conf/finetune/self_large_vox_433h.yaml +121 -0
- av_hubert/avhubert/conf/pretrain/base_lrs3_iter1.yaml +112 -0
- av_hubert/avhubert/conf/pretrain/base_lrs3_iter2.yaml +112 -0
- av_hubert/avhubert/conf/pretrain/base_lrs3_iter3.yaml +112 -0
- av_hubert/avhubert/conf/pretrain/base_lrs3_iter4.yaml +112 -0
- av_hubert/avhubert/conf/pretrain/base_lrs3_iter5.yaml +112 -0
- av_hubert/avhubert/conf/pretrain/base_vox_iter1.yaml +113 -0
- av_hubert/avhubert/conf/pretrain/base_vox_iter2.yaml +113 -0
- av_hubert/avhubert/conf/pretrain/base_vox_iter3.yaml +113 -0
- av_hubert/avhubert/conf/pretrain/base_vox_iter4.yaml +112 -0
- av_hubert/avhubert/conf/pretrain/base_vox_iter5.yaml +113 -0
- av_hubert/avhubert/conf/pretrain/large_lrs3_iter5.yaml +117 -0
- av_hubert/avhubert/conf/pretrain/large_vox_iter5.yaml +117 -0
- av_hubert/avhubert/conf/pretrain/noise_base_vox_iter5.yaml +115 -0
- av_hubert/avhubert/conf/pretrain/noise_large_vox_iter5.yaml +119 -0
- av_hubert/avhubert/conf/s2s_decode.yaml +23 -0
- av_hubert/avhubert/decoder.py +243 -0
- av_hubert/avhubert/hubert.py +779 -0
- av_hubert/avhubert/hubert_asr.py +521 -0
- av_hubert/avhubert/hubert_criterion.py +169 -0
- av_hubert/avhubert/hubert_dataset.py +529 -0
- av_hubert/avhubert/hubert_pretraining.py +400 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
av_hubert/assets/lipreading.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
av_hubert/fairseq/docs/fairseq.gif filter=lfs diff=lfs merge=lfs -text
|
av_hubert/.gitmodules
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "fairseq"]
|
2 |
+
path = fairseq
|
3 |
+
url = https://github.com/pytorch/fairseq
|
av_hubert/CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code of Conduct
|
2 |
+
|
3 |
+
## Our Pledge
|
4 |
+
|
5 |
+
In the interest of fostering an open and welcoming environment, we as
|
6 |
+
contributors and maintainers pledge to make participation in our project and
|
7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
9 |
+
level of experience, education, socio-economic status, nationality, personal
|
10 |
+
appearance, race, religion, or sexual identity and orientation.
|
11 |
+
|
12 |
+
## Our Standards
|
13 |
+
|
14 |
+
Examples of behavior that contributes to creating a positive environment
|
15 |
+
include:
|
16 |
+
|
17 |
+
* Using welcoming and inclusive language
|
18 |
+
* Being respectful of differing viewpoints and experiences
|
19 |
+
* Gracefully accepting constructive criticism
|
20 |
+
* Focusing on what is best for the community
|
21 |
+
* Showing empathy towards other community members
|
22 |
+
|
23 |
+
Examples of unacceptable behavior by participants include:
|
24 |
+
|
25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
26 |
+
advances
|
27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
28 |
+
* Public or private harassment
|
29 |
+
* Publishing others' private information, such as a physical or electronic
|
30 |
+
address, without explicit permission
|
31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
32 |
+
professional setting
|
33 |
+
|
34 |
+
## Our Responsibilities
|
35 |
+
|
36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
38 |
+
response to any instances of unacceptable behavior.
|
39 |
+
|
40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
44 |
+
threatening, offensive, or harmful.
|
45 |
+
|
46 |
+
## Scope
|
47 |
+
|
48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
49 |
+
an individual is representing the project or its community in public spaces.
|
50 |
+
Examples of representing a project or community include using an official
|
51 |
+
project e-mail address, posting via an official social media account, or acting
|
52 |
+
as an appointed representative at an online or offline event. Representation of
|
53 |
+
a project may be further defined and clarified by project maintainers.
|
54 |
+
|
55 |
+
This Code of Conduct also applies outside the project spaces when there is a
|
56 |
+
reasonable belief that an individual's behavior may have a negative impact on
|
57 |
+
the project or its community.
|
58 |
+
|
59 |
+
## Enforcement
|
60 |
+
|
61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
62 |
+
reported by contacting the project team at <[email protected]>. All
|
63 |
+
complaints will be reviewed and investigated and will result in a response that
|
64 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
65 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
66 |
+
Further details of specific enforcement policies may be posted separately.
|
67 |
+
|
68 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
69 |
+
faith may face temporary or permanent repercussions as determined by other
|
70 |
+
members of the project's leadership.
|
71 |
+
|
72 |
+
## Attribution
|
73 |
+
|
74 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
75 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
76 |
+
|
77 |
+
[homepage]: https://www.contributor-covenant.org
|
78 |
+
|
79 |
+
For answers to common questions about this code of conduct, see
|
80 |
+
https://www.contributor-covenant.org/faq
|
av_hubert/CONTRIBUTING.md
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing to av_hubert
|
2 |
+
We want to make contributing to this project as easy and transparent as
|
3 |
+
possible.
|
4 |
+
|
5 |
+
## Pull Requests
|
6 |
+
We actively welcome your pull requests.
|
7 |
+
|
8 |
+
1. Fork the repo and create your branch from `main`.
|
9 |
+
2. If you've added code that should be tested, add tests.
|
10 |
+
3. If you've changed APIs, update the documentation.
|
11 |
+
4. Ensure the test suite passes.
|
12 |
+
5. Make sure your code lints.
|
13 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
14 |
+
|
15 |
+
## Contributor License Agreement ("CLA")
|
16 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
17 |
+
to do this once to work on any of Facebook's open source projects.
|
18 |
+
|
19 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
20 |
+
|
21 |
+
## Issues
|
22 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
23 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
24 |
+
|
25 |
+
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
26 |
+
disclosure of security bugs. In those cases, please go through the process
|
27 |
+
outlined on that page and do not file a public issue.
|
28 |
+
|
29 |
+
## License
|
30 |
+
By contributing to av_hubert, you agree that your contributions will be licensed
|
31 |
+
under the LICENSE file in the root directory of this source tree.
|
av_hubert/LICENSE
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AV-HuBERT LICENSE AGREEMENT
|
2 |
+
|
3 |
+
This License Agreement (as may be amended in accordance with this License
|
4 |
+
Agreement, “License”), between you (“Licensee” or “you”) and Meta Platforms,
|
5 |
+
Inc. (“Meta” or “we”) applies to your use of any computer program, algorithm,
|
6 |
+
source code, object code, or software that is made available by Meta under this
|
7 |
+
License (“Software”) and any specifications, manuals, documentation, and other
|
8 |
+
written information provided by Meta related to the Software (“Documentation”).
|
9 |
+
|
10 |
+
By clicking “I Accept” below or by using the Software, you agree to the terms
|
11 |
+
of this License. If you do not agree to this License, then you do not have any
|
12 |
+
rights to use the Software or Documentation (collectively, the “Software
|
13 |
+
Products”), and you must immediately cease using the Software Products.
|
14 |
+
|
15 |
+
1. LICENSE GRANT a. Subject to your compliance with the Documentation and
|
16 |
+
Sections 2, 3, and 5, Meta grants you a non-exclusive, worldwide,
|
17 |
+
non-transferable, non-sublicensable, revocable, royalty free and limited
|
18 |
+
license under Meta’s copyright interests to reproduce, distribute, and create
|
19 |
+
derivative works of the Software solely for your non-commercial research
|
20 |
+
purposes. The foregoing license is personal to you, and you may not assign or
|
21 |
+
sublicense this License or any other rights or obligations under this License
|
22 |
+
without Meta’s prior written consent; any such assignment or sublicense will be
|
23 |
+
void and will automatically and immediately terminate this License.
|
24 |
+
|
25 |
+
b. You may make a reasonable number of copies of the Documentation solely for
|
26 |
+
use in connection with the license to the Software granted above.
|
27 |
+
|
28 |
+
c. The grant of rights expressly set forth in this Section 1 (License Grant)
|
29 |
+
are the complete grant of rights to you in the Software Products, and no other
|
30 |
+
licenses are granted, whether by waiver, estoppel, implication, equity or
|
31 |
+
otherwise. Meta and its licensors reserve all rights not expressly granted by
|
32 |
+
this License.
|
33 |
+
|
34 |
+
2. RESTRICTIONS
|
35 |
+
|
36 |
+
You will not, and will not permit, assist or cause any third party to:
|
37 |
+
|
38 |
+
a. use, modify, copy, reproduce, create derivative works of, or distribute the
|
39 |
+
Software Products (or any derivative works thereof, works incorporating the
|
40 |
+
Software Products, or any data produced by the Software), in whole or in part,
|
41 |
+
for (i) any commercial or production purposes, (ii) military purposes or in the
|
42 |
+
service of nuclear technology, (iii) purposes of surveillance, including any
|
43 |
+
research or development relating to surveillance, (iv) biometric processing,
|
44 |
+
(v) in any manner that infringes, misappropriates, or otherwise violates any
|
45 |
+
third-party rights, or (vi) in any manner that violates any applicable law,
|
46 |
+
including any privacy or security laws, rules, regulations, directives, or
|
47 |
+
governmental requirements (including the General Data Privacy Regulation
|
48 |
+
(Regulation (EU) 2016/679), the California Consumer Privacy Act, and any and
|
49 |
+
all laws governing the processing of biometric information), as well as all
|
50 |
+
amendments and successor laws to any of the foregoing;
|
51 |
+
|
52 |
+
b. decompile, disassemble, or reverse-engineer the Software, in whole or in
|
53 |
+
part;
|
54 |
+
|
55 |
+
c. alter or remove copyright and other proprietary notices which appear on or
|
56 |
+
in the Software Products;
|
57 |
+
|
58 |
+
d. utilize any equipment, device, software, or other means to circumvent or
|
59 |
+
remove any security or protection used by Meta in connection with the Software,
|
60 |
+
or to circumvent or remove any usage restrictions, or to enable functionality
|
61 |
+
disabled by Meta; or
|
62 |
+
|
63 |
+
e. offer or impose any terms on the Software Products that alter, restrict, or
|
64 |
+
are inconsistent with the terms of this License.
|
65 |
+
|
66 |
+
3. ATTRIBUTION
|
67 |
+
|
68 |
+
Together with any copies of the Software Products (as well as derivative works
|
69 |
+
thereof or works incorporating the Software Products) that you distribute, you
|
70 |
+
must provide (i) a copy of this License, and (ii) the following attribution
|
71 |
+
notice: “AV-HuBERT is licensed under the AV-HuBERT license, Copyright (c) Meta
|
72 |
+
Platforms, Inc. All Rights Reserved.”
|
73 |
+
|
74 |
+
4. DISCLAIMERS
|
75 |
+
|
76 |
+
THE SOFTWARE PRODUCTS ARE PROVIDED “AS IS” and “WITH ALL FAULTS” WITH NO
|
77 |
+
WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. META EXPRESSLY DISCLAIMS ALL
|
78 |
+
REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM,
|
79 |
+
USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE SOFTWARE PRODUCTS,
|
80 |
+
INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY,
|
81 |
+
FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR
|
82 |
+
NON-INFRINGEMENT. META MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE SOFTWARE
|
83 |
+
PRODUCTS WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR
|
84 |
+
PRODUCE ANY PARTICULAR RESULTS.
|
85 |
+
|
86 |
+
5. LIMITATION OF LIABILITY
|
87 |
+
|
88 |
+
TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL META BE LIABLE TO YOU
|
89 |
+
(A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE,
|
90 |
+
STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY
|
91 |
+
INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR
|
92 |
+
LOST PROFITS, EVEN IF META HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
93 |
+
THE SOFTWARE PRODUCTS, THEIR CONSTITUENT COMPONENTS, AND ANY OUTPUT
|
94 |
+
(COLLECTIVELY, “SOFTWARE MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN
|
95 |
+
ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE SOFTWARE MATERIALS
|
96 |
+
COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON,
|
97 |
+
INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S PRIVACY
|
98 |
+
RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A
|
99 |
+
“HIGH-RISK USE”). IF YOU ELECT TO USE ANY OF THE SOFTWARE MATERIALS FOR A
|
100 |
+
HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT
|
101 |
+
APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN
|
102 |
+
CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT
|
103 |
+
IN ANY OF THE SOFTWARE MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY
|
104 |
+
THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR
|
105 |
+
THE FIELD OF THE HIGH-RISK USE.
|
106 |
+
|
107 |
+
6. TERMINATION; SURVIVAL
|
108 |
+
|
109 |
+
a. This License will automatically terminate upon any breach by you of the
|
110 |
+
terms of this License.
|
111 |
+
|
112 |
+
b. We may terminate this License, in whole or in part, at any time upon notice
|
113 |
+
(including electronic) to you.
|
114 |
+
|
115 |
+
c. The following sections survive termination of this License: 2
|
116 |
+
(Restrictions), 3 (Attribution), 4 (Disclaimers), 5 (Limitation on Liability),
|
117 |
+
6 (Termination; Survival), 7 (Third Party Materials), 8 (Trademarks), 9
|
118 |
+
(Applicable Law; Dispute Resolution), and 10 (Miscellaneous).
|
119 |
+
|
120 |
+
7. THIRD PARTY MATERIALS
|
121 |
+
|
122 |
+
The Software Products may contain third-party software or other components
|
123 |
+
(including free and open source software) (all of the foregoing, “Third Party
|
124 |
+
Materials”), which are subject to the license terms of the respective
|
125 |
+
third-party licensors. Your dealings or correspondence with third parties and
|
126 |
+
your use of or interaction with any Third Party Materials are solely between
|
127 |
+
you and the third party. Meta does not control or endorse, and makes no
|
128 |
+
representations or warranties regarding, any Third Party Materials, and your
|
129 |
+
access to and use of such Third Party Materials are at your own risk.
|
130 |
+
|
131 |
+
8. TRADEMARKS
|
132 |
+
|
133 |
+
Licensee has not been granted any trademark license as part of this License and
|
134 |
+
may not use any name or mark associated with Meta without the prior written
|
135 |
+
permission of Meta, except to the extent necessary to make the reference
|
136 |
+
required by the “ATTRIBUTION” section of this Agreement.
|
137 |
+
|
138 |
+
9. APPLICABLE LAW; DISPUTE RESOLUTION
|
139 |
+
|
140 |
+
This License will be governed and construed under the laws of the State of
|
141 |
+
California without regard to conflicts of law provisions. Any suit or
|
142 |
+
proceeding arising out of or relating to this License will be brought in the
|
143 |
+
federal or state courts, as applicable, in San Mateo County, California, and
|
144 |
+
each party irrevocably submits to the jurisdiction and venue of such courts.
|
145 |
+
|
146 |
+
10. MISCELLANEOUS
|
147 |
+
|
148 |
+
If any provision or part of a provision of this License is unlawful, void or
|
149 |
+
unenforceable, that provision or part of the provision is deemed severed from
|
150 |
+
this License, and will not affect the validity and enforceability of any
|
151 |
+
remaining provisions. The failure of Meta to exercise or enforce any right or
|
152 |
+
provision of this License will not operate as a waiver of such right or
|
153 |
+
provision. This License does not confer any third-party beneficiary rights upon
|
154 |
+
any other person or entity. This License, together with the Documentation,
|
155 |
+
contains the entire understanding between you and Meta regarding the subject
|
156 |
+
matter of this License, and supersedes all other written or oral agreements and
|
157 |
+
understandings between you and Meta regarding such subject matter. No change or
|
158 |
+
addition to any provision of this License will be binding unless it is in
|
159 |
+
writing and signed by an authorized representative of both you and Meta.
|
av_hubert/README.md
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AV-HuBERT (Audio-Visual Hidden Unit BERT)
|
2 |
+
[Learning Audio-Visual Speech Representation by Masked Multimodal Cluster Prediction](https://arxiv.org/abs/2201.02184)
|
3 |
+
|
4 |
+
[Robust Self-Supervised Audio-Visual Speech Recognition](https://arxiv.org/abs/2201.01763)
|
5 |
+
|
6 |
+
![lip-reading](assets/lipreading.gif)
|
7 |
+
|
8 |
+
## Introduction
|
9 |
+
AV-HuBERT is a self-supervised representation learning framework for audio-visual speech. It achieves state-of-the-art results in lip reading, ASR and audio-visual speech recognition on the LRS3 audio-visual speech benchmark.
|
10 |
+
|
11 |
+
If you find AV-HuBERT useful in your research, please use the following BibTeX entry for citation.
|
12 |
+
```BibTeX
|
13 |
+
@article{shi2022avhubert,
|
14 |
+
author = {Bowen Shi and Wei-Ning Hsu and Kushal Lakhotia and Abdelrahman Mohamed},
|
15 |
+
title = {Learning Audio-Visual Speech Representation by Masked Multimodal Cluster Prediction},
|
16 |
+
journal = {arXiv preprint arXiv:2201.02184}
|
17 |
+
year = {2022}
|
18 |
+
}
|
19 |
+
|
20 |
+
@article{shi2022avsr,
|
21 |
+
author = {Bowen Shi and Wei-Ning Hsu and Abdelrahman Mohamed},
|
22 |
+
title = {Robust Self-Supervised Audio-Visual Speech Recognition},
|
23 |
+
journal = {arXiv preprint arXiv:2201.01763}
|
24 |
+
year = {2022}
|
25 |
+
}
|
26 |
+
```
|
27 |
+
|
28 |
+
## License
|
29 |
+
|
30 |
+
AV-HuBERT LICENSE AGREEMENT
|
31 |
+
|
32 |
+
This License Agreement (as may be amended in accordance with this License
|
33 |
+
Agreement, “License”), between you (“Licensee” or “you”) and Meta Platforms,
|
34 |
+
Inc. (“Meta” or “we”) applies to your use of any computer program, algorithm,
|
35 |
+
source code, object code, or software that is made available by Meta under this
|
36 |
+
License (“Software”) and any specifications, manuals, documentation, and other
|
37 |
+
written information provided by Meta related to the Software (“Documentation”).
|
38 |
+
|
39 |
+
By using the Software, you agree to the terms of [this
|
40 |
+
License](https://github.com/facebookresearch/av_hubert/blob/main/LICENSE). If
|
41 |
+
you do not agree to this License, then you do not have any rights to use the
|
42 |
+
Software or Documentation (collectively, the “Software Products”), and you must
|
43 |
+
immediately cease using the Software Products.
|
44 |
+
|
45 |
+
## Pre-trained and fine-tuned models
|
46 |
+
|
47 |
+
Please find the checkpoints [here](http://facebookresearch.github.io/av_hubert)
|
48 |
+
|
49 |
+
## Demo
|
50 |
+
Run our lip-reading demo using Colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1bNXkfpHiVHzXQH8WjGhzQ-fsDxolpUjD)
|
51 |
+
|
52 |
+
## Installation
|
53 |
+
First, create a conda virtual environment and activate it:
|
54 |
+
```
|
55 |
+
conda create -n avhubert python=3.8 -y
|
56 |
+
conda activate avhubert
|
57 |
+
```
|
58 |
+
Then, clone this directory:
|
59 |
+
```
|
60 |
+
git clone https://github.com/facebookresearch/av_hubert.git
|
61 |
+
cd avhubert
|
62 |
+
git submodule init
|
63 |
+
git submodule update
|
64 |
+
```
|
65 |
+
|
66 |
+
Lastly, install Fairseq and the other packages:
|
67 |
+
```
|
68 |
+
pip install -r requirements.txt
|
69 |
+
cd fairseq
|
70 |
+
pip install --editable ./
|
71 |
+
```
|
72 |
+
|
73 |
+
## Load a pretrained model
|
74 |
+
```sh
|
75 |
+
$ cd avhubert
|
76 |
+
$ python
|
77 |
+
>>> import fairseq
|
78 |
+
>>> import hubert_pretraining, hubert
|
79 |
+
>>> ckpt_path = "/path/to/the/checkpoint.pt"
|
80 |
+
>>> models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
|
81 |
+
>>> model = models[0]
|
82 |
+
```
|
83 |
+
|
84 |
+
## Train a new model
|
85 |
+
|
86 |
+
### Data preparation
|
87 |
+
|
88 |
+
Follow the steps in [`preparation`](avhubert/preparation/) to pre-process:
|
89 |
+
- LRS3 and VoxCeleb2 datasets
|
90 |
+
|
91 |
+
Follow the steps in [`clustering`](avhubert/clustering/) (pre-train only) to create:
|
92 |
+
- `{train,valid}.km` frame-aligned pseudo label files.
|
93 |
+
The `label_rate` is the same as the feature frame rate used for clustering,
|
94 |
+
which is 100Hz for MFCC features and 25Hz for AV-HuBERT features by default.
|
95 |
+
|
96 |
+
### Pre-train an AV-HuBERT model
|
97 |
+
|
98 |
+
Suppose `{train,valid}.tsv` are saved at `/path/to/data`, `{train,valid}.km`
|
99 |
+
are saved at `/path/to/labels`, the configuration file is saved at `/path/to/conf/conf-name`, and the label rate is 100Hz.
|
100 |
+
|
101 |
+
To train a model, run:
|
102 |
+
```sh
|
103 |
+
$ cd avhubert
|
104 |
+
$ fairseq-hydra-train --config-dir /path/to/conf/ --config-name conf-name \
|
105 |
+
task.data=/path/to/data task.label_dir=/path/to/label \
|
106 |
+
model.label_rate=100 hydra.run.dir=/path/to/experiment/pretrain/ \
|
107 |
+
common.user_dir=`pwd`
|
108 |
+
```
|
109 |
+
|
110 |
+
### Finetune an AV-HuBERT model with Seq2Seq
|
111 |
+
Suppose `{train,valid}.tsv` are saved at `/path/to/data`, `{train,valid}.wrd`
|
112 |
+
are saved at `/path/to/labels`, the configuration file is saved at `/path/to/conf/conf-name`.
|
113 |
+
|
114 |
+
To fine-tune a pre-trained HuBERT model at `/path/to/checkpoint`, run:
|
115 |
+
```sh
|
116 |
+
$ cd avhubert
|
117 |
+
$ fairseq-hydra-train --config-dir /path/to/conf/ --config-name conf-name \
|
118 |
+
task.data=/path/to/data task.label_dir=/path/to/label \
|
119 |
+
task.tokenizer_bpe_model=/path/to/tokenizer model.w2v_path=/path/to/checkpoint \
|
120 |
+
hydra.run.dir=/path/to/experiment/finetune/ common.user_dir=`pwd`
|
121 |
+
```
|
122 |
+
|
123 |
+
### Decode an AV-HuBERT model
|
124 |
+
Suppose the `test.tsv` and `test.wrd` are the video list and transcripts of
|
125 |
+
the split to be decoded, saved at `/path/to/data`, and the fine-tuned model is
|
126 |
+
saved at `/path/to/checkpoint`.
|
127 |
+
|
128 |
+
#### Seq2Seq decoding
|
129 |
+
|
130 |
+
`task.normalize` needs to be consistent with the value used during fine-tuning.
|
131 |
+
Decoding results will be saved at
|
132 |
+
`/path/to/experiment/decode/s2s/test`.
|
133 |
+
|
134 |
+
```sh
|
135 |
+
$ cd avhubert
|
136 |
+
$ python -B infer_s2s.py --config-dir ./conf/ --config-name conf-name \
|
137 |
+
dataset.gen_subset=test common_eval.path=/path/to/checkpoint \
|
138 |
+
common_eval.results_path=/path/to/experiment/decode/s2s/test \
|
139 |
+
override.modalities=['video'] common.user_dir=`pwd`
|
140 |
+
```
|
141 |
+
|
142 |
+
The command above uses the default decoding hyperparameter, which can be found
|
143 |
+
in `conf/s2s_decode.yaml`. `override.modalities` can be set to `['video']` (for lip reading),
|
144 |
+
or `['audio']` (for ASR) or `['audio','video']` (for audio-visual speech recognition).These parameters can be
|
145 |
+
configured from the command line. For example, to search with a beam size of
|
146 |
+
20, we can append the command above with `generation.beam=20`.
|
147 |
+
Important parameters include:
|
148 |
+
- generation.beam
|
149 |
+
- generation.lenpen
|
150 |
+
|
151 |
+
#### Different test set
|
152 |
+
If your test data are stored in a different directory with the training data, append the following to the above command.
|
153 |
+
|
154 |
+
`+override.data=/path/to/test +override.label_dir=/path/to/test`
|
155 |
+
|
156 |
+
, where `/path/to/test` contains `test.{tsv,wrd}`. This is useful when you want to test with the fine-tuned checkpoints we provide.
|
157 |
+
|
158 |
+
#### Test under noisy environment
|
159 |
+
If you want to test your model under noisy environment, append the following to the above command.
|
160 |
+
|
161 |
+
`+override.noise_wav=/path/to/noise override.noise_prob=1 override.noise_snr={snr}`
|
162 |
+
|
163 |
+
`{snr}` is the signal-to-noise ratio (SNR) and `/path/to/noise` is a folder containing noise manifest files (`/path/to/noise/{valid,test}.tsv`). See [`preparation`](avhubert/preparation/) for setting up this folder.
|
164 |
+
|
av_hubert/assets/lipreading.gif
ADDED
![]() |
Git LFS Details
|
av_hubert/avhubert/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .hubert import * # noqa
|
7 |
+
from .hubert_asr import * # noqa
|
8 |
+
from .hubert_dataset import *
|
9 |
+
from .hubert_pretraining import *
|
10 |
+
from .hubert_criterion import *
|
av_hubert/avhubert/clustering/README.md
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AV-HuBERT Label Preparation
|
2 |
+
|
3 |
+
This folder contains scripts for preparing AV-HUBERT labels from tsv files, the
|
4 |
+
steps are:
|
5 |
+
1. feature extraction
|
6 |
+
2. k-means clustering
|
7 |
+
3. k-means application
|
8 |
+
|
9 |
+
## Installation
|
10 |
+
To prepare labels, you need some additional packages:
|
11 |
+
```
|
12 |
+
pip install -r requirements.txt
|
13 |
+
```
|
14 |
+
|
15 |
+
## Data preparation
|
16 |
+
|
17 |
+
`*.tsv` files contains a list of audio, where each line is the root, and
|
18 |
+
following lines are the subpath and number of frames of each video and audio separated by `tab`:
|
19 |
+
```
|
20 |
+
<root-dir>
|
21 |
+
<id-1> <video-path-1> <audio-path-1> <video-number-frames-1> <audio-number-frames-1>
|
22 |
+
<id-2> <video-path-2> <audio-path-2> <video-number-frames-2> <audio-number-frames-2>
|
23 |
+
...
|
24 |
+
```
|
25 |
+
See [here](../preparation/) for data preparation for LRS3 and VoxCeleb2.
|
26 |
+
|
27 |
+
## Feature extraction
|
28 |
+
|
29 |
+
### MFCC feature
|
30 |
+
Suppose the tsv file is at `${tsv_dir}/${split}.tsv`. To extract 39-D
|
31 |
+
mfcc+delta+ddelta features for the 1st iteration AV-HuBERT training, run:
|
32 |
+
```sh
|
33 |
+
python dump_mfcc_feature.py ${tsv_dir} ${split} ${nshard} ${rank} ${feat_dir}
|
34 |
+
```
|
35 |
+
This would shard the tsv file into `${nshard}` and extract features for the
|
36 |
+
`${rank}`-th shard, where rank is an integer in `[0, nshard-1]`. Features would
|
37 |
+
be saved at `${feat_dir}/${split}_${rank}_${nshard}.{npy,len}`.
|
38 |
+
|
39 |
+
|
40 |
+
### AV-HuBERT feature
|
41 |
+
To extract features from the `${layer}`-th transformer layer of a trained
|
42 |
+
AV-HuBERT model saved at `${ckpt_path}`, run:
|
43 |
+
```sh
|
44 |
+
python dump_hubert_feature.py ${tsv_dir} ${split} ${ckpt_path} ${layer} ${nshard} ${rank} ${feat_dir} --user_dir `pwd`/../
|
45 |
+
```
|
46 |
+
Features would also be saved at `${feat_dir}/${split}_${rank}_${nshard}.{npy,len}`.
|
47 |
+
|
48 |
+
- if out-of-memory, decrease the chunk size with `--max_chunk`
|
49 |
+
|
50 |
+
|
51 |
+
## K-means clustering
|
52 |
+
To fit a k-means model with `${n_clusters}` clusters on 10% of the `${split}` data, run
|
53 |
+
```sh
|
54 |
+
python learn_kmeans.py ${feat_dir} ${split} ${nshard} ${km_path} ${n_cluster} --percent 0.1
|
55 |
+
```
|
56 |
+
This saves the k-means model to `${km_path}`.
|
57 |
+
|
58 |
+
- set `--precent -1` to use all data
|
59 |
+
- more kmeans options can be found with `-h` flag
|
60 |
+
|
61 |
+
|
62 |
+
## K-means application
|
63 |
+
To apply a trained k-means model `${km_path}` to obtain labels for `${split}`, run
|
64 |
+
```sh
|
65 |
+
python dump_km_label.py ${feat_dir} ${split} ${km_path} ${nshard} ${rank} ${lab_dir}
|
66 |
+
```
|
67 |
+
This would extract labels for the `${rank}`-th shard out of `${nshard}` shards
|
68 |
+
and dump them to `${lab_dir}/${split}_${rank}_${shard}.km`
|
69 |
+
|
70 |
+
|
71 |
+
Finally, merge shards for `${split}` by running
|
72 |
+
```sh
|
73 |
+
for rank in $(seq 0 $((nshard - 1))); do
|
74 |
+
cat $lab_dir/${split}_${rank}_${nshard}.km
|
75 |
+
done > $lab_dir/${split}.km
|
76 |
+
```
|
77 |
+
and create a dictionary of cluster indexes by running
|
78 |
+
```sh
|
79 |
+
for i in $(seq 1 $((n_cluster-1)));do
|
80 |
+
echo $i 10000
|
81 |
+
done > $lab_dir/dict.{mfcc,km}.txt
|
82 |
+
```
|
83 |
+
|
84 |
+
|
85 |
+
## Clustering on slurm
|
86 |
+
If you are on slurm, you can combine the above steps (feature extraction + K-means clustering + K-means application) by:
|
87 |
+
|
88 |
+
- MFCC feature cluster:
|
89 |
+
```sh
|
90 |
+
python submit_cluster.py --tsv ${tsv_dir} --output ${lab_dir} --ncluster ${n_cluster} \
|
91 |
+
--nshard ${nshard} --mfcc --percent 0.1
|
92 |
+
```
|
93 |
+
|
94 |
+
- AV-HuBERT feature cluster:
|
95 |
+
```sh
|
96 |
+
python submit_cluster.py --tsv ${tsv_dir} --output ${lab_dir} --ckpt ${ckpt_path} --nlayer ${layer} \
|
97 |
+
--ncluster ${n_cluster} --nshard ${nshard} --percent 0.1
|
98 |
+
```
|
99 |
+
|
100 |
+
This would dump labels to `${lab_dir}/{train,valid}.km`.
|
av_hubert/avhubert/clustering/dump_hubert_feature.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import math
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
|
12 |
+
import fairseq
|
13 |
+
import soundfile as sf
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import tqdm
|
17 |
+
from npy_append_array import NpyAppendArray
|
18 |
+
import numpy as np
|
19 |
+
from python_speech_features import logfbank
|
20 |
+
from scipy.io import wavfile
|
21 |
+
|
22 |
+
logging.basicConfig(
|
23 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
24 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
25 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
26 |
+
stream=sys.stdout,
|
27 |
+
)
|
28 |
+
logger = logging.getLogger("dump_hubert_feature")
|
29 |
+
|
30 |
+
|
31 |
+
class HubertFeatureReader(object):
|
32 |
+
def __init__(self, ckpt_path, layer, max_chunk=1600000, custom_utils=None):
|
33 |
+
(
|
34 |
+
model,
|
35 |
+
cfg,
|
36 |
+
task,
|
37 |
+
) = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
|
38 |
+
self.model = model[0].eval().cuda()
|
39 |
+
self.task = task
|
40 |
+
self.layer = layer
|
41 |
+
self.max_chunk = max_chunk
|
42 |
+
self.stack_order_audio = self.task.cfg.stack_order_audio
|
43 |
+
image_crop_size, image_mean, image_std = self.task.cfg.image_crop_size, self.task.cfg.image_mean, self.task.cfg.image_std
|
44 |
+
self.transform = custom_utils.Compose([
|
45 |
+
custom_utils.Normalize( 0.0,255.0 ),
|
46 |
+
custom_utils.CenterCrop((image_crop_size, image_crop_size)),
|
47 |
+
custom_utils.Normalize(image_mean, image_std) ])
|
48 |
+
|
49 |
+
self.custom_utils = custom_utils
|
50 |
+
logger.info(f"TASK CONFIG:\n{self.task.cfg}")
|
51 |
+
logger.info(f" max_chunk = {self.max_chunk}")
|
52 |
+
logger.info(f"Transform: {self.transform}")
|
53 |
+
|
54 |
+
def load_feature(self, mix_name, ref_len=None):
|
55 |
+
def stacker(feats, stack_order):
|
56 |
+
feat_dim = feats.shape[1]
|
57 |
+
if len(feats) % stack_order != 0:
|
58 |
+
res = stack_order - len(feats) % stack_order
|
59 |
+
res = np.zeros([res, feat_dim]).astype(feats.dtype)
|
60 |
+
feats = np.concatenate([feats, res], axis=0)
|
61 |
+
feats = feats.reshape((-1, stack_order, feat_dim)).reshape(-1, stack_order*feat_dim)
|
62 |
+
return feats
|
63 |
+
video_fn, audio_fn = mix_name
|
64 |
+
video_feats = self.load_image(video_fn)
|
65 |
+
|
66 |
+
audio_fn = audio_fn.split(':')[0]
|
67 |
+
sample_rate, wav_data = wavfile.read(audio_fn)
|
68 |
+
assert sample_rate == 16_000 and len(wav_data.shape) == 1
|
69 |
+
audio_feats = logfbank(wav_data, samplerate=sample_rate).astype(np.float32)
|
70 |
+
audio_feats = stacker(audio_feats, self.stack_order_audio)
|
71 |
+
|
72 |
+
diff = len(audio_feats) - len(video_feats)
|
73 |
+
if diff < 0:
|
74 |
+
audio_feats = np.concatenate([audio_feats, np.zeros([-diff, audio_feats.shape[-1]], dtype=audio_feats.dtype)])
|
75 |
+
elif diff > 0:
|
76 |
+
audio_feats = audio_feats[:-diff]
|
77 |
+
return video_feats, audio_feats
|
78 |
+
|
79 |
+
def load_image(self, audio_name):
|
80 |
+
feats = self.custom_utils.load_video(audio_name)
|
81 |
+
feats = self.transform(feats)
|
82 |
+
feats = np.expand_dims(feats, axis=-1)
|
83 |
+
return feats
|
84 |
+
|
85 |
+
def get_feats(self, path, ref_len=None):
|
86 |
+
video_feats, audio_feats = self.load_feature(path, ref_len)
|
87 |
+
with torch.no_grad():
|
88 |
+
audio_feats, video_feats = torch.from_numpy(audio_feats.astype(np.float32)).cuda(), torch.from_numpy(video_feats.astype(np.float32)).cuda()
|
89 |
+
if self.task.cfg.normalize:
|
90 |
+
audio_feats = F.layer_norm(audio_feats, audio_feats.shape[1:])
|
91 |
+
video_feats = video_feats.unsqueeze(dim=0).permute((0, 4, 1, 2, 3)).contiguous()
|
92 |
+
audio_feats = audio_feats.unsqueeze(dim=0).transpose(1, 2)
|
93 |
+
source = {'audio': audio_feats, 'video': video_feats}
|
94 |
+
if self.layer == 0:
|
95 |
+
ret_conv, output_layer = True, None
|
96 |
+
else:
|
97 |
+
ret_conv, output_layer = False, self.layer
|
98 |
+
feat, _ = self.model.extract_features(
|
99 |
+
source=source,
|
100 |
+
padding_mask=None,
|
101 |
+
mask=False,
|
102 |
+
output_layer=output_layer,
|
103 |
+
ret_conv=ret_conv
|
104 |
+
# output_layer=self.layer,
|
105 |
+
)
|
106 |
+
return feat.squeeze(dim=0)
|
107 |
+
|
108 |
+
|
109 |
+
def get_path_iterator(tsv, nshard, rank):
|
110 |
+
with open(tsv, "r") as f:
|
111 |
+
root = f.readline().rstrip()
|
112 |
+
lines = [line.rstrip() for line in f]
|
113 |
+
tot = len(lines)
|
114 |
+
shard_size = math.ceil(tot / nshard)
|
115 |
+
start, end = rank * shard_size, min((rank + 1) * shard_size, tot)
|
116 |
+
assert start < end, "start={start}, end={end}"
|
117 |
+
logger.info(
|
118 |
+
f"rank {rank} of {nshard}, process {end-start} "
|
119 |
+
f"({start}-{end}) out of {tot}"
|
120 |
+
)
|
121 |
+
|
122 |
+
lines = lines[start:end]
|
123 |
+
|
124 |
+
def iterate():
|
125 |
+
for line in lines:
|
126 |
+
items = line.strip().split("\t")
|
127 |
+
# audio_path = f"{items[1]}:{items[0]}"
|
128 |
+
yield (items[1], items[2]+':'+items[0]), int(items[3])
|
129 |
+
|
130 |
+
return iterate, len(lines)
|
131 |
+
|
132 |
+
|
133 |
+
def dump_feature(
|
134 |
+
tsv_dir, split, ckpt_path, layer, nshard, rank, feat_dir, max_chunk, custom_utils=None, **kwargs
|
135 |
+
):
|
136 |
+
reader = HubertFeatureReader(ckpt_path, layer, max_chunk, custom_utils=custom_utils)
|
137 |
+
generator, num = get_path_iterator(f"{tsv_dir}/{split}.tsv", nshard, rank)
|
138 |
+
iterator = generator()
|
139 |
+
|
140 |
+
feat_path = f"{feat_dir}/{split}_{rank}_{nshard}.npy"
|
141 |
+
leng_path = f"{feat_dir}/{split}_{rank}_{nshard}.len"
|
142 |
+
|
143 |
+
os.makedirs(feat_dir, exist_ok=True)
|
144 |
+
if os.path.exists(feat_path):
|
145 |
+
os.remove(feat_path)
|
146 |
+
|
147 |
+
feat_f = NpyAppendArray(feat_path)
|
148 |
+
with open(leng_path, "w") as leng_f:
|
149 |
+
for path, nsample in tqdm.tqdm(iterator, total=num):
|
150 |
+
feat = reader.get_feats(path, nsample)
|
151 |
+
feat_f.append(feat.cpu().numpy())
|
152 |
+
leng_f.write(f"{len(feat)}\n")
|
153 |
+
logger.info("finished successfully")
|
154 |
+
|
155 |
+
|
156 |
+
if __name__ == "__main__":
|
157 |
+
import argparse
|
158 |
+
|
159 |
+
parser = argparse.ArgumentParser()
|
160 |
+
parser.add_argument("tsv_dir")
|
161 |
+
parser.add_argument("split")
|
162 |
+
parser.add_argument("ckpt_path")
|
163 |
+
parser.add_argument("layer", type=int)
|
164 |
+
parser.add_argument("nshard", type=int)
|
165 |
+
parser.add_argument("rank", type=int)
|
166 |
+
parser.add_argument("feat_dir")
|
167 |
+
parser.add_argument("--max_chunk", type=int, default=1600000)
|
168 |
+
parser.add_argument("--user_dir", type=str, default=None)
|
169 |
+
|
170 |
+
args = parser.parse_args()
|
171 |
+
logger.info(args)
|
172 |
+
fairseq.utils.import_user_module(args)
|
173 |
+
sys.path.append(args.user_dir)
|
174 |
+
import utils as custom_utils
|
175 |
+
kwargs = vars(args)
|
176 |
+
kwargs.update({'custom_utils': custom_utils})
|
177 |
+
dump_feature(**kwargs)
|
av_hubert/avhubert/clustering/dump_km_label.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
import joblib
|
14 |
+
import torch
|
15 |
+
import tqdm
|
16 |
+
|
17 |
+
logging.basicConfig(
|
18 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
19 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
20 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
21 |
+
stream=sys.stdout,
|
22 |
+
)
|
23 |
+
logger = logging.getLogger("dump_km_label")
|
24 |
+
|
25 |
+
|
26 |
+
class ApplyKmeans(object):
|
27 |
+
def __init__(self, km_path):
|
28 |
+
self.km_model = joblib.load(km_path)
|
29 |
+
self.C_np = self.km_model.cluster_centers_.transpose()
|
30 |
+
self.Cnorm_np = (self.C_np ** 2).sum(0, keepdims=True)
|
31 |
+
|
32 |
+
self.C = torch.from_numpy(self.C_np)
|
33 |
+
self.Cnorm = torch.from_numpy(self.Cnorm_np)
|
34 |
+
if torch.cuda.is_available():
|
35 |
+
self.C = self.C.cuda()
|
36 |
+
self.Cnorm = self.Cnorm.cuda()
|
37 |
+
|
38 |
+
def __call__(self, x):
|
39 |
+
if isinstance(x, torch.Tensor):
|
40 |
+
dist = (
|
41 |
+
x.pow(2).sum(1, keepdim=True)
|
42 |
+
- 2 * torch.matmul(x, self.C)
|
43 |
+
+ self.Cnorm
|
44 |
+
)
|
45 |
+
return dist.argmin(dim=1).cpu().numpy()
|
46 |
+
else:
|
47 |
+
dist = (
|
48 |
+
(x ** 2).sum(1, keepdims=True)
|
49 |
+
- 2 * np.matmul(x, self.C_np)
|
50 |
+
+ self.Cnorm_np
|
51 |
+
)
|
52 |
+
return np.argmin(dist, axis=1)
|
53 |
+
|
54 |
+
|
55 |
+
def get_feat_iterator(feat_dir, split, nshard, rank):
|
56 |
+
feat_path = f"{feat_dir}/{split}_{rank}_{nshard}.npy"
|
57 |
+
leng_path = f"{feat_dir}/{split}_{rank}_{nshard}.len"
|
58 |
+
with open(leng_path, "r") as f:
|
59 |
+
lengs = [int(line.rstrip()) for line in f]
|
60 |
+
offsets = [0] + np.cumsum(lengs[:-1]).tolist()
|
61 |
+
|
62 |
+
def iterate():
|
63 |
+
feat = np.load(feat_path, mmap_mode="r")
|
64 |
+
assert feat.shape[0] == (offsets[-1] + lengs[-1])
|
65 |
+
for offset, leng in zip(offsets, lengs):
|
66 |
+
yield feat[offset: offset + leng]
|
67 |
+
|
68 |
+
return iterate, len(lengs)
|
69 |
+
|
70 |
+
|
71 |
+
def dump_label(feat_dir, split, km_path, nshard, rank, lab_dir):
|
72 |
+
apply_kmeans = ApplyKmeans(km_path)
|
73 |
+
generator, num = get_feat_iterator(feat_dir, split, nshard, rank)
|
74 |
+
iterator = generator()
|
75 |
+
|
76 |
+
lab_path = f"{lab_dir}/{split}_{rank}_{nshard}.km"
|
77 |
+
os.makedirs(lab_dir, exist_ok=True)
|
78 |
+
with open(lab_path, "w") as f:
|
79 |
+
for feat in tqdm.tqdm(iterator, total=num):
|
80 |
+
# feat = torch.from_numpy(feat).cuda()
|
81 |
+
lab = apply_kmeans(feat).tolist()
|
82 |
+
f.write(" ".join(map(str, lab)) + "\n")
|
83 |
+
logger.info("finished successfully")
|
84 |
+
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
import argparse
|
88 |
+
|
89 |
+
parser = argparse.ArgumentParser()
|
90 |
+
parser.add_argument("feat_dir")
|
91 |
+
parser.add_argument("split")
|
92 |
+
parser.add_argument("km_path")
|
93 |
+
parser.add_argument("nshard", type=int)
|
94 |
+
parser.add_argument("rank", type=int)
|
95 |
+
parser.add_argument("lab_dir")
|
96 |
+
args = parser.parse_args()
|
97 |
+
logging.info(str(args))
|
98 |
+
|
99 |
+
dump_label(**vars(args))
|
av_hubert/avhubert/clustering/dump_mfcc_feature.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import math
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
|
12 |
+
import soundfile as sf
|
13 |
+
import torch
|
14 |
+
import torchaudio
|
15 |
+
import tqdm
|
16 |
+
from npy_append_array import NpyAppendArray
|
17 |
+
|
18 |
+
logging.basicConfig(
|
19 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
20 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
21 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
22 |
+
stream=sys.stdout,
|
23 |
+
)
|
24 |
+
logger = logging.getLogger("dump_mfcc_feature")
|
25 |
+
|
26 |
+
|
27 |
+
class MfccFeatureReader(object):
|
28 |
+
def __init__(self, sample_rate):
|
29 |
+
self.sample_rate = sample_rate
|
30 |
+
|
31 |
+
def read_audio(self, path, ref_len=None):
|
32 |
+
wav, sr = sf.read(path)
|
33 |
+
assert sr == self.sample_rate, sr
|
34 |
+
if wav.ndim == 2:
|
35 |
+
wav = wav.mean(-1)
|
36 |
+
assert wav.ndim == 1, wav.ndim
|
37 |
+
if ref_len is not None and abs(ref_len - len(wav)) > 160:
|
38 |
+
logging.warning(f"ref {ref_len} != read {len(wav)} ({path})")
|
39 |
+
return wav
|
40 |
+
|
41 |
+
def get_feats(self, path, ref_len=None):
|
42 |
+
x = self.read_audio(path, ref_len)
|
43 |
+
with torch.no_grad():
|
44 |
+
x = torch.from_numpy(x).float()
|
45 |
+
x = x.view(1, -1)
|
46 |
+
|
47 |
+
mfccs = torchaudio.compliance.kaldi.mfcc(
|
48 |
+
waveform=x,
|
49 |
+
sample_frequency=self.sample_rate,
|
50 |
+
use_energy=False,
|
51 |
+
) # (time, freq)
|
52 |
+
mfccs = mfccs.transpose(0, 1) # (freq, time)
|
53 |
+
deltas = torchaudio.functional.compute_deltas(mfccs)
|
54 |
+
ddeltas = torchaudio.functional.compute_deltas(deltas)
|
55 |
+
concat = torch.cat([mfccs, deltas, ddeltas], dim=0)
|
56 |
+
concat = concat.transpose(0, 1).contiguous() # (freq, time)
|
57 |
+
return concat
|
58 |
+
|
59 |
+
|
60 |
+
def get_path_iterator(tsv, nshard, rank):
|
61 |
+
with open(tsv, "r") as f:
|
62 |
+
root = f.readline().rstrip()
|
63 |
+
lines = [line.rstrip() for line in f]
|
64 |
+
tot = len(lines)
|
65 |
+
shard_size = math.ceil(tot / nshard)
|
66 |
+
start, end = rank * shard_size, min((rank + 1) * shard_size, tot)
|
67 |
+
assert start < end, "start={start}, end={end}"
|
68 |
+
logger.info(
|
69 |
+
f"rank {rank} of {nshard}, process {end-start} "
|
70 |
+
f"({start}-{end}) out of {tot}"
|
71 |
+
)
|
72 |
+
|
73 |
+
lines = lines[start:end]
|
74 |
+
|
75 |
+
def iterate():
|
76 |
+
for line in lines:
|
77 |
+
_, video_path, wav_path, nsample_video, nsample_wav = line.split("\t")
|
78 |
+
yield f"{root}/{wav_path}", int(nsample_wav)
|
79 |
+
|
80 |
+
return iterate, len(lines)
|
81 |
+
|
82 |
+
|
83 |
+
def dump_feature(tsv_dir, split, nshard, rank, feat_dir, sample_rate=16_000):
|
84 |
+
reader = MfccFeatureReader(sample_rate)
|
85 |
+
generator, num = get_path_iterator(f"{tsv_dir}/{split}.tsv", nshard, rank)
|
86 |
+
iterator = generator()
|
87 |
+
|
88 |
+
feat_path = f"{feat_dir}/{split}_{rank}_{nshard}.npy"
|
89 |
+
leng_path = f"{feat_dir}/{split}_{rank}_{nshard}.len"
|
90 |
+
|
91 |
+
os.makedirs(feat_dir, exist_ok=True)
|
92 |
+
if os.path.exists(feat_path):
|
93 |
+
os.remove(feat_path)
|
94 |
+
|
95 |
+
feat_f = NpyAppendArray(feat_path)
|
96 |
+
with open(leng_path, "w") as leng_f:
|
97 |
+
for path, nsample in tqdm.tqdm(iterator, total=num):
|
98 |
+
feat = reader.get_feats(path, nsample)
|
99 |
+
feat_f.append(feat.cpu().numpy())
|
100 |
+
leng_f.write(f"{len(feat)}\n")
|
101 |
+
logger.info("finished successfully")
|
102 |
+
|
103 |
+
|
104 |
+
if __name__ == "__main__":
|
105 |
+
import argparse
|
106 |
+
|
107 |
+
parser = argparse.ArgumentParser()
|
108 |
+
parser.add_argument("tsv_dir")
|
109 |
+
parser.add_argument("split")
|
110 |
+
parser.add_argument("nshard", type=int)
|
111 |
+
parser.add_argument("rank", type=int)
|
112 |
+
parser.add_argument("feat_dir")
|
113 |
+
parser.add_argument("--sample_rate", type=int, default=16000)
|
114 |
+
args = parser.parse_args()
|
115 |
+
logger.info(args)
|
116 |
+
|
117 |
+
dump_feature(**vars(args))
|
av_hubert/avhubert/clustering/learn_kmeans.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
from sklearn.cluster import MiniBatchKMeans
|
13 |
+
|
14 |
+
import joblib
|
15 |
+
|
16 |
+
logging.basicConfig(
|
17 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
18 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
19 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
20 |
+
stream=sys.stdout,
|
21 |
+
)
|
22 |
+
logger = logging.getLogger("learn_kmeans")
|
23 |
+
|
24 |
+
|
25 |
+
def get_km_model(
|
26 |
+
n_clusters,
|
27 |
+
init,
|
28 |
+
max_iter,
|
29 |
+
batch_size,
|
30 |
+
tol,
|
31 |
+
max_no_improvement,
|
32 |
+
n_init,
|
33 |
+
reassignment_ratio,
|
34 |
+
):
|
35 |
+
return MiniBatchKMeans(
|
36 |
+
n_clusters=n_clusters,
|
37 |
+
init=init,
|
38 |
+
max_iter=max_iter,
|
39 |
+
batch_size=batch_size,
|
40 |
+
verbose=1,
|
41 |
+
compute_labels=False,
|
42 |
+
tol=tol,
|
43 |
+
max_no_improvement=max_no_improvement,
|
44 |
+
init_size=None,
|
45 |
+
n_init=n_init,
|
46 |
+
reassignment_ratio=reassignment_ratio,
|
47 |
+
)
|
48 |
+
|
49 |
+
|
50 |
+
def load_feature_shard(feat_dir, split, nshard, rank, percent):
|
51 |
+
feat_path = f"{feat_dir}/{split}_{rank}_{nshard}.npy"
|
52 |
+
leng_path = f"{feat_dir}/{split}_{rank}_{nshard}.len"
|
53 |
+
with open(leng_path, "r") as f:
|
54 |
+
lengs = [int(line.rstrip()) for line in f]
|
55 |
+
offsets = [0] + np.cumsum(lengs[:-1]).tolist()
|
56 |
+
|
57 |
+
if percent < 0:
|
58 |
+
return np.load(feat_path, mmap_mode="r")
|
59 |
+
else:
|
60 |
+
nsample = int(np.ceil(len(lengs) * percent))
|
61 |
+
indices = np.random.choice(len(lengs), nsample, replace=False)
|
62 |
+
feat = np.load(feat_path, mmap_mode="r")
|
63 |
+
sampled_feat = np.concatenate(
|
64 |
+
[feat[offsets[i]: offsets[i] + lengs[i]] for i in indices], axis=0
|
65 |
+
)
|
66 |
+
logger.info(
|
67 |
+
(
|
68 |
+
f"sampled {nsample} utterances, {len(sampled_feat)} frames "
|
69 |
+
f"from shard {rank}/{nshard}"
|
70 |
+
)
|
71 |
+
)
|
72 |
+
return sampled_feat
|
73 |
+
|
74 |
+
|
75 |
+
def load_feature(feat_dir, split, nshard, seed, percent):
|
76 |
+
assert percent <= 1.0
|
77 |
+
feat = np.concatenate(
|
78 |
+
[
|
79 |
+
load_feature_shard(feat_dir, split, nshard, r, percent)
|
80 |
+
for r in range(nshard)
|
81 |
+
],
|
82 |
+
axis=0,
|
83 |
+
)
|
84 |
+
logging.info(f"loaded feature with dimension {feat.shape}")
|
85 |
+
return feat
|
86 |
+
|
87 |
+
|
88 |
+
def learn_kmeans(
|
89 |
+
feat_dir,
|
90 |
+
split,
|
91 |
+
nshard,
|
92 |
+
km_path,
|
93 |
+
n_clusters,
|
94 |
+
seed,
|
95 |
+
percent,
|
96 |
+
init,
|
97 |
+
max_iter,
|
98 |
+
batch_size,
|
99 |
+
tol,
|
100 |
+
n_init,
|
101 |
+
reassignment_ratio,
|
102 |
+
max_no_improvement,
|
103 |
+
):
|
104 |
+
np.random.seed(seed)
|
105 |
+
feat = load_feature(feat_dir, split, nshard, seed, percent)
|
106 |
+
km_model = get_km_model(
|
107 |
+
n_clusters,
|
108 |
+
init,
|
109 |
+
max_iter,
|
110 |
+
batch_size,
|
111 |
+
tol,
|
112 |
+
max_no_improvement,
|
113 |
+
n_init,
|
114 |
+
reassignment_ratio,
|
115 |
+
)
|
116 |
+
km_model.fit(feat)
|
117 |
+
joblib.dump(km_model, km_path)
|
118 |
+
|
119 |
+
inertia = -km_model.score(feat) / len(feat)
|
120 |
+
logger.info("total intertia: %.5f", inertia)
|
121 |
+
logger.info("finished successfully")
|
122 |
+
|
123 |
+
|
124 |
+
if __name__ == "__main__":
|
125 |
+
import argparse
|
126 |
+
|
127 |
+
parser = argparse.ArgumentParser()
|
128 |
+
parser.add_argument("feat_dir", type=str)
|
129 |
+
parser.add_argument("split", type=str)
|
130 |
+
parser.add_argument("nshard", type=int)
|
131 |
+
parser.add_argument("km_path", type=str)
|
132 |
+
parser.add_argument("n_clusters", type=int)
|
133 |
+
parser.add_argument("--seed", default=0, type=int)
|
134 |
+
parser.add_argument(
|
135 |
+
"--percent", default=-1, type=float, help="sample a subset; -1 for all"
|
136 |
+
)
|
137 |
+
parser.add_argument("--init", default="k-means++")
|
138 |
+
parser.add_argument("--max_iter", default=100, type=int)
|
139 |
+
parser.add_argument("--batch_size", default=10000, type=int)
|
140 |
+
parser.add_argument("--tol", default=0.0, type=float)
|
141 |
+
parser.add_argument("--max_no_improvement", default=100, type=int)
|
142 |
+
parser.add_argument("--n_init", default=20, type=int)
|
143 |
+
parser.add_argument("--reassignment_ratio", default=0.0, type=float)
|
144 |
+
args = parser.parse_args()
|
145 |
+
logging.info(str(args))
|
146 |
+
|
147 |
+
learn_kmeans(**vars(args))
|
av_hubert/avhubert/clustering/requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
soundfile
|
2 |
+
joblib
|
3 |
+
sklearn
|
4 |
+
torchaudio==0.10.1
|
5 |
+
npy-append-array==0.9.13
|
6 |
+
submitit==1.4.1
|
av_hubert/avhubert/clustering/submit_cluster.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import os, subprocess
|
8 |
+
import submitit
|
9 |
+
import argparse
|
10 |
+
from argparse import Namespace
|
11 |
+
|
12 |
+
def dump_av_hubert(*args, **kwargs):
|
13 |
+
from dump_hubert_feature import dump_feature
|
14 |
+
import fairseq
|
15 |
+
import sys
|
16 |
+
av_hubert_dir = os.path.join(os.getcwd(), '..')
|
17 |
+
fairseq.utils.import_user_module(Namespace(user_dir=av_hubert_dir))
|
18 |
+
sys.path.append(av_hubert_dir)
|
19 |
+
import utils as custom_utils
|
20 |
+
kwargs.update({'custom_utils': custom_utils})
|
21 |
+
args = args[0]
|
22 |
+
dump_feature(*args, **kwargs)
|
23 |
+
return
|
24 |
+
|
25 |
+
|
26 |
+
def dump_mfcc(*args, **kwargs):
|
27 |
+
from dump_mfcc_feature import dump_feature
|
28 |
+
args = args[0]
|
29 |
+
dump_feature(*args, **kwargs)
|
30 |
+
return
|
31 |
+
|
32 |
+
def run_kmeans(*args, **kwargs):
|
33 |
+
import sys
|
34 |
+
from learn_kmeans import learn_kmeans
|
35 |
+
learn_kmeans(*args, **kwargs)
|
36 |
+
return
|
37 |
+
|
38 |
+
def apply_kmeans(*args, **kwargs):
|
39 |
+
import sys
|
40 |
+
from dump_km_label import dump_label
|
41 |
+
args = args[0]
|
42 |
+
dump_label(*args, **kwargs)
|
43 |
+
return
|
44 |
+
|
45 |
+
def concatenate(*args, **kwargs):
|
46 |
+
from concat import main as concat_fn
|
47 |
+
args = args[0]
|
48 |
+
concat_fn(*args, **kwargs)
|
49 |
+
return
|
50 |
+
|
51 |
+
def main():
|
52 |
+
parser = argparse.ArgumentParser(description='clustering', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
53 |
+
parser.add_argument('--tsv', type=str, help='tsv dir')
|
54 |
+
parser.add_argument('--output', type=str, help='output dir (labels)')
|
55 |
+
parser.add_argument('--ckpt', type=str, help='checkpoint of last iteration')
|
56 |
+
parser.add_argument('--nlayer', type=int, default=12, help='layer index for clustering')
|
57 |
+
parser.add_argument('--ncluster', type=int, default=500, help='number of clusters')
|
58 |
+
parser.add_argument('--nshard', type=int, default=100, help='number of shards')
|
59 |
+
parser.add_argument('--percent', type=float, default=0.05, help='Percentage for clustering')
|
60 |
+
parser.add_argument('--mfcc', action='store_true', help='extracting MFCC feature')
|
61 |
+
parser.add_argument('--slurm-partition', type=str, help='slurm partitions')
|
62 |
+
args = parser.parse_args()
|
63 |
+
tsv_dir = args.tsv
|
64 |
+
output_dir = args.output
|
65 |
+
km_dir = output_dir
|
66 |
+
feat_dir = output_dir
|
67 |
+
ckpt_path = args.ckpt
|
68 |
+
nlayer = args.nlayer
|
69 |
+
nshard = args.nshard
|
70 |
+
n_clusters = args.ncluster
|
71 |
+
slurm_partition = args.slurm_partition
|
72 |
+
is_mfcc = args.mfcc
|
73 |
+
timeout_min = 240
|
74 |
+
percent = 0.1
|
75 |
+
log_folder = "log_submit/%j"
|
76 |
+
km_path = f"{km_dir}/kmeans.mdl"
|
77 |
+
os.makedirs(output_dir, exist_ok=True)
|
78 |
+
ext = submitit.AutoExecutor(folder=log_folder)
|
79 |
+
|
80 |
+
args_array = []
|
81 |
+
if is_mfcc:
|
82 |
+
print(f"Dump MFCC feature")
|
83 |
+
for rank in range(nshard):
|
84 |
+
args = [tsv_dir, 'train', nshard, rank, output_dir]
|
85 |
+
args_array.append(args)
|
86 |
+
args_array.append([tsv_dir, 'valid', 1, 0, output_dir])
|
87 |
+
ext.update_parameters(timeout_min=60, slurm_partition=slurm_partition, cpus_per_task=1, slurm_array_parallelism=100)
|
88 |
+
jobs = ext.map_array(dump_mfcc, args_array)
|
89 |
+
else:
|
90 |
+
print(f"Dump AV-Hubert feature")
|
91 |
+
for rank in range(nshard):
|
92 |
+
args = [tsv_dir, 'train', ckpt_path, nlayer, nshard, rank, output_dir, 1600000]
|
93 |
+
args_array.append(args)
|
94 |
+
args_array.append([tsv_dir, 'valid', ckpt_path, nlayer, 1, 0, output_dir, 1600000])
|
95 |
+
ext.update_parameters(timeout_min=60, slurm_partition=slurm_partition, cpus_per_task=1, gpus_per_node=1, slurm_array_parallelism=100)
|
96 |
+
jobs = ext.map_array(dump_av_hubert, args_array)
|
97 |
+
[job.result() for job in jobs]
|
98 |
+
|
99 |
+
print(f"Learn K-means")
|
100 |
+
percent, batch_size = percent, 20000
|
101 |
+
ext.update_parameters(timeout_min=timeout_min, slurm_partition=slurm_partition, cpus_per_task=8, mem_gb=128)
|
102 |
+
args, kwargs = [feat_dir, 'train', nshard, km_path, n_clusters], vars(Namespace(seed=0, percent=percent, init="k-means++", max_iter=100, batch_size=batch_size, tol=0.0, n_init=20, reassignment_ratio=0.0, max_no_improvement=100))
|
103 |
+
print(args, kwargs)
|
104 |
+
job = ext.submit(run_kmeans, *args, **kwargs)
|
105 |
+
job.result()
|
106 |
+
|
107 |
+
print(f"Apply K-means")
|
108 |
+
args_array = []
|
109 |
+
for rank in range(nshard):
|
110 |
+
args = [feat_dir, 'train', km_path, nshard, rank, output_dir]
|
111 |
+
args_array.append(args)
|
112 |
+
args_array.append([feat_dir, 'valid', km_path, 1, 0, output_dir])
|
113 |
+
ext.update_parameters(timeout_min=10, slurm_partition=slurm_partition, cpus_per_task=1, slurm_array_parallelism=500)
|
114 |
+
jobs = ext.map_array(apply_kmeans, args_array)
|
115 |
+
[job.result() for job in jobs]
|
116 |
+
|
117 |
+
print(f"Concatenate labels")
|
118 |
+
cont = f"for rank in $(seq 0 {nshard-1}); do cat {output_dir}/train_${{rank}}_{nshard}.km; done > {output_dir}/train.km"
|
119 |
+
print(cont)
|
120 |
+
subprocess.call(cont, shell=True)
|
121 |
+
cont = f"cp {output_dir}/valid*.km {output_dir}/valid.km"
|
122 |
+
print(cont)
|
123 |
+
subprocess.call(cont, shell=True)
|
124 |
+
with open(f"{output_dir}/dict.km.txt", 'w') as fo:
|
125 |
+
for i in range(n_clusters):
|
126 |
+
fo.write(f"{i} {10000}\n")
|
127 |
+
print(f"Please delete intermediate files to save space: rm {output_dir}/*npy")
|
128 |
+
return
|
129 |
+
|
130 |
+
|
131 |
+
if __name__ == '__main__':
|
132 |
+
main()
|
av_hubert/avhubert/conf/av-finetune/base_noise_pt_noise_ft_30h.yaml
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
tensorboard_logdir: tblog
|
8 |
+
seed: 1337
|
9 |
+
user_dir: ???
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval: 2
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
best_checkpoint_metric: accuracy
|
16 |
+
maximize_best_checkpoint_metric: true
|
17 |
+
|
18 |
+
distributed_training:
|
19 |
+
ddp_backend: c10d
|
20 |
+
find_unused_parameters: true
|
21 |
+
distributed_world_size: 8
|
22 |
+
distributed_port: 29671
|
23 |
+
nprocs_per_node: 8
|
24 |
+
|
25 |
+
task:
|
26 |
+
_name: av_hubert_pretraining
|
27 |
+
is_s2s: true
|
28 |
+
data: ???
|
29 |
+
label_dir: ???
|
30 |
+
tokenizer_bpe_model: ???
|
31 |
+
normalize: true # must be consistent with pre-training
|
32 |
+
labels: ["wrd"]
|
33 |
+
single_target: true
|
34 |
+
fine_tuning: true
|
35 |
+
stack_order_audio: 4
|
36 |
+
tokenizer_bpe_name: sentencepiece
|
37 |
+
max_sample_size: 500
|
38 |
+
modalities: ["video","audio"]
|
39 |
+
image_aug: true
|
40 |
+
pad_audio: true
|
41 |
+
random_crop: false
|
42 |
+
noise_prob: 0.25
|
43 |
+
noise_snr: 0
|
44 |
+
noise_wav: ???
|
45 |
+
|
46 |
+
dataset:
|
47 |
+
num_workers: 6
|
48 |
+
max_tokens: 1000
|
49 |
+
validate_after_updates: 0
|
50 |
+
validate_interval: 2
|
51 |
+
train_subset: train
|
52 |
+
valid_subset: valid
|
53 |
+
|
54 |
+
criterion:
|
55 |
+
_name: label_smoothed_cross_entropy
|
56 |
+
report_accuracy: true
|
57 |
+
label_smoothing: 0.1
|
58 |
+
|
59 |
+
optimization:
|
60 |
+
max_update: 30000
|
61 |
+
lr: [0.001]
|
62 |
+
sentence_avg: true
|
63 |
+
update_freq: [1]
|
64 |
+
|
65 |
+
optimizer:
|
66 |
+
_name: adam
|
67 |
+
adam_betas: (0.9,0.98)
|
68 |
+
adam_eps: 1e-08
|
69 |
+
|
70 |
+
lr_scheduler:
|
71 |
+
_name: tri_stage
|
72 |
+
warmup_steps: 10000
|
73 |
+
hold_steps: 0
|
74 |
+
decay_steps: 20000
|
75 |
+
final_lr_scale: 0.05
|
76 |
+
|
77 |
+
model:
|
78 |
+
_name: av_hubert_seq2seq
|
79 |
+
w2v_path: ???
|
80 |
+
apply_mask: false
|
81 |
+
mask_selection: static
|
82 |
+
mask_length: 10
|
83 |
+
mask_other: 0
|
84 |
+
mask_prob: 0.75
|
85 |
+
mask_channel_selection: static
|
86 |
+
mask_channel_length: 64
|
87 |
+
mask_channel_other: 0
|
88 |
+
mask_channel_prob: 0.5
|
89 |
+
layerdrop: 0.1
|
90 |
+
dropout: 0.0
|
91 |
+
activation_dropout: 0.1
|
92 |
+
attention_dropout: 0.0
|
93 |
+
feature_grad_mult: 1.0
|
94 |
+
decoder_layers: 6
|
95 |
+
decoder_dropout: 0.1
|
96 |
+
decoder_attention_dropout: 0.0
|
97 |
+
decoder_activation_dropout: 0.1
|
98 |
+
freeze_finetune_updates: 24000
|
99 |
+
share_decoder_input_output_embed: true
|
100 |
+
decoder_normalize_before: true
|
101 |
+
|
102 |
+
hydra:
|
103 |
+
job:
|
104 |
+
config:
|
105 |
+
override_dirname:
|
106 |
+
kv_sep: '-'
|
107 |
+
item_sep: '__'
|
108 |
+
exclude_keys:
|
109 |
+
- run
|
110 |
+
- task.data
|
111 |
+
- task.label_dir
|
112 |
+
- model.w2v_path
|
113 |
+
- dataset.train_subset
|
114 |
+
- dataset.valid_subset
|
115 |
+
- criterion.wer_kenlm_model
|
116 |
+
- criterion.wer_lexicon
|
117 |
+
run:
|
118 |
+
dir: ???
|
119 |
+
sweep:
|
120 |
+
dir: ???
|
121 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/av-finetune/base_noise_pt_noise_ft_433h.yaml
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
tensorboard_logdir: tblog
|
8 |
+
seed: 1337
|
9 |
+
user_dir: ???
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval: 2
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
best_checkpoint_metric: accuracy
|
16 |
+
maximize_best_checkpoint_metric: true
|
17 |
+
|
18 |
+
distributed_training:
|
19 |
+
ddp_backend: c10d
|
20 |
+
find_unused_parameters: true
|
21 |
+
distributed_world_size: 8
|
22 |
+
distributed_port: 29671
|
23 |
+
nprocs_per_node: 8
|
24 |
+
|
25 |
+
task:
|
26 |
+
_name: av_hubert_pretraining
|
27 |
+
is_s2s: true
|
28 |
+
data: ???
|
29 |
+
label_dir: ???
|
30 |
+
tokenizer_bpe_model: ???
|
31 |
+
normalize: true # must be consistent with pre-training
|
32 |
+
labels: ["wrd"]
|
33 |
+
single_target: true
|
34 |
+
fine_tuning: true
|
35 |
+
stack_order_audio: 4
|
36 |
+
tokenizer_bpe_name: sentencepiece
|
37 |
+
max_sample_size: 500
|
38 |
+
modalities: ["video","audio"]
|
39 |
+
image_aug: true
|
40 |
+
pad_audio: true
|
41 |
+
random_crop: false
|
42 |
+
noise_prob: 0.25
|
43 |
+
noise_snr: 0
|
44 |
+
noise_wav: ???
|
45 |
+
|
46 |
+
dataset:
|
47 |
+
num_workers: 6
|
48 |
+
max_tokens: 1000
|
49 |
+
validate_after_updates: 0
|
50 |
+
validate_interval: 2
|
51 |
+
train_subset: train
|
52 |
+
valid_subset: valid
|
53 |
+
|
54 |
+
criterion:
|
55 |
+
_name: label_smoothed_cross_entropy
|
56 |
+
report_accuracy: true
|
57 |
+
label_smoothing: 0.1
|
58 |
+
|
59 |
+
optimization:
|
60 |
+
max_update: 60000
|
61 |
+
lr: [0.001]
|
62 |
+
sentence_avg: true
|
63 |
+
update_freq: [1]
|
64 |
+
|
65 |
+
optimizer:
|
66 |
+
_name: adam
|
67 |
+
adam_betas: (0.9,0.98)
|
68 |
+
adam_eps: 1e-08
|
69 |
+
|
70 |
+
lr_scheduler:
|
71 |
+
_name: tri_stage
|
72 |
+
warmup_steps: 20000
|
73 |
+
hold_steps: 0
|
74 |
+
decay_steps: 40000
|
75 |
+
final_lr_scale: 0.05
|
76 |
+
|
77 |
+
model:
|
78 |
+
_name: av_hubert_seq2seq
|
79 |
+
w2v_path: ???
|
80 |
+
apply_mask: false
|
81 |
+
mask_selection: static
|
82 |
+
mask_length: 10
|
83 |
+
mask_other: 0
|
84 |
+
mask_prob: 0.75
|
85 |
+
mask_channel_selection: static
|
86 |
+
mask_channel_length: 64
|
87 |
+
mask_channel_other: 0
|
88 |
+
mask_channel_prob: 0.5
|
89 |
+
layerdrop: 0.1
|
90 |
+
dropout: 0.0
|
91 |
+
activation_dropout: 0.1
|
92 |
+
attention_dropout: 0.0
|
93 |
+
feature_grad_mult: 1.0
|
94 |
+
decoder_layers: 6
|
95 |
+
decoder_dropout: 0.1
|
96 |
+
decoder_attention_dropout: 0.0
|
97 |
+
decoder_activation_dropout: 0.1
|
98 |
+
freeze_finetune_updates: 48000
|
99 |
+
share_decoder_input_output_embed: true
|
100 |
+
decoder_normalize_before: true
|
101 |
+
|
102 |
+
hydra:
|
103 |
+
job:
|
104 |
+
config:
|
105 |
+
override_dirname:
|
106 |
+
kv_sep: '-'
|
107 |
+
item_sep: '__'
|
108 |
+
exclude_keys:
|
109 |
+
- run
|
110 |
+
- task.data
|
111 |
+
- task.label_dir
|
112 |
+
- model.w2v_path
|
113 |
+
- dataset.train_subset
|
114 |
+
- dataset.valid_subset
|
115 |
+
- criterion.wer_kenlm_model
|
116 |
+
- criterion.wer_lexicon
|
117 |
+
run:
|
118 |
+
dir: ???
|
119 |
+
sweep:
|
120 |
+
dir: ???
|
121 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/av-finetune/large_noise_pt_noise_ft_30h.yaml
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
tensorboard_logdir: tblog
|
8 |
+
seed: 1337
|
9 |
+
user_dir: ???
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval: 2
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
best_checkpoint_metric: accuracy
|
16 |
+
maximize_best_checkpoint_metric: true
|
17 |
+
|
18 |
+
distributed_training:
|
19 |
+
ddp_backend: c10d
|
20 |
+
find_unused_parameters: true
|
21 |
+
distributed_world_size: 8
|
22 |
+
distributed_port: 29671
|
23 |
+
nprocs_per_node: 8
|
24 |
+
|
25 |
+
task:
|
26 |
+
_name: av_hubert_pretraining
|
27 |
+
is_s2s: true
|
28 |
+
data: ???
|
29 |
+
label_dir: ???
|
30 |
+
tokenizer_bpe_model: ???
|
31 |
+
normalize: true # must be consistent with pre-training
|
32 |
+
labels: ["wrd"]
|
33 |
+
single_target: true
|
34 |
+
fine_tuning: true
|
35 |
+
stack_order_audio: 4
|
36 |
+
tokenizer_bpe_name: sentencepiece
|
37 |
+
max_sample_size: 500
|
38 |
+
modalities: ["video","audio"]
|
39 |
+
image_aug: true
|
40 |
+
pad_audio: true
|
41 |
+
random_crop: false
|
42 |
+
noise_prob: 0.25
|
43 |
+
noise_snr: 0
|
44 |
+
noise_wav: ???
|
45 |
+
|
46 |
+
dataset:
|
47 |
+
num_workers: 6
|
48 |
+
max_tokens: 1000
|
49 |
+
validate_after_updates: 0
|
50 |
+
validate_interval: 2
|
51 |
+
train_subset: train
|
52 |
+
valid_subset: valid
|
53 |
+
|
54 |
+
criterion:
|
55 |
+
_name: label_smoothed_cross_entropy
|
56 |
+
report_accuracy: true
|
57 |
+
label_smoothing: 0.1
|
58 |
+
|
59 |
+
optimization:
|
60 |
+
max_update: 18000
|
61 |
+
lr: [0.001]
|
62 |
+
sentence_avg: true
|
63 |
+
update_freq: [1]
|
64 |
+
|
65 |
+
optimizer:
|
66 |
+
_name: adam
|
67 |
+
adam_betas: (0.9,0.98)
|
68 |
+
adam_eps: 1e-08
|
69 |
+
|
70 |
+
lr_scheduler:
|
71 |
+
_name: tri_stage
|
72 |
+
warmup_steps: 6000
|
73 |
+
hold_steps: 0
|
74 |
+
decay_steps: 18000
|
75 |
+
final_lr_scale: 0.05
|
76 |
+
|
77 |
+
model:
|
78 |
+
_name: av_hubert_seq2seq
|
79 |
+
w2v_path: ???
|
80 |
+
apply_mask: false
|
81 |
+
mask_selection: static
|
82 |
+
mask_length: 10
|
83 |
+
mask_other: 0
|
84 |
+
mask_prob: 0.75
|
85 |
+
mask_channel_selection: static
|
86 |
+
mask_channel_length: 64
|
87 |
+
mask_channel_other: 0
|
88 |
+
mask_channel_prob: 0.5
|
89 |
+
layerdrop: 0.1
|
90 |
+
dropout: 0.0
|
91 |
+
activation_dropout: 0.1
|
92 |
+
attention_dropout: 0.0
|
93 |
+
feature_grad_mult: 1.0
|
94 |
+
decoder_layers: 9
|
95 |
+
decoder_dropout: 0.1
|
96 |
+
decoder_attention_dropout: 0.0
|
97 |
+
decoder_activation_dropout: 0.1
|
98 |
+
freeze_finetune_updates: 30000
|
99 |
+
share_decoder_input_output_embed: true
|
100 |
+
decoder_normalize_before: true
|
101 |
+
decoder_embed_dim: 1024
|
102 |
+
decoder_ffn_embed_dim: 4096
|
103 |
+
decoder_attention_heads: 8
|
104 |
+
|
105 |
+
hydra:
|
106 |
+
job:
|
107 |
+
config:
|
108 |
+
override_dirname:
|
109 |
+
kv_sep: '-'
|
110 |
+
item_sep: '__'
|
111 |
+
exclude_keys:
|
112 |
+
- run
|
113 |
+
- task.data
|
114 |
+
- task.label_dir
|
115 |
+
- model.w2v_path
|
116 |
+
- dataset.train_subset
|
117 |
+
- dataset.valid_subset
|
118 |
+
- criterion.wer_kenlm_model
|
119 |
+
- criterion.wer_lexicon
|
120 |
+
run:
|
121 |
+
dir: ???
|
122 |
+
sweep:
|
123 |
+
dir: ???
|
124 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/av-finetune/large_noise_pt_noise_ft_433h.yaml
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
tensorboard_logdir: tblog
|
8 |
+
seed: 1337
|
9 |
+
user_dir: ???
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval: 2
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
best_checkpoint_metric: accuracy
|
16 |
+
maximize_best_checkpoint_metric: true
|
17 |
+
|
18 |
+
distributed_training:
|
19 |
+
ddp_backend: c10d
|
20 |
+
find_unused_parameters: true
|
21 |
+
distributed_world_size: 8
|
22 |
+
distributed_port: 29671
|
23 |
+
nprocs_per_node: 8
|
24 |
+
|
25 |
+
task:
|
26 |
+
_name: av_hubert_pretraining
|
27 |
+
is_s2s: true
|
28 |
+
data: ???
|
29 |
+
label_dir: ???
|
30 |
+
tokenizer_bpe_model: ???
|
31 |
+
normalize: true # must be consistent with pre-training
|
32 |
+
labels: ["wrd"]
|
33 |
+
single_target: true
|
34 |
+
fine_tuning: true
|
35 |
+
stack_order_audio: 4
|
36 |
+
tokenizer_bpe_name: sentencepiece
|
37 |
+
max_sample_size: 500
|
38 |
+
modalities: ["video","audio"]
|
39 |
+
image_aug: true
|
40 |
+
pad_audio: true
|
41 |
+
random_crop: false
|
42 |
+
noise_prob: 0.25
|
43 |
+
noise_snr: 0
|
44 |
+
noise_wav: ???
|
45 |
+
|
46 |
+
dataset:
|
47 |
+
num_workers: 6
|
48 |
+
max_tokens: 1000
|
49 |
+
validate_after_updates: 0
|
50 |
+
validate_interval: 2
|
51 |
+
train_subset: train
|
52 |
+
valid_subset: valid
|
53 |
+
|
54 |
+
criterion:
|
55 |
+
_name: label_smoothed_cross_entropy
|
56 |
+
report_accuracy: true
|
57 |
+
label_smoothing: 0.1
|
58 |
+
|
59 |
+
optimization:
|
60 |
+
max_update: 60000
|
61 |
+
lr: [0.001]
|
62 |
+
sentence_avg: true
|
63 |
+
update_freq: [1]
|
64 |
+
|
65 |
+
optimizer:
|
66 |
+
_name: adam
|
67 |
+
adam_betas: (0.9,0.98)
|
68 |
+
adam_eps: 1e-08
|
69 |
+
|
70 |
+
lr_scheduler:
|
71 |
+
_name: tri_stage
|
72 |
+
warmup_steps: 20000
|
73 |
+
hold_steps: 0
|
74 |
+
decay_steps: 40000
|
75 |
+
final_lr_scale: 0.05
|
76 |
+
|
77 |
+
model:
|
78 |
+
_name: av_hubert_seq2seq
|
79 |
+
w2v_path: ???
|
80 |
+
apply_mask: false
|
81 |
+
mask_selection: static
|
82 |
+
mask_length: 10
|
83 |
+
mask_other: 0
|
84 |
+
mask_prob: 0.75
|
85 |
+
mask_channel_selection: static
|
86 |
+
mask_channel_length: 64
|
87 |
+
mask_channel_other: 0
|
88 |
+
mask_channel_prob: 0.5
|
89 |
+
layerdrop: 0.1
|
90 |
+
dropout: 0.0
|
91 |
+
activation_dropout: 0.1
|
92 |
+
attention_dropout: 0.0
|
93 |
+
feature_grad_mult: 1.0
|
94 |
+
decoder_layers: 9
|
95 |
+
decoder_dropout: 0.1
|
96 |
+
decoder_attention_dropout: 0.0
|
97 |
+
decoder_activation_dropout: 0.1
|
98 |
+
freeze_finetune_updates: 48000
|
99 |
+
share_decoder_input_output_embed: true
|
100 |
+
decoder_normalize_before: true
|
101 |
+
decoder_embed_dim: 1024
|
102 |
+
decoder_ffn_embed_dim: 4096
|
103 |
+
decoder_attention_heads: 8
|
104 |
+
|
105 |
+
hydra:
|
106 |
+
job:
|
107 |
+
config:
|
108 |
+
override_dirname:
|
109 |
+
kv_sep: '-'
|
110 |
+
item_sep: '__'
|
111 |
+
exclude_keys:
|
112 |
+
- run
|
113 |
+
- task.data
|
114 |
+
- task.label_dir
|
115 |
+
- model.w2v_path
|
116 |
+
- dataset.train_subset
|
117 |
+
- dataset.valid_subset
|
118 |
+
- criterion.wer_kenlm_model
|
119 |
+
- criterion.wer_lexicon
|
120 |
+
run:
|
121 |
+
dir: ???
|
122 |
+
sweep:
|
123 |
+
dir: ???
|
124 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/finetune/base_lrs3_30h.yaml
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
tensorboard_logdir: tblog
|
8 |
+
seed: 1337
|
9 |
+
user_dir: ???
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval: 2
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
best_checkpoint_metric: accuracy
|
16 |
+
maximize_best_checkpoint_metric: true
|
17 |
+
|
18 |
+
distributed_training:
|
19 |
+
ddp_backend: c10d
|
20 |
+
find_unused_parameters: true
|
21 |
+
distributed_world_size: 8
|
22 |
+
distributed_port: 29671
|
23 |
+
nprocs_per_node: 8
|
24 |
+
|
25 |
+
task:
|
26 |
+
_name: av_hubert_pretraining
|
27 |
+
is_s2s: true
|
28 |
+
data: ???
|
29 |
+
label_dir: ???
|
30 |
+
tokenizer_bpe_model: ???
|
31 |
+
normalize: true # must be consistent with pre-training
|
32 |
+
labels: ["wrd"]
|
33 |
+
single_target: true
|
34 |
+
fine_tuning: true
|
35 |
+
stack_order_audio: 4
|
36 |
+
tokenizer_bpe_name: sentencepiece
|
37 |
+
max_sample_size: 500
|
38 |
+
modalities: ["video"]
|
39 |
+
image_aug: true
|
40 |
+
pad_audio: true
|
41 |
+
random_crop: false
|
42 |
+
|
43 |
+
dataset:
|
44 |
+
num_workers: 6
|
45 |
+
max_tokens: 1000
|
46 |
+
validate_after_updates: 0
|
47 |
+
validate_interval: 2
|
48 |
+
train_subset: train
|
49 |
+
valid_subset: valid
|
50 |
+
|
51 |
+
criterion:
|
52 |
+
_name: label_smoothed_cross_entropy
|
53 |
+
report_accuracy: true
|
54 |
+
label_smoothing: 0.1
|
55 |
+
|
56 |
+
optimization:
|
57 |
+
max_update: 30000
|
58 |
+
lr: [0.001]
|
59 |
+
sentence_avg: true
|
60 |
+
update_freq: [1]
|
61 |
+
|
62 |
+
optimizer:
|
63 |
+
_name: adam
|
64 |
+
adam_betas: (0.9,0.98)
|
65 |
+
adam_eps: 1e-08
|
66 |
+
|
67 |
+
lr_scheduler:
|
68 |
+
_name: tri_stage
|
69 |
+
warmup_steps: 10000
|
70 |
+
hold_steps: 0
|
71 |
+
decay_steps: 20000
|
72 |
+
final_lr_scale: 0.05
|
73 |
+
|
74 |
+
model:
|
75 |
+
_name: av_hubert_seq2seq
|
76 |
+
w2v_path: ???
|
77 |
+
apply_mask: false
|
78 |
+
mask_selection: static
|
79 |
+
mask_length: 10
|
80 |
+
mask_other: 0
|
81 |
+
mask_prob: 0.75
|
82 |
+
mask_channel_selection: static
|
83 |
+
mask_channel_length: 64
|
84 |
+
mask_channel_other: 0
|
85 |
+
mask_channel_prob: 0.5
|
86 |
+
layerdrop: 0.1
|
87 |
+
dropout: 0.0
|
88 |
+
activation_dropout: 0.1
|
89 |
+
attention_dropout: 0.0
|
90 |
+
feature_grad_mult: 1.0
|
91 |
+
decoder_layers: 6
|
92 |
+
decoder_dropout: 0.1
|
93 |
+
decoder_attention_dropout: 0.0
|
94 |
+
decoder_activation_dropout: 0.1
|
95 |
+
freeze_finetune_updates: 30000
|
96 |
+
share_decoder_input_output_embed: true
|
97 |
+
decoder_normalize_before: true
|
98 |
+
|
99 |
+
hydra:
|
100 |
+
job:
|
101 |
+
config:
|
102 |
+
override_dirname:
|
103 |
+
kv_sep: '-'
|
104 |
+
item_sep: '__'
|
105 |
+
exclude_keys:
|
106 |
+
- run
|
107 |
+
- task.data
|
108 |
+
- task.label_dir
|
109 |
+
- model.w2v_path
|
110 |
+
- dataset.train_subset
|
111 |
+
- dataset.valid_subset
|
112 |
+
- criterion.wer_kenlm_model
|
113 |
+
- criterion.wer_lexicon
|
114 |
+
run:
|
115 |
+
dir: ???
|
116 |
+
sweep:
|
117 |
+
dir: ???
|
118 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/finetune/base_lrs3_433h.yaml
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
tensorboard_logdir: tblog
|
8 |
+
seed: 1337
|
9 |
+
user_dir: ???
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval: 2
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
best_checkpoint_metric: accuracy
|
16 |
+
maximize_best_checkpoint_metric: true
|
17 |
+
|
18 |
+
distributed_training:
|
19 |
+
ddp_backend: c10d
|
20 |
+
find_unused_parameters: true
|
21 |
+
distributed_world_size: 8
|
22 |
+
distributed_port: 29671
|
23 |
+
nprocs_per_node: 8
|
24 |
+
|
25 |
+
task:
|
26 |
+
_name: av_hubert_pretraining
|
27 |
+
is_s2s: true
|
28 |
+
data: ???
|
29 |
+
label_dir: ???
|
30 |
+
tokenizer_bpe_model: ???
|
31 |
+
normalize: true # must be consistent with pre-training
|
32 |
+
labels: ["wrd"]
|
33 |
+
single_target: true
|
34 |
+
fine_tuning: true
|
35 |
+
stack_order_audio: 4
|
36 |
+
tokenizer_bpe_name: sentencepiece
|
37 |
+
max_sample_size: 500
|
38 |
+
modalities: ["video"]
|
39 |
+
image_aug: true
|
40 |
+
pad_audio: true
|
41 |
+
random_crop: false
|
42 |
+
|
43 |
+
dataset:
|
44 |
+
num_workers: 6
|
45 |
+
max_tokens: 1000
|
46 |
+
validate_after_updates: 0
|
47 |
+
validate_interval: 2
|
48 |
+
train_subset: train
|
49 |
+
valid_subset: valid
|
50 |
+
|
51 |
+
criterion:
|
52 |
+
_name: label_smoothed_cross_entropy
|
53 |
+
report_accuracy: true
|
54 |
+
label_smoothing: 0.1
|
55 |
+
|
56 |
+
optimization:
|
57 |
+
max_update: 120000
|
58 |
+
lr: [0.001]
|
59 |
+
sentence_avg: true
|
60 |
+
update_freq: [1]
|
61 |
+
|
62 |
+
optimizer:
|
63 |
+
_name: adam
|
64 |
+
adam_betas: (0.9,0.98)
|
65 |
+
adam_eps: 1e-08
|
66 |
+
|
67 |
+
lr_scheduler:
|
68 |
+
_name: tri_stage
|
69 |
+
warmup_steps: 40000
|
70 |
+
hold_steps: 0
|
71 |
+
decay_steps: 80000
|
72 |
+
final_lr_scale: 0.05
|
73 |
+
|
74 |
+
model:
|
75 |
+
_name: av_hubert_seq2seq
|
76 |
+
w2v_path: ???
|
77 |
+
apply_mask: false
|
78 |
+
mask_selection: static
|
79 |
+
mask_length: 10
|
80 |
+
mask_other: 0
|
81 |
+
mask_prob: 0.75
|
82 |
+
mask_channel_selection: static
|
83 |
+
mask_channel_length: 64
|
84 |
+
mask_channel_other: 0
|
85 |
+
mask_channel_prob: 0.5
|
86 |
+
layerdrop: 0.1
|
87 |
+
dropout: 0.0
|
88 |
+
activation_dropout: 0.1
|
89 |
+
attention_dropout: 0.0
|
90 |
+
feature_grad_mult: 1.0
|
91 |
+
decoder_layers: 6
|
92 |
+
decoder_dropout: 0.1
|
93 |
+
decoder_attention_dropout: 0.0
|
94 |
+
decoder_activation_dropout: 0.1
|
95 |
+
freeze_finetune_updates: 60000
|
96 |
+
share_decoder_input_output_embed: true
|
97 |
+
decoder_normalize_before: true
|
98 |
+
|
99 |
+
hydra:
|
100 |
+
job:
|
101 |
+
config:
|
102 |
+
override_dirname:
|
103 |
+
kv_sep: '-'
|
104 |
+
item_sep: '__'
|
105 |
+
exclude_keys:
|
106 |
+
- run
|
107 |
+
- task.data
|
108 |
+
- task.label_dir
|
109 |
+
- model.w2v_path
|
110 |
+
- dataset.train_subset
|
111 |
+
- dataset.valid_subset
|
112 |
+
- criterion.wer_kenlm_model
|
113 |
+
- criterion.wer_lexicon
|
114 |
+
run:
|
115 |
+
dir: ???
|
116 |
+
sweep:
|
117 |
+
dir: ???
|
118 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/finetune/base_vox_30h.yaml
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
tensorboard_logdir: tblog
|
8 |
+
seed: 1337
|
9 |
+
user_dir: ???
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval: 2
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
best_checkpoint_metric: accuracy
|
16 |
+
maximize_best_checkpoint_metric: true
|
17 |
+
|
18 |
+
distributed_training:
|
19 |
+
ddp_backend: c10d
|
20 |
+
find_unused_parameters: true
|
21 |
+
distributed_world_size: 8
|
22 |
+
distributed_port: 29671
|
23 |
+
nprocs_per_node: 8
|
24 |
+
|
25 |
+
task:
|
26 |
+
_name: av_hubert_pretraining
|
27 |
+
is_s2s: true
|
28 |
+
data: ???
|
29 |
+
label_dir: ???
|
30 |
+
tokenizer_bpe_model: ???
|
31 |
+
normalize: true # must be consistent with pre-training
|
32 |
+
labels: ["wrd"]
|
33 |
+
single_target: true
|
34 |
+
fine_tuning: true
|
35 |
+
stack_order_audio: 4
|
36 |
+
tokenizer_bpe_name: sentencepiece
|
37 |
+
max_sample_size: 500
|
38 |
+
modalities: ["video"]
|
39 |
+
image_aug: true
|
40 |
+
pad_audio: true
|
41 |
+
random_crop: false
|
42 |
+
|
43 |
+
dataset:
|
44 |
+
num_workers: 6
|
45 |
+
max_tokens: 1000
|
46 |
+
validate_after_updates: 0
|
47 |
+
validate_interval: 2
|
48 |
+
train_subset: train
|
49 |
+
valid_subset: valid
|
50 |
+
|
51 |
+
criterion:
|
52 |
+
_name: label_smoothed_cross_entropy
|
53 |
+
report_accuracy: true
|
54 |
+
label_smoothing: 0.1
|
55 |
+
|
56 |
+
optimization:
|
57 |
+
max_update: 30000
|
58 |
+
lr: [0.001]
|
59 |
+
sentence_avg: true
|
60 |
+
update_freq: [1]
|
61 |
+
|
62 |
+
optimizer:
|
63 |
+
_name: adam
|
64 |
+
adam_betas: (0.9,0.98)
|
65 |
+
adam_eps: 1e-08
|
66 |
+
|
67 |
+
lr_scheduler:
|
68 |
+
_name: tri_stage
|
69 |
+
warmup_steps: 10000
|
70 |
+
hold_steps: 0
|
71 |
+
decay_steps: 20000
|
72 |
+
final_lr_scale: 0.05
|
73 |
+
|
74 |
+
model:
|
75 |
+
_name: av_hubert_seq2seq
|
76 |
+
w2v_path: ???
|
77 |
+
apply_mask: false
|
78 |
+
mask_selection: static
|
79 |
+
mask_length: 10
|
80 |
+
mask_other: 0
|
81 |
+
mask_prob: 0.75
|
82 |
+
mask_channel_selection: static
|
83 |
+
mask_channel_length: 64
|
84 |
+
mask_channel_other: 0
|
85 |
+
mask_channel_prob: 0.5
|
86 |
+
layerdrop: 0.1
|
87 |
+
dropout: 0.0
|
88 |
+
activation_dropout: 0.1
|
89 |
+
attention_dropout: 0.0
|
90 |
+
feature_grad_mult: 1.0
|
91 |
+
decoder_layers: 6
|
92 |
+
decoder_dropout: 0.1
|
93 |
+
decoder_attention_dropout: 0.0
|
94 |
+
decoder_activation_dropout: 0.1
|
95 |
+
freeze_finetune_updates: 24000
|
96 |
+
share_decoder_input_output_embed: true
|
97 |
+
decoder_normalize_before: true
|
98 |
+
|
99 |
+
hydra:
|
100 |
+
job:
|
101 |
+
config:
|
102 |
+
override_dirname:
|
103 |
+
kv_sep: '-'
|
104 |
+
item_sep: '__'
|
105 |
+
exclude_keys:
|
106 |
+
- run
|
107 |
+
- task.data
|
108 |
+
- task.label_dir
|
109 |
+
- model.w2v_path
|
110 |
+
- dataset.train_subset
|
111 |
+
- dataset.valid_subset
|
112 |
+
- criterion.wer_kenlm_model
|
113 |
+
- criterion.wer_lexicon
|
114 |
+
run:
|
115 |
+
dir: ???
|
116 |
+
sweep:
|
117 |
+
dir: ???
|
118 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/finetune/base_vox_433h.yaml
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
tensorboard_logdir: tblog
|
8 |
+
seed: 1337
|
9 |
+
user_dir: ???
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval: 2
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
best_checkpoint_metric: accuracy
|
16 |
+
maximize_best_checkpoint_metric: true
|
17 |
+
|
18 |
+
distributed_training:
|
19 |
+
ddp_backend: c10d
|
20 |
+
find_unused_parameters: true
|
21 |
+
distributed_world_size: 8
|
22 |
+
distributed_port: 29671
|
23 |
+
nprocs_per_node: 8
|
24 |
+
|
25 |
+
task:
|
26 |
+
_name: av_hubert_pretraining
|
27 |
+
is_s2s: true
|
28 |
+
data: ???
|
29 |
+
label_dir: ???
|
30 |
+
tokenizer_bpe_model: ???
|
31 |
+
normalize: true # must be consistent with pre-training
|
32 |
+
labels: ["wrd"]
|
33 |
+
single_target: true
|
34 |
+
fine_tuning: true
|
35 |
+
stack_order_audio: 4
|
36 |
+
tokenizer_bpe_name: sentencepiece
|
37 |
+
max_sample_size: 500
|
38 |
+
modalities: ["video"]
|
39 |
+
image_aug: true
|
40 |
+
pad_audio: true
|
41 |
+
random_crop: false
|
42 |
+
|
43 |
+
dataset:
|
44 |
+
num_workers: 6
|
45 |
+
max_tokens: 1000
|
46 |
+
validate_after_updates: 0
|
47 |
+
validate_interval: 2
|
48 |
+
train_subset: train
|
49 |
+
valid_subset: valid
|
50 |
+
|
51 |
+
criterion:
|
52 |
+
_name: label_smoothed_cross_entropy
|
53 |
+
report_accuracy: true
|
54 |
+
label_smoothing: 0.1
|
55 |
+
|
56 |
+
optimization:
|
57 |
+
max_update: 45000
|
58 |
+
lr: [0.001]
|
59 |
+
sentence_avg: true
|
60 |
+
update_freq: [1]
|
61 |
+
|
62 |
+
optimizer:
|
63 |
+
_name: adam
|
64 |
+
adam_betas: (0.9,0.98)
|
65 |
+
adam_eps: 1e-08
|
66 |
+
|
67 |
+
lr_scheduler:
|
68 |
+
_name: tri_stage
|
69 |
+
warmup_steps: 15000
|
70 |
+
hold_steps: 0
|
71 |
+
decay_steps: 30000
|
72 |
+
final_lr_scale: 0.05
|
73 |
+
|
74 |
+
model:
|
75 |
+
_name: av_hubert_seq2seq
|
76 |
+
w2v_path: ???
|
77 |
+
apply_mask: false
|
78 |
+
mask_selection: static
|
79 |
+
mask_length: 10
|
80 |
+
mask_other: 0
|
81 |
+
mask_prob: 0.75
|
82 |
+
mask_channel_selection: static
|
83 |
+
mask_channel_length: 64
|
84 |
+
mask_channel_other: 0
|
85 |
+
mask_channel_prob: 0.5
|
86 |
+
layerdrop: 0.1
|
87 |
+
dropout: 0.0
|
88 |
+
activation_dropout: 0.1
|
89 |
+
attention_dropout: 0.0
|
90 |
+
feature_grad_mult: 1.0
|
91 |
+
decoder_layers: 6
|
92 |
+
decoder_dropout: 0.1
|
93 |
+
decoder_attention_dropout: 0.0
|
94 |
+
decoder_activation_dropout: 0.1
|
95 |
+
freeze_finetune_updates: 22500
|
96 |
+
share_decoder_input_output_embed: true
|
97 |
+
decoder_normalize_before: true
|
98 |
+
|
99 |
+
hydra:
|
100 |
+
job:
|
101 |
+
config:
|
102 |
+
override_dirname:
|
103 |
+
kv_sep: '-'
|
104 |
+
item_sep: '__'
|
105 |
+
exclude_keys:
|
106 |
+
- run
|
107 |
+
- task.data
|
108 |
+
- task.label_dir
|
109 |
+
- model.w2v_path
|
110 |
+
- dataset.train_subset
|
111 |
+
- dataset.valid_subset
|
112 |
+
- criterion.wer_kenlm_model
|
113 |
+
- criterion.wer_lexicon
|
114 |
+
run:
|
115 |
+
dir: ???
|
116 |
+
sweep:
|
117 |
+
dir: ???
|
118 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/finetune/large_lrs3_30h.yaml
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
tensorboard_logdir: tblog
|
8 |
+
seed: 1337
|
9 |
+
user_dir: ???
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval: 2
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
best_checkpoint_metric: accuracy
|
16 |
+
maximize_best_checkpoint_metric: true
|
17 |
+
|
18 |
+
distributed_training:
|
19 |
+
ddp_backend: c10d
|
20 |
+
find_unused_parameters: true
|
21 |
+
distributed_world_size: 8
|
22 |
+
distributed_port: 29671
|
23 |
+
nprocs_per_node: 8
|
24 |
+
|
25 |
+
task:
|
26 |
+
_name: av_hubert_pretraining
|
27 |
+
is_s2s: true
|
28 |
+
data: ???
|
29 |
+
label_dir: ???
|
30 |
+
tokenizer_bpe_model: ???
|
31 |
+
normalize: true # must be consistent with pre-training
|
32 |
+
labels: ["wrd"]
|
33 |
+
single_target: true
|
34 |
+
fine_tuning: true
|
35 |
+
stack_order_audio: 4
|
36 |
+
tokenizer_bpe_name: sentencepiece
|
37 |
+
max_sample_size: 500
|
38 |
+
modalities: ["video"]
|
39 |
+
image_aug: true
|
40 |
+
pad_audio: true
|
41 |
+
random_crop: false
|
42 |
+
|
43 |
+
dataset:
|
44 |
+
num_workers: 6
|
45 |
+
max_tokens: 1000
|
46 |
+
validate_after_updates: 0
|
47 |
+
validate_interval: 2
|
48 |
+
train_subset: train
|
49 |
+
valid_subset: valid
|
50 |
+
|
51 |
+
criterion:
|
52 |
+
_name: label_smoothed_cross_entropy
|
53 |
+
report_accuracy: true
|
54 |
+
label_smoothing: 0.1
|
55 |
+
|
56 |
+
optimization:
|
57 |
+
max_update: 18000
|
58 |
+
lr: [0.001]
|
59 |
+
sentence_avg: true
|
60 |
+
update_freq: [1]
|
61 |
+
|
62 |
+
optimizer:
|
63 |
+
_name: adam
|
64 |
+
adam_betas: (0.9,0.98)
|
65 |
+
adam_eps: 1e-08
|
66 |
+
|
67 |
+
lr_scheduler:
|
68 |
+
_name: tri_stage
|
69 |
+
warmup_steps: 6000
|
70 |
+
hold_steps: 0
|
71 |
+
decay_steps: 12000
|
72 |
+
final_lr_scale: 0.05
|
73 |
+
|
74 |
+
model:
|
75 |
+
_name: av_hubert_seq2seq
|
76 |
+
w2v_path: ???
|
77 |
+
apply_mask: false
|
78 |
+
mask_selection: static
|
79 |
+
mask_length: 10
|
80 |
+
mask_other: 0
|
81 |
+
mask_prob: 0.75
|
82 |
+
mask_channel_selection: static
|
83 |
+
mask_channel_length: 64
|
84 |
+
mask_channel_other: 0
|
85 |
+
mask_channel_prob: 0.5
|
86 |
+
layerdrop: 0.1
|
87 |
+
dropout: 0.0
|
88 |
+
activation_dropout: 0.1
|
89 |
+
attention_dropout: 0.0
|
90 |
+
feature_grad_mult: 1.0
|
91 |
+
decoder_layers: 9
|
92 |
+
decoder_dropout: 0.1
|
93 |
+
decoder_attention_dropout: 0.0
|
94 |
+
decoder_activation_dropout: 0.1
|
95 |
+
freeze_finetune_updates: 14400
|
96 |
+
share_decoder_input_output_embed: true
|
97 |
+
decoder_normalize_before: true
|
98 |
+
decoder_embed_dim: 1024
|
99 |
+
decoder_ffn_embed_dim: 4096
|
100 |
+
decoder_attention_heads: 8
|
101 |
+
|
102 |
+
hydra:
|
103 |
+
job:
|
104 |
+
config:
|
105 |
+
override_dirname:
|
106 |
+
kv_sep: '-'
|
107 |
+
item_sep: '__'
|
108 |
+
exclude_keys:
|
109 |
+
- run
|
110 |
+
- task.data
|
111 |
+
- task.label_dir
|
112 |
+
- model.w2v_path
|
113 |
+
- dataset.train_subset
|
114 |
+
- dataset.valid_subset
|
115 |
+
- criterion.wer_kenlm_model
|
116 |
+
- criterion.wer_lexicon
|
117 |
+
run:
|
118 |
+
dir: ???
|
119 |
+
sweep:
|
120 |
+
dir: ???
|
121 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/finetune/large_lrs3_433h.yaml
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
tensorboard_logdir: tblog
|
8 |
+
seed: 1337
|
9 |
+
user_dir: ???
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval: 2
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
best_checkpoint_metric: accuracy
|
16 |
+
maximize_best_checkpoint_metric: true
|
17 |
+
|
18 |
+
distributed_training:
|
19 |
+
ddp_backend: c10d
|
20 |
+
find_unused_parameters: true
|
21 |
+
distributed_world_size: 8
|
22 |
+
distributed_port: 29671
|
23 |
+
nprocs_per_node: 8
|
24 |
+
|
25 |
+
task:
|
26 |
+
_name: av_hubert_pretraining
|
27 |
+
is_s2s: true
|
28 |
+
data: ???
|
29 |
+
label_dir: ???
|
30 |
+
tokenizer_bpe_model: ???
|
31 |
+
normalize: true # must be consistent with pre-training
|
32 |
+
labels: ["wrd"]
|
33 |
+
single_target: true
|
34 |
+
fine_tuning: true
|
35 |
+
stack_order_audio: 4
|
36 |
+
tokenizer_bpe_name: sentencepiece
|
37 |
+
max_sample_size: 500
|
38 |
+
modalities: ["video"]
|
39 |
+
image_aug: true
|
40 |
+
pad_audio: true
|
41 |
+
random_crop: false
|
42 |
+
|
43 |
+
dataset:
|
44 |
+
num_workers: 6
|
45 |
+
max_tokens: 1000
|
46 |
+
validate_after_updates: 0
|
47 |
+
validate_interval: 2
|
48 |
+
train_subset: train
|
49 |
+
valid_subset: valid
|
50 |
+
|
51 |
+
criterion:
|
52 |
+
_name: label_smoothed_cross_entropy
|
53 |
+
report_accuracy: true
|
54 |
+
label_smoothing: 0.1
|
55 |
+
|
56 |
+
optimization:
|
57 |
+
max_update: 30000
|
58 |
+
lr: [0.001]
|
59 |
+
sentence_avg: true
|
60 |
+
update_freq: [1]
|
61 |
+
|
62 |
+
optimizer:
|
63 |
+
_name: adam
|
64 |
+
adam_betas: (0.9,0.98)
|
65 |
+
adam_eps: 1e-08
|
66 |
+
|
67 |
+
lr_scheduler:
|
68 |
+
_name: tri_stage
|
69 |
+
warmup_steps: 10000
|
70 |
+
hold_steps: 0
|
71 |
+
decay_steps: 20000
|
72 |
+
final_lr_scale: 0.05
|
73 |
+
|
74 |
+
model:
|
75 |
+
_name: av_hubert_seq2seq
|
76 |
+
w2v_path: ???
|
77 |
+
apply_mask: false
|
78 |
+
mask_selection: static
|
79 |
+
mask_length: 10
|
80 |
+
mask_other: 0
|
81 |
+
mask_prob: 0.75
|
82 |
+
mask_channel_selection: static
|
83 |
+
mask_channel_length: 64
|
84 |
+
mask_channel_other: 0
|
85 |
+
mask_channel_prob: 0.5
|
86 |
+
layerdrop: 0.1
|
87 |
+
dropout: 0.0
|
88 |
+
activation_dropout: 0.1
|
89 |
+
attention_dropout: 0.0
|
90 |
+
feature_grad_mult: 1.0
|
91 |
+
decoder_layers: 9
|
92 |
+
decoder_dropout: 0.1
|
93 |
+
decoder_attention_dropout: 0.0
|
94 |
+
decoder_activation_dropout: 0.1
|
95 |
+
freeze_finetune_updates: 18000
|
96 |
+
share_decoder_input_output_embed: true
|
97 |
+
decoder_normalize_before: true
|
98 |
+
decoder_embed_dim: 1024
|
99 |
+
decoder_ffn_embed_dim: 4096
|
100 |
+
decoder_attention_heads: 8
|
101 |
+
|
102 |
+
hydra:
|
103 |
+
job:
|
104 |
+
config:
|
105 |
+
override_dirname:
|
106 |
+
kv_sep: '-'
|
107 |
+
item_sep: '__'
|
108 |
+
exclude_keys:
|
109 |
+
- run
|
110 |
+
- task.data
|
111 |
+
- task.label_dir
|
112 |
+
- model.w2v_path
|
113 |
+
- dataset.train_subset
|
114 |
+
- dataset.valid_subset
|
115 |
+
- criterion.wer_kenlm_model
|
116 |
+
- criterion.wer_lexicon
|
117 |
+
run:
|
118 |
+
dir: ???
|
119 |
+
sweep:
|
120 |
+
dir: ???
|
121 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/finetune/large_vox_30h.yaml
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
tensorboard_logdir: tblog
|
8 |
+
seed: 1337
|
9 |
+
user_dir: ???
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval: 2
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
best_checkpoint_metric: accuracy
|
16 |
+
maximize_best_checkpoint_metric: true
|
17 |
+
|
18 |
+
distributed_training:
|
19 |
+
ddp_backend: c10d
|
20 |
+
find_unused_parameters: true
|
21 |
+
distributed_world_size: 8
|
22 |
+
distributed_port: 29671
|
23 |
+
nprocs_per_node: 8
|
24 |
+
|
25 |
+
task:
|
26 |
+
_name: av_hubert_pretraining
|
27 |
+
is_s2s: true
|
28 |
+
data: ???
|
29 |
+
label_dir: ???
|
30 |
+
tokenizer_bpe_model: ???
|
31 |
+
normalize: true # must be consistent with pre-training
|
32 |
+
labels: ["wrd"]
|
33 |
+
single_target: true
|
34 |
+
fine_tuning: true
|
35 |
+
stack_order_audio: 4
|
36 |
+
tokenizer_bpe_name: sentencepiece
|
37 |
+
max_sample_size: 500
|
38 |
+
modalities: ["video"]
|
39 |
+
image_aug: true
|
40 |
+
pad_audio: true
|
41 |
+
random_crop: false
|
42 |
+
|
43 |
+
dataset:
|
44 |
+
num_workers: 6
|
45 |
+
max_tokens: 1000
|
46 |
+
validate_after_updates: 0
|
47 |
+
validate_interval: 2
|
48 |
+
train_subset: train
|
49 |
+
valid_subset: valid
|
50 |
+
|
51 |
+
criterion:
|
52 |
+
_name: label_smoothed_cross_entropy
|
53 |
+
report_accuracy: true
|
54 |
+
label_smoothing: 0.1
|
55 |
+
|
56 |
+
optimization:
|
57 |
+
max_update: 30000
|
58 |
+
lr: [0.001]
|
59 |
+
sentence_avg: true
|
60 |
+
update_freq: [1]
|
61 |
+
|
62 |
+
optimizer:
|
63 |
+
_name: adam
|
64 |
+
adam_betas: (0.9,0.98)
|
65 |
+
adam_eps: 1e-08
|
66 |
+
|
67 |
+
lr_scheduler:
|
68 |
+
_name: tri_stage
|
69 |
+
warmup_steps: 10000
|
70 |
+
hold_steps: 0
|
71 |
+
decay_steps: 20000
|
72 |
+
final_lr_scale: 0.05
|
73 |
+
|
74 |
+
model:
|
75 |
+
_name: av_hubert_seq2seq
|
76 |
+
w2v_path: ???
|
77 |
+
apply_mask: false
|
78 |
+
mask_selection: static
|
79 |
+
mask_length: 10
|
80 |
+
mask_other: 0
|
81 |
+
mask_prob: 0.75
|
82 |
+
mask_channel_selection: static
|
83 |
+
mask_channel_length: 64
|
84 |
+
mask_channel_other: 0
|
85 |
+
mask_channel_prob: 0.5
|
86 |
+
layerdrop: 0.1
|
87 |
+
dropout: 0.0
|
88 |
+
activation_dropout: 0.1
|
89 |
+
attention_dropout: 0.0
|
90 |
+
feature_grad_mult: 1.0
|
91 |
+
decoder_layers: 9
|
92 |
+
decoder_dropout: 0.1
|
93 |
+
decoder_attention_dropout: 0.0
|
94 |
+
decoder_activation_dropout: 0.1
|
95 |
+
freeze_finetune_updates: 30000
|
96 |
+
share_decoder_input_output_embed: true
|
97 |
+
decoder_normalize_before: true
|
98 |
+
decoder_embed_dim: 1024
|
99 |
+
decoder_ffn_embed_dim: 4096
|
100 |
+
decoder_attention_heads: 8
|
101 |
+
|
102 |
+
hydra:
|
103 |
+
job:
|
104 |
+
config:
|
105 |
+
override_dirname:
|
106 |
+
kv_sep: '-'
|
107 |
+
item_sep: '__'
|
108 |
+
exclude_keys:
|
109 |
+
- run
|
110 |
+
- task.data
|
111 |
+
- task.label_dir
|
112 |
+
- model.w2v_path
|
113 |
+
- dataset.train_subset
|
114 |
+
- dataset.valid_subset
|
115 |
+
- criterion.wer_kenlm_model
|
116 |
+
- criterion.wer_lexicon
|
117 |
+
run:
|
118 |
+
dir: ???
|
119 |
+
sweep:
|
120 |
+
dir: ???
|
121 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/finetune/large_vox_433h.yaml
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
tensorboard_logdir: tblog
|
8 |
+
seed: 1337
|
9 |
+
user_dir: ???
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval: 2
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
best_checkpoint_metric: accuracy
|
16 |
+
maximize_best_checkpoint_metric: true
|
17 |
+
|
18 |
+
distributed_training:
|
19 |
+
ddp_backend: c10d
|
20 |
+
find_unused_parameters: true
|
21 |
+
distributed_world_size: 8
|
22 |
+
distributed_port: 29671
|
23 |
+
nprocs_per_node: 8
|
24 |
+
|
25 |
+
task:
|
26 |
+
_name: av_hubert_pretraining
|
27 |
+
is_s2s: true
|
28 |
+
data: ???
|
29 |
+
label_dir: ???
|
30 |
+
tokenizer_bpe_model: ???
|
31 |
+
normalize: true # must be consistent with pre-training
|
32 |
+
labels: ["wrd"]
|
33 |
+
single_target: true
|
34 |
+
fine_tuning: true
|
35 |
+
stack_order_audio: 4
|
36 |
+
tokenizer_bpe_name: sentencepiece
|
37 |
+
max_sample_size: 500
|
38 |
+
modalities: ["video"]
|
39 |
+
image_aug: true
|
40 |
+
pad_audio: true
|
41 |
+
random_crop: false
|
42 |
+
|
43 |
+
dataset:
|
44 |
+
num_workers: 6
|
45 |
+
max_tokens: 1000
|
46 |
+
validate_after_updates: 0
|
47 |
+
validate_interval: 2
|
48 |
+
train_subset: train
|
49 |
+
valid_subset: valid
|
50 |
+
|
51 |
+
criterion:
|
52 |
+
_name: label_smoothed_cross_entropy
|
53 |
+
report_accuracy: true
|
54 |
+
label_smoothing: 0.1
|
55 |
+
|
56 |
+
optimization:
|
57 |
+
max_update: 30000
|
58 |
+
lr: [0.001]
|
59 |
+
sentence_avg: true
|
60 |
+
update_freq: [1]
|
61 |
+
|
62 |
+
optimizer:
|
63 |
+
_name: adam
|
64 |
+
adam_betas: (0.9,0.98)
|
65 |
+
adam_eps: 1e-08
|
66 |
+
|
67 |
+
lr_scheduler:
|
68 |
+
_name: tri_stage
|
69 |
+
warmup_steps: 10000
|
70 |
+
hold_steps: 0
|
71 |
+
decay_steps: 20000
|
72 |
+
final_lr_scale: 0.05
|
73 |
+
|
74 |
+
model:
|
75 |
+
_name: av_hubert_seq2seq
|
76 |
+
w2v_path: ???
|
77 |
+
apply_mask: false
|
78 |
+
mask_selection: static
|
79 |
+
mask_length: 10
|
80 |
+
mask_other: 0
|
81 |
+
mask_prob: 0.75
|
82 |
+
mask_channel_selection: static
|
83 |
+
mask_channel_length: 64
|
84 |
+
mask_channel_other: 0
|
85 |
+
mask_channel_prob: 0.5
|
86 |
+
layerdrop: 0.1
|
87 |
+
dropout: 0.0
|
88 |
+
activation_dropout: 0.1
|
89 |
+
attention_dropout: 0.0
|
90 |
+
feature_grad_mult: 1.0
|
91 |
+
decoder_layers: 9
|
92 |
+
decoder_dropout: 0.1
|
93 |
+
decoder_attention_dropout: 0.0
|
94 |
+
decoder_activation_dropout: 0.1
|
95 |
+
freeze_finetune_updates: 30000
|
96 |
+
share_decoder_input_output_embed: true
|
97 |
+
decoder_normalize_before: true
|
98 |
+
decoder_embed_dim: 1024
|
99 |
+
decoder_ffn_embed_dim: 4096
|
100 |
+
decoder_attention_heads: 8
|
101 |
+
|
102 |
+
hydra:
|
103 |
+
job:
|
104 |
+
config:
|
105 |
+
override_dirname:
|
106 |
+
kv_sep: '-'
|
107 |
+
item_sep: '__'
|
108 |
+
exclude_keys:
|
109 |
+
- run
|
110 |
+
- task.data
|
111 |
+
- task.label_dir
|
112 |
+
- model.w2v_path
|
113 |
+
- dataset.train_subset
|
114 |
+
- dataset.valid_subset
|
115 |
+
- criterion.wer_kenlm_model
|
116 |
+
- criterion.wer_lexicon
|
117 |
+
run:
|
118 |
+
dir: ???
|
119 |
+
sweep:
|
120 |
+
dir: ???
|
121 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/finetune/self_large_vox_30h.yaml
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
tensorboard_logdir: tblog
|
8 |
+
seed: 1337
|
9 |
+
user_dir: ???
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval: 2
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
best_checkpoint_metric: accuracy
|
16 |
+
maximize_best_checkpoint_metric: true
|
17 |
+
|
18 |
+
distributed_training:
|
19 |
+
ddp_backend: c10d
|
20 |
+
find_unused_parameters: true
|
21 |
+
distributed_world_size: 32
|
22 |
+
distributed_port: 29671
|
23 |
+
nprocs_per_node: 8
|
24 |
+
|
25 |
+
task:
|
26 |
+
_name: av_hubert_pretraining
|
27 |
+
is_s2s: true
|
28 |
+
data: ???
|
29 |
+
label_dir: ???
|
30 |
+
tokenizer_bpe_model: ???
|
31 |
+
normalize: true # must be consistent with pre-training
|
32 |
+
labels: ["wrd"]
|
33 |
+
single_target: true
|
34 |
+
fine_tuning: true
|
35 |
+
stack_order_audio: 4
|
36 |
+
tokenizer_bpe_name: sentencepiece
|
37 |
+
max_sample_size: 500
|
38 |
+
modalities: ["video"]
|
39 |
+
image_aug: true
|
40 |
+
pad_audio: true
|
41 |
+
random_crop: false
|
42 |
+
|
43 |
+
dataset:
|
44 |
+
num_workers: 6
|
45 |
+
max_tokens: 1000
|
46 |
+
validate_after_updates: 0
|
47 |
+
validate_interval: 2
|
48 |
+
train_subset: train
|
49 |
+
valid_subset: valid
|
50 |
+
|
51 |
+
criterion:
|
52 |
+
_name: label_smoothed_cross_entropy
|
53 |
+
report_accuracy: true
|
54 |
+
label_smoothing: 0.1
|
55 |
+
|
56 |
+
optimization:
|
57 |
+
max_update: 100000
|
58 |
+
lr: [0.001]
|
59 |
+
sentence_avg: true
|
60 |
+
update_freq: [1]
|
61 |
+
|
62 |
+
optimizer:
|
63 |
+
_name: adam
|
64 |
+
adam_betas: (0.9,0.98)
|
65 |
+
adam_eps: 1e-08
|
66 |
+
|
67 |
+
lr_scheduler:
|
68 |
+
_name: tri_stage
|
69 |
+
warmup_steps: 10000
|
70 |
+
hold_steps: 0
|
71 |
+
decay_steps: 90000
|
72 |
+
final_lr_scale: 0.05
|
73 |
+
|
74 |
+
model:
|
75 |
+
_name: av_hubert_seq2seq
|
76 |
+
w2v_path: ???
|
77 |
+
apply_mask: false
|
78 |
+
mask_selection: static
|
79 |
+
mask_length: 10
|
80 |
+
mask_other: 0
|
81 |
+
mask_prob: 0.75
|
82 |
+
mask_channel_selection: static
|
83 |
+
mask_channel_length: 64
|
84 |
+
mask_channel_other: 0
|
85 |
+
mask_channel_prob: 0.5
|
86 |
+
layerdrop: 0.1
|
87 |
+
dropout: 0.0
|
88 |
+
activation_dropout: 0.1
|
89 |
+
attention_dropout: 0.0
|
90 |
+
feature_grad_mult: 1.0
|
91 |
+
decoder_layers: 9
|
92 |
+
decoder_dropout: 0.1
|
93 |
+
decoder_attention_dropout: 0.0
|
94 |
+
decoder_activation_dropout: 0.1
|
95 |
+
freeze_finetune_updates: 80000
|
96 |
+
share_decoder_input_output_embed: true
|
97 |
+
decoder_normalize_before: true
|
98 |
+
decoder_embed_dim: 1024
|
99 |
+
decoder_ffn_embed_dim: 4096
|
100 |
+
decoder_attention_heads: 8
|
101 |
+
|
102 |
+
hydra:
|
103 |
+
job:
|
104 |
+
config:
|
105 |
+
override_dirname:
|
106 |
+
kv_sep: '-'
|
107 |
+
item_sep: '__'
|
108 |
+
exclude_keys:
|
109 |
+
- run
|
110 |
+
- task.data
|
111 |
+
- task.label_dir
|
112 |
+
- model.w2v_path
|
113 |
+
- dataset.train_subset
|
114 |
+
- dataset.valid_subset
|
115 |
+
- criterion.wer_kenlm_model
|
116 |
+
- criterion.wer_lexicon
|
117 |
+
run:
|
118 |
+
dir: ???
|
119 |
+
sweep:
|
120 |
+
dir: ???
|
121 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/finetune/self_large_vox_433h.yaml
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
tensorboard_logdir: tblog
|
8 |
+
seed: 1337
|
9 |
+
user_dir: ???
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval: 2
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
best_checkpoint_metric: accuracy
|
16 |
+
maximize_best_checkpoint_metric: true
|
17 |
+
|
18 |
+
distributed_training:
|
19 |
+
ddp_backend: c10d
|
20 |
+
find_unused_parameters: true
|
21 |
+
distributed_world_size: 32
|
22 |
+
distributed_port: 29671
|
23 |
+
nprocs_per_node: 8
|
24 |
+
|
25 |
+
task:
|
26 |
+
_name: av_hubert_pretraining
|
27 |
+
is_s2s: true
|
28 |
+
data: ???
|
29 |
+
label_dir: ???
|
30 |
+
tokenizer_bpe_model: ???
|
31 |
+
normalize: true # must be consistent with pre-training
|
32 |
+
labels: ["wrd"]
|
33 |
+
single_target: true
|
34 |
+
fine_tuning: true
|
35 |
+
stack_order_audio: 4
|
36 |
+
tokenizer_bpe_name: sentencepiece
|
37 |
+
max_sample_size: 500
|
38 |
+
modalities: ["video"]
|
39 |
+
image_aug: true
|
40 |
+
pad_audio: true
|
41 |
+
random_crop: false
|
42 |
+
|
43 |
+
dataset:
|
44 |
+
num_workers: 6
|
45 |
+
max_tokens: 1000
|
46 |
+
validate_after_updates: 0
|
47 |
+
validate_interval: 2
|
48 |
+
train_subset: train
|
49 |
+
valid_subset: valid
|
50 |
+
|
51 |
+
criterion:
|
52 |
+
_name: label_smoothed_cross_entropy
|
53 |
+
report_accuracy: true
|
54 |
+
label_smoothing: 0.1
|
55 |
+
|
56 |
+
optimization:
|
57 |
+
max_update: 100000
|
58 |
+
lr: [0.001]
|
59 |
+
sentence_avg: true
|
60 |
+
update_freq: [1]
|
61 |
+
|
62 |
+
optimizer:
|
63 |
+
_name: adam
|
64 |
+
adam_betas: (0.9,0.98)
|
65 |
+
adam_eps: 1e-08
|
66 |
+
|
67 |
+
lr_scheduler:
|
68 |
+
_name: tri_stage
|
69 |
+
warmup_steps: 10000
|
70 |
+
hold_steps: 0
|
71 |
+
decay_steps: 90000
|
72 |
+
final_lr_scale: 0.05
|
73 |
+
|
74 |
+
model:
|
75 |
+
_name: av_hubert_seq2seq
|
76 |
+
w2v_path: ???
|
77 |
+
apply_mask: false
|
78 |
+
mask_selection: static
|
79 |
+
mask_length: 10
|
80 |
+
mask_other: 0
|
81 |
+
mask_prob: 0.75
|
82 |
+
mask_channel_selection: static
|
83 |
+
mask_channel_length: 64
|
84 |
+
mask_channel_other: 0
|
85 |
+
mask_channel_prob: 0.5
|
86 |
+
layerdrop: 0.1
|
87 |
+
dropout: 0.0
|
88 |
+
activation_dropout: 0.1
|
89 |
+
attention_dropout: 0.0
|
90 |
+
feature_grad_mult: 1.0
|
91 |
+
decoder_layers: 9
|
92 |
+
decoder_dropout: 0.1
|
93 |
+
decoder_attention_dropout: 0.0
|
94 |
+
decoder_activation_dropout: 0.1
|
95 |
+
freeze_finetune_updates: 80000
|
96 |
+
share_decoder_input_output_embed: true
|
97 |
+
decoder_normalize_before: true
|
98 |
+
decoder_embed_dim: 1024
|
99 |
+
decoder_ffn_embed_dim: 4096
|
100 |
+
decoder_attention_heads: 8
|
101 |
+
|
102 |
+
hydra:
|
103 |
+
job:
|
104 |
+
config:
|
105 |
+
override_dirname:
|
106 |
+
kv_sep: '-'
|
107 |
+
item_sep: '__'
|
108 |
+
exclude_keys:
|
109 |
+
- run
|
110 |
+
- task.data
|
111 |
+
- task.label_dir
|
112 |
+
- model.w2v_path
|
113 |
+
- dataset.train_subset
|
114 |
+
- dataset.valid_subset
|
115 |
+
- criterion.wer_kenlm_model
|
116 |
+
- criterion.wer_lexicon
|
117 |
+
run:
|
118 |
+
dir: ???
|
119 |
+
sweep:
|
120 |
+
dir: ???
|
121 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/base_lrs3_iter1.yaml
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
seed: 1337
|
8 |
+
user_dir: ???
|
9 |
+
empty_cache_freq: 10000
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval_updates: 25000
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
|
16 |
+
|
17 |
+
distributed_training:
|
18 |
+
ddp_backend: no_c10d
|
19 |
+
distributed_backend: 'nccl'
|
20 |
+
distributed_world_size: 32
|
21 |
+
distributed_port: 29671
|
22 |
+
nprocs_per_node: 8
|
23 |
+
|
24 |
+
task:
|
25 |
+
_name: av_hubert_pretraining
|
26 |
+
data: ???
|
27 |
+
label_dir: ???
|
28 |
+
labels: ["mfcc"]
|
29 |
+
label_rate: ${model.label_rate}
|
30 |
+
sample_rate: 25
|
31 |
+
max_sample_size: 500
|
32 |
+
min_sample_size: 5
|
33 |
+
pad_audio: true
|
34 |
+
random_crop: false
|
35 |
+
normalize: true
|
36 |
+
stack_order_audio: 4
|
37 |
+
# stack_order: 1
|
38 |
+
input_modality: image
|
39 |
+
image_aug: true
|
40 |
+
|
41 |
+
dataset:
|
42 |
+
num_workers: 6
|
43 |
+
max_tokens: 1000
|
44 |
+
skip_invalid_size_inputs_valid_test: true
|
45 |
+
validate_interval: 5
|
46 |
+
validate_interval_updates: 10000
|
47 |
+
|
48 |
+
criterion:
|
49 |
+
_name: av_hubert
|
50 |
+
pred_masked_weight: 1.0
|
51 |
+
pred_nomask_weight: 0.0
|
52 |
+
loss_weights: [10,]
|
53 |
+
|
54 |
+
optimization:
|
55 |
+
max_update: 400000
|
56 |
+
lr: [0.0005]
|
57 |
+
clip_norm: 10.0
|
58 |
+
|
59 |
+
optimizer:
|
60 |
+
_name: adam
|
61 |
+
adam_betas: (0.9,0.98)
|
62 |
+
adam_eps: 1e-06
|
63 |
+
weight_decay: 0.01
|
64 |
+
|
65 |
+
lr_scheduler:
|
66 |
+
_name: polynomial_decay
|
67 |
+
warmup_updates: 32000
|
68 |
+
|
69 |
+
model:
|
70 |
+
_name: av_hubert
|
71 |
+
label_rate: 100
|
72 |
+
skip_masked: false
|
73 |
+
skip_nomask: false
|
74 |
+
modality_dropout: 0
|
75 |
+
audio_dropout: 0.5
|
76 |
+
modality_fuse: concat
|
77 |
+
selection_type: same_seq
|
78 |
+
masking_type: feature
|
79 |
+
mask_prob_image: 0.8
|
80 |
+
mask_length_image: 10
|
81 |
+
mask_prob_audio: 0.8
|
82 |
+
mask_length_audio: 10
|
83 |
+
extractor_mode: default
|
84 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
85 |
+
final_dim: 256
|
86 |
+
encoder_layerdrop: 0.05
|
87 |
+
dropout_input: 0.1
|
88 |
+
dropout_features: 0.1
|
89 |
+
dropout: 0.1
|
90 |
+
attention_dropout: 0.1
|
91 |
+
feature_grad_mult: 0.1
|
92 |
+
untie_final_proj: true
|
93 |
+
activation_dropout: 0.0
|
94 |
+
wav_input: false
|
95 |
+
layer_norm_first: true
|
96 |
+
audio_feat_dim: 104
|
97 |
+
|
98 |
+
hydra:
|
99 |
+
job:
|
100 |
+
config:
|
101 |
+
override_dirname:
|
102 |
+
kv_sep: '-'
|
103 |
+
item_sep: '__'
|
104 |
+
exclude_keys:
|
105 |
+
- run
|
106 |
+
- task.data
|
107 |
+
- task.label_dir
|
108 |
+
run:
|
109 |
+
dir: ???
|
110 |
+
sweep:
|
111 |
+
dir: ???
|
112 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/base_lrs3_iter2.yaml
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
seed: 1337
|
8 |
+
user_dir: ???
|
9 |
+
empty_cache_freq: 10000
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval_updates: 25000
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
|
16 |
+
|
17 |
+
distributed_training:
|
18 |
+
ddp_backend: no_c10d
|
19 |
+
distributed_backend: 'nccl'
|
20 |
+
distributed_world_size: 32
|
21 |
+
distributed_port: 29671
|
22 |
+
nprocs_per_node: 8
|
23 |
+
|
24 |
+
task:
|
25 |
+
_name: av_hubert_pretraining
|
26 |
+
data: ???
|
27 |
+
label_dir: ???
|
28 |
+
labels: ["mfcc"]
|
29 |
+
label_rate: ${model.label_rate}
|
30 |
+
sample_rate: 25
|
31 |
+
max_sample_size: 500
|
32 |
+
min_sample_size: 5
|
33 |
+
pad_audio: true
|
34 |
+
random_crop: false
|
35 |
+
normalize: true
|
36 |
+
stack_order_audio: 4
|
37 |
+
# stack_order: 1
|
38 |
+
input_modality: image
|
39 |
+
image_aug: true
|
40 |
+
|
41 |
+
dataset:
|
42 |
+
num_workers: 6
|
43 |
+
max_tokens: 1000
|
44 |
+
skip_invalid_size_inputs_valid_test: true
|
45 |
+
validate_interval: 5
|
46 |
+
validate_interval_updates: 10000
|
47 |
+
|
48 |
+
criterion:
|
49 |
+
_name: av_hubert
|
50 |
+
pred_masked_weight: 1.0
|
51 |
+
pred_nomask_weight: 0.0
|
52 |
+
loss_weights: [10,]
|
53 |
+
|
54 |
+
optimization:
|
55 |
+
max_update: 400000
|
56 |
+
lr: [0.0005]
|
57 |
+
clip_norm: 10.0
|
58 |
+
|
59 |
+
optimizer:
|
60 |
+
_name: adam
|
61 |
+
adam_betas: (0.9,0.98)
|
62 |
+
adam_eps: 1e-06
|
63 |
+
weight_decay: 0.01
|
64 |
+
|
65 |
+
lr_scheduler:
|
66 |
+
_name: polynomial_decay
|
67 |
+
warmup_updates: 32000
|
68 |
+
|
69 |
+
model:
|
70 |
+
_name: av_hubert
|
71 |
+
label_rate: 25
|
72 |
+
skip_masked: false
|
73 |
+
skip_nomask: false
|
74 |
+
modality_dropout: 0
|
75 |
+
audio_dropout: 0.5
|
76 |
+
modality_fuse: concat
|
77 |
+
selection_type: same_seq
|
78 |
+
masking_type: feature
|
79 |
+
mask_prob_image: 0.8
|
80 |
+
mask_length_image: 10
|
81 |
+
mask_prob_audio: 0.8
|
82 |
+
mask_length_audio: 10
|
83 |
+
extractor_mode: default
|
84 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
85 |
+
final_dim: 256
|
86 |
+
encoder_layerdrop: 0.05
|
87 |
+
dropout_input: 0.1
|
88 |
+
dropout_features: 0.1
|
89 |
+
dropout: 0.1
|
90 |
+
attention_dropout: 0.1
|
91 |
+
feature_grad_mult: 0.1
|
92 |
+
untie_final_proj: true
|
93 |
+
activation_dropout: 0.0
|
94 |
+
wav_input: false
|
95 |
+
layer_norm_first: true
|
96 |
+
audio_feat_dim: 104
|
97 |
+
|
98 |
+
hydra:
|
99 |
+
job:
|
100 |
+
config:
|
101 |
+
override_dirname:
|
102 |
+
kv_sep: '-'
|
103 |
+
item_sep: '__'
|
104 |
+
exclude_keys:
|
105 |
+
- run
|
106 |
+
- task.data
|
107 |
+
- task.label_dir
|
108 |
+
run:
|
109 |
+
dir: ???
|
110 |
+
sweep:
|
111 |
+
dir: ???
|
112 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/base_lrs3_iter3.yaml
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
seed: 1337
|
8 |
+
user_dir: ???
|
9 |
+
empty_cache_freq: 10000
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval_updates: 25000
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
|
16 |
+
|
17 |
+
distributed_training:
|
18 |
+
ddp_backend: no_c10d
|
19 |
+
distributed_backend: 'nccl'
|
20 |
+
distributed_world_size: 32
|
21 |
+
distributed_port: 29671
|
22 |
+
nprocs_per_node: 8
|
23 |
+
|
24 |
+
task:
|
25 |
+
_name: av_hubert_pretraining
|
26 |
+
data: ???
|
27 |
+
label_dir: ???
|
28 |
+
labels: ["mfcc"]
|
29 |
+
label_rate: ${model.label_rate}
|
30 |
+
sample_rate: 25
|
31 |
+
max_sample_size: 500
|
32 |
+
min_sample_size: 5
|
33 |
+
pad_audio: true
|
34 |
+
random_crop: false
|
35 |
+
normalize: true
|
36 |
+
stack_order_audio: 4
|
37 |
+
# stack_order: 1
|
38 |
+
input_modality: image
|
39 |
+
image_aug: true
|
40 |
+
|
41 |
+
dataset:
|
42 |
+
num_workers: 6
|
43 |
+
max_tokens: 1000
|
44 |
+
skip_invalid_size_inputs_valid_test: true
|
45 |
+
validate_interval: 5
|
46 |
+
validate_interval_updates: 10000
|
47 |
+
|
48 |
+
criterion:
|
49 |
+
_name: av_hubert
|
50 |
+
pred_masked_weight: 1.0
|
51 |
+
pred_nomask_weight: 0.0
|
52 |
+
loss_weights: [10,]
|
53 |
+
|
54 |
+
optimization:
|
55 |
+
max_update: 400000
|
56 |
+
lr: [0.0005]
|
57 |
+
clip_norm: 10.0
|
58 |
+
|
59 |
+
optimizer:
|
60 |
+
_name: adam
|
61 |
+
adam_betas: (0.9,0.98)
|
62 |
+
adam_eps: 1e-06
|
63 |
+
weight_decay: 0.01
|
64 |
+
|
65 |
+
lr_scheduler:
|
66 |
+
_name: polynomial_decay
|
67 |
+
warmup_updates: 32000
|
68 |
+
|
69 |
+
model:
|
70 |
+
_name: av_hubert
|
71 |
+
label_rate: 25
|
72 |
+
skip_masked: false
|
73 |
+
skip_nomask: false
|
74 |
+
modality_dropout: 0
|
75 |
+
audio_dropout: 0.5
|
76 |
+
modality_fuse: concat
|
77 |
+
selection_type: same_seq
|
78 |
+
masking_type: feature
|
79 |
+
mask_prob_image: 0.8
|
80 |
+
mask_length_image: 10
|
81 |
+
mask_prob_audio: 0.8
|
82 |
+
mask_length_audio: 10
|
83 |
+
extractor_mode: default
|
84 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
85 |
+
final_dim: 256
|
86 |
+
encoder_layerdrop: 0.05
|
87 |
+
dropout_input: 0.1
|
88 |
+
dropout_features: 0.1
|
89 |
+
dropout: 0.1
|
90 |
+
attention_dropout: 0.1
|
91 |
+
feature_grad_mult: 0.1
|
92 |
+
untie_final_proj: true
|
93 |
+
activation_dropout: 0.0
|
94 |
+
wav_input: false
|
95 |
+
layer_norm_first: true
|
96 |
+
audio_feat_dim: 104
|
97 |
+
|
98 |
+
hydra:
|
99 |
+
job:
|
100 |
+
config:
|
101 |
+
override_dirname:
|
102 |
+
kv_sep: '-'
|
103 |
+
item_sep: '__'
|
104 |
+
exclude_keys:
|
105 |
+
- run
|
106 |
+
- task.data
|
107 |
+
- task.label_dir
|
108 |
+
run:
|
109 |
+
dir: ???
|
110 |
+
sweep:
|
111 |
+
dir: ???
|
112 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/base_lrs3_iter4.yaml
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
seed: 1337
|
8 |
+
user_dir: ???
|
9 |
+
empty_cache_freq: 10000
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval_updates: 25000
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
|
16 |
+
|
17 |
+
distributed_training:
|
18 |
+
ddp_backend: no_c10d
|
19 |
+
distributed_backend: 'nccl'
|
20 |
+
distributed_world_size: 32
|
21 |
+
distributed_port: 29671
|
22 |
+
nprocs_per_node: 8
|
23 |
+
|
24 |
+
task:
|
25 |
+
_name: av_hubert_pretraining
|
26 |
+
data: ???
|
27 |
+
label_dir: ???
|
28 |
+
labels: ["mfcc"]
|
29 |
+
label_rate: ${model.label_rate}
|
30 |
+
sample_rate: 25
|
31 |
+
max_sample_size: 500
|
32 |
+
min_sample_size: 5
|
33 |
+
pad_audio: true
|
34 |
+
random_crop: false
|
35 |
+
normalize: true
|
36 |
+
stack_order_audio: 4
|
37 |
+
# stack_order: 1
|
38 |
+
input_modality: image
|
39 |
+
image_aug: true
|
40 |
+
|
41 |
+
dataset:
|
42 |
+
num_workers: 6
|
43 |
+
max_tokens: 1000
|
44 |
+
skip_invalid_size_inputs_valid_test: true
|
45 |
+
validate_interval: 5
|
46 |
+
validate_interval_updates: 10000
|
47 |
+
|
48 |
+
criterion:
|
49 |
+
_name: av_hubert
|
50 |
+
pred_masked_weight: 1.0
|
51 |
+
pred_nomask_weight: 0.0
|
52 |
+
loss_weights: [10,]
|
53 |
+
|
54 |
+
optimization:
|
55 |
+
max_update: 400000
|
56 |
+
lr: [0.0005]
|
57 |
+
clip_norm: 10.0
|
58 |
+
|
59 |
+
optimizer:
|
60 |
+
_name: adam
|
61 |
+
adam_betas: (0.9,0.98)
|
62 |
+
adam_eps: 1e-06
|
63 |
+
weight_decay: 0.01
|
64 |
+
|
65 |
+
lr_scheduler:
|
66 |
+
_name: polynomial_decay
|
67 |
+
warmup_updates: 32000
|
68 |
+
|
69 |
+
model:
|
70 |
+
_name: av_hubert
|
71 |
+
label_rate: 25
|
72 |
+
skip_masked: false
|
73 |
+
skip_nomask: false
|
74 |
+
modality_dropout: 0
|
75 |
+
audio_dropout: 0.5
|
76 |
+
modality_fuse: concat
|
77 |
+
selection_type: same_seq
|
78 |
+
masking_type: feature
|
79 |
+
mask_prob_image: 0.8
|
80 |
+
mask_length_image: 10
|
81 |
+
mask_prob_audio: 0.8
|
82 |
+
mask_length_audio: 10
|
83 |
+
extractor_mode: default
|
84 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
85 |
+
final_dim: 256
|
86 |
+
encoder_layerdrop: 0.05
|
87 |
+
dropout_input: 0.1
|
88 |
+
dropout_features: 0.1
|
89 |
+
dropout: 0.1
|
90 |
+
attention_dropout: 0.1
|
91 |
+
feature_grad_mult: 0.1
|
92 |
+
untie_final_proj: true
|
93 |
+
activation_dropout: 0.0
|
94 |
+
wav_input: false
|
95 |
+
layer_norm_first: true
|
96 |
+
audio_feat_dim: 104
|
97 |
+
|
98 |
+
hydra:
|
99 |
+
job:
|
100 |
+
config:
|
101 |
+
override_dirname:
|
102 |
+
kv_sep: '-'
|
103 |
+
item_sep: '__'
|
104 |
+
exclude_keys:
|
105 |
+
- run
|
106 |
+
- task.data
|
107 |
+
- task.label_dir
|
108 |
+
run:
|
109 |
+
dir: ???
|
110 |
+
sweep:
|
111 |
+
dir: ???
|
112 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/base_lrs3_iter5.yaml
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
seed: 1337
|
8 |
+
user_dir: ???
|
9 |
+
empty_cache_freq: 10000
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval_updates: 25000
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
|
16 |
+
|
17 |
+
distributed_training:
|
18 |
+
ddp_backend: no_c10d
|
19 |
+
distributed_backend: 'nccl'
|
20 |
+
distributed_world_size: 32
|
21 |
+
distributed_port: 29671
|
22 |
+
nprocs_per_node: 8
|
23 |
+
|
24 |
+
task:
|
25 |
+
_name: av_hubert_pretraining
|
26 |
+
data: ???
|
27 |
+
label_dir: ???
|
28 |
+
labels: ["km"]
|
29 |
+
label_rate: ${model.label_rate}
|
30 |
+
sample_rate: 25
|
31 |
+
max_sample_size: 500
|
32 |
+
min_sample_size: 5
|
33 |
+
pad_audio: true
|
34 |
+
random_crop: false
|
35 |
+
normalize: true
|
36 |
+
stack_order_audio: 4
|
37 |
+
# stack_order: 1
|
38 |
+
input_modality: image
|
39 |
+
image_aug: true
|
40 |
+
|
41 |
+
dataset:
|
42 |
+
num_workers: 6
|
43 |
+
max_tokens: 1000
|
44 |
+
skip_invalid_size_inputs_valid_test: true
|
45 |
+
validate_interval: 5
|
46 |
+
validate_interval_updates: 10000
|
47 |
+
|
48 |
+
criterion:
|
49 |
+
_name: av_hubert
|
50 |
+
pred_masked_weight: 1.0
|
51 |
+
pred_nomask_weight: 0.0
|
52 |
+
loss_weights: [10,]
|
53 |
+
|
54 |
+
optimization:
|
55 |
+
max_update: 400000
|
56 |
+
lr: [0.0005]
|
57 |
+
clip_norm: 10.0
|
58 |
+
|
59 |
+
optimizer:
|
60 |
+
_name: adam
|
61 |
+
adam_betas: (0.9,0.98)
|
62 |
+
adam_eps: 1e-06
|
63 |
+
weight_decay: 0.01
|
64 |
+
|
65 |
+
lr_scheduler:
|
66 |
+
_name: polynomial_decay
|
67 |
+
warmup_updates: 32000
|
68 |
+
|
69 |
+
model:
|
70 |
+
_name: av_hubert
|
71 |
+
label_rate: ???
|
72 |
+
skip_masked: false
|
73 |
+
skip_nomask: false
|
74 |
+
modality_dropout: 0.5
|
75 |
+
audio_dropout: 0.5
|
76 |
+
modality_fuse: concat
|
77 |
+
selection_type: same_seq
|
78 |
+
masking_type: input
|
79 |
+
mask_prob_image: 0.3
|
80 |
+
mask_length_image: 5
|
81 |
+
mask_prob_audio: 0.8
|
82 |
+
mask_length_audio: 10
|
83 |
+
extractor_mode: default
|
84 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
85 |
+
final_dim: 256
|
86 |
+
encoder_layerdrop: 0.05
|
87 |
+
dropout_input: 0.1
|
88 |
+
dropout_features: 0.1
|
89 |
+
dropout: 0.1
|
90 |
+
attention_dropout: 0.1
|
91 |
+
feature_grad_mult: 0.1
|
92 |
+
untie_final_proj: true
|
93 |
+
activation_dropout: 0.0
|
94 |
+
wav_input: false
|
95 |
+
layer_norm_first: true
|
96 |
+
audio_feat_dim: 104
|
97 |
+
|
98 |
+
hydra:
|
99 |
+
job:
|
100 |
+
config:
|
101 |
+
override_dirname:
|
102 |
+
kv_sep: '-'
|
103 |
+
item_sep: '__'
|
104 |
+
exclude_keys:
|
105 |
+
- run
|
106 |
+
- task.data
|
107 |
+
- task.label_dir
|
108 |
+
run:
|
109 |
+
dir: ???
|
110 |
+
sweep:
|
111 |
+
dir: ???
|
112 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/base_vox_iter1.yaml
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
seed: 1337
|
8 |
+
user_dir: ???
|
9 |
+
empty_cache_freq: 10000
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval_updates: 25000
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
|
16 |
+
|
17 |
+
distributed_training:
|
18 |
+
ddp_backend: no_c10d
|
19 |
+
distributed_backend: 'nccl'
|
20 |
+
distributed_world_size: 32
|
21 |
+
distributed_port: 29671
|
22 |
+
nprocs_per_node: 8
|
23 |
+
|
24 |
+
task:
|
25 |
+
_name: av_hubert_pretraining
|
26 |
+
data: ???
|
27 |
+
label_dir: ???
|
28 |
+
labels: ["mfcc"]
|
29 |
+
label_rate: ${model.label_rate}
|
30 |
+
sample_rate: 25
|
31 |
+
max_sample_size: 2000
|
32 |
+
min_sample_size: 5
|
33 |
+
pad_audio: false
|
34 |
+
random_crop: true
|
35 |
+
normalize: true
|
36 |
+
stack_order_audio: 4
|
37 |
+
# stack_order: 1
|
38 |
+
input_modality: image
|
39 |
+
image_aug: true
|
40 |
+
max_trim_sample_size: 400
|
41 |
+
|
42 |
+
dataset:
|
43 |
+
num_workers: 6
|
44 |
+
max_tokens: 1000
|
45 |
+
skip_invalid_size_inputs_valid_test: true
|
46 |
+
validate_interval: 5
|
47 |
+
validate_interval_updates: 10000
|
48 |
+
|
49 |
+
criterion:
|
50 |
+
_name: av_hubert
|
51 |
+
pred_masked_weight: 1.0
|
52 |
+
pred_nomask_weight: 0.0
|
53 |
+
loss_weights: [10,]
|
54 |
+
|
55 |
+
optimization:
|
56 |
+
max_update: 800000
|
57 |
+
lr: [0.002]
|
58 |
+
clip_norm: 10.0
|
59 |
+
|
60 |
+
optimizer:
|
61 |
+
_name: adam
|
62 |
+
adam_betas: (0.9,0.98)
|
63 |
+
adam_eps: 1e-06
|
64 |
+
weight_decay: 0.01
|
65 |
+
|
66 |
+
lr_scheduler:
|
67 |
+
_name: polynomial_decay
|
68 |
+
warmup_updates: 64000
|
69 |
+
|
70 |
+
model:
|
71 |
+
_name: av_hubert
|
72 |
+
label_rate: 100
|
73 |
+
skip_masked: false
|
74 |
+
skip_nomask: false
|
75 |
+
modality_dropout: 0
|
76 |
+
audio_dropout: 0.5
|
77 |
+
modality_fuse: concat
|
78 |
+
selection_type: same_seq
|
79 |
+
masking_type: feature
|
80 |
+
mask_prob_image: 0.8
|
81 |
+
mask_length_image: 10
|
82 |
+
mask_prob_audio: 0.8
|
83 |
+
mask_length_audio: 10
|
84 |
+
extractor_mode: default
|
85 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
86 |
+
final_dim: 256
|
87 |
+
encoder_layerdrop: 0.05
|
88 |
+
dropout_input: 0.1
|
89 |
+
dropout_features: 0.1
|
90 |
+
dropout: 0.1
|
91 |
+
attention_dropout: 0.1
|
92 |
+
feature_grad_mult: 0.1
|
93 |
+
untie_final_proj: true
|
94 |
+
activation_dropout: 0.0
|
95 |
+
wav_input: false
|
96 |
+
layer_norm_first: true
|
97 |
+
audio_feat_dim: 104
|
98 |
+
|
99 |
+
hydra:
|
100 |
+
job:
|
101 |
+
config:
|
102 |
+
override_dirname:
|
103 |
+
kv_sep: '-'
|
104 |
+
item_sep: '__'
|
105 |
+
exclude_keys:
|
106 |
+
- run
|
107 |
+
- task.data
|
108 |
+
- task.label_dir
|
109 |
+
run:
|
110 |
+
dir: ???
|
111 |
+
sweep:
|
112 |
+
dir: ???
|
113 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/base_vox_iter2.yaml
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
seed: 1337
|
8 |
+
user_dir: ???
|
9 |
+
empty_cache_freq: 10000
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval_updates: 25000
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
|
16 |
+
|
17 |
+
distributed_training:
|
18 |
+
ddp_backend: no_c10d
|
19 |
+
distributed_backend: 'nccl'
|
20 |
+
distributed_world_size: 32
|
21 |
+
distributed_port: 29671
|
22 |
+
nprocs_per_node: 8
|
23 |
+
|
24 |
+
task:
|
25 |
+
_name: av_hubert_pretraining
|
26 |
+
data: ???
|
27 |
+
label_dir: ???
|
28 |
+
labels: ["km"]
|
29 |
+
label_rate: ${model.label_rate}
|
30 |
+
sample_rate: 25
|
31 |
+
max_sample_size: 2000
|
32 |
+
min_sample_size: 5
|
33 |
+
pad_audio: false
|
34 |
+
random_crop: true
|
35 |
+
normalize: true
|
36 |
+
stack_order_audio: 4
|
37 |
+
# stack_order: 1
|
38 |
+
input_modality: image
|
39 |
+
image_aug: true
|
40 |
+
max_trim_sample_size: 400
|
41 |
+
|
42 |
+
dataset:
|
43 |
+
num_workers: 6
|
44 |
+
max_tokens: 1000
|
45 |
+
skip_invalid_size_inputs_valid_test: true
|
46 |
+
validate_interval: 5
|
47 |
+
validate_interval_updates: 10000
|
48 |
+
|
49 |
+
criterion:
|
50 |
+
_name: av_hubert
|
51 |
+
pred_masked_weight: 1.0
|
52 |
+
pred_nomask_weight: 0.0
|
53 |
+
loss_weights: [10,]
|
54 |
+
|
55 |
+
optimization:
|
56 |
+
max_update: 800000
|
57 |
+
lr: [0.002]
|
58 |
+
clip_norm: 10.0
|
59 |
+
|
60 |
+
optimizer:
|
61 |
+
_name: adam
|
62 |
+
adam_betas: (0.9,0.98)
|
63 |
+
adam_eps: 1e-06
|
64 |
+
weight_decay: 0.01
|
65 |
+
|
66 |
+
lr_scheduler:
|
67 |
+
_name: polynomial_decay
|
68 |
+
warmup_updates: 64000
|
69 |
+
|
70 |
+
model:
|
71 |
+
_name: av_hubert
|
72 |
+
label_rate: 25
|
73 |
+
skip_masked: false
|
74 |
+
skip_nomask: false
|
75 |
+
modality_dropout: 0.5
|
76 |
+
audio_dropout: 0.5
|
77 |
+
modality_fuse: concat
|
78 |
+
selection_type: same_seq
|
79 |
+
masking_type: feature
|
80 |
+
mask_prob_image: 0.8
|
81 |
+
mask_length_image: 10
|
82 |
+
mask_prob_audio: 0.8
|
83 |
+
mask_length_audio: 10
|
84 |
+
extractor_mode: default
|
85 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
86 |
+
final_dim: 256
|
87 |
+
encoder_layerdrop: 0.05
|
88 |
+
dropout_input: 0.1
|
89 |
+
dropout_features: 0.1
|
90 |
+
dropout: 0.1
|
91 |
+
attention_dropout: 0.1
|
92 |
+
feature_grad_mult: 0.1
|
93 |
+
untie_final_proj: true
|
94 |
+
activation_dropout: 0.0
|
95 |
+
wav_input: false
|
96 |
+
layer_norm_first: true
|
97 |
+
audio_feat_dim: 104
|
98 |
+
|
99 |
+
hydra:
|
100 |
+
job:
|
101 |
+
config:
|
102 |
+
override_dirname:
|
103 |
+
kv_sep: '-'
|
104 |
+
item_sep: '__'
|
105 |
+
exclude_keys:
|
106 |
+
- run
|
107 |
+
- task.data
|
108 |
+
- task.label_dir
|
109 |
+
run:
|
110 |
+
dir: ???
|
111 |
+
sweep:
|
112 |
+
dir: ???
|
113 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/base_vox_iter3.yaml
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
seed: 1337
|
8 |
+
user_dir: ???
|
9 |
+
empty_cache_freq: 10000
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval_updates: 25000
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
|
16 |
+
|
17 |
+
distributed_training:
|
18 |
+
ddp_backend: no_c10d
|
19 |
+
distributed_backend: 'nccl'
|
20 |
+
distributed_world_size: 32
|
21 |
+
distributed_port: 29671
|
22 |
+
nprocs_per_node: 8
|
23 |
+
|
24 |
+
task:
|
25 |
+
_name: av_hubert_pretraining
|
26 |
+
data: ???
|
27 |
+
label_dir: ???
|
28 |
+
labels: ["km"]
|
29 |
+
label_rate: ${model.label_rate}
|
30 |
+
sample_rate: 25
|
31 |
+
max_sample_size: 2000
|
32 |
+
min_sample_size: 5
|
33 |
+
pad_audio: false
|
34 |
+
random_crop: true
|
35 |
+
normalize: true
|
36 |
+
stack_order_audio: 4
|
37 |
+
# stack_order: 1
|
38 |
+
input_modality: image
|
39 |
+
image_aug: true
|
40 |
+
max_trim_sample_size: 400
|
41 |
+
|
42 |
+
dataset:
|
43 |
+
num_workers: 6
|
44 |
+
max_tokens: 1000
|
45 |
+
skip_invalid_size_inputs_valid_test: true
|
46 |
+
validate_interval: 5
|
47 |
+
validate_interval_updates: 10000
|
48 |
+
|
49 |
+
criterion:
|
50 |
+
_name: av_hubert
|
51 |
+
pred_masked_weight: 1.0
|
52 |
+
pred_nomask_weight: 0.0
|
53 |
+
loss_weights: [10,]
|
54 |
+
|
55 |
+
optimization:
|
56 |
+
max_update: 800000
|
57 |
+
lr: [0.002]
|
58 |
+
clip_norm: 10.0
|
59 |
+
|
60 |
+
optimizer:
|
61 |
+
_name: adam
|
62 |
+
adam_betas: (0.9,0.98)
|
63 |
+
adam_eps: 1e-06
|
64 |
+
weight_decay: 0.01
|
65 |
+
|
66 |
+
lr_scheduler:
|
67 |
+
_name: polynomial_decay
|
68 |
+
warmup_updates: 64000
|
69 |
+
|
70 |
+
model:
|
71 |
+
_name: av_hubert
|
72 |
+
label_rate: 25
|
73 |
+
skip_masked: false
|
74 |
+
skip_nomask: false
|
75 |
+
modality_dropout: 0.5
|
76 |
+
audio_dropout: 0.5
|
77 |
+
modality_fuse: concat
|
78 |
+
selection_type: same_seq
|
79 |
+
masking_type: feature
|
80 |
+
mask_prob_image: 0.8
|
81 |
+
mask_length_image: 10
|
82 |
+
mask_prob_audio: 0.8
|
83 |
+
mask_length_audio: 10
|
84 |
+
extractor_mode: default
|
85 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
86 |
+
final_dim: 256
|
87 |
+
encoder_layerdrop: 0.05
|
88 |
+
dropout_input: 0.1
|
89 |
+
dropout_features: 0.1
|
90 |
+
dropout: 0.1
|
91 |
+
attention_dropout: 0.1
|
92 |
+
feature_grad_mult: 0.1
|
93 |
+
untie_final_proj: true
|
94 |
+
activation_dropout: 0.0
|
95 |
+
wav_input: false
|
96 |
+
layer_norm_first: true
|
97 |
+
audio_feat_dim: 104
|
98 |
+
|
99 |
+
hydra:
|
100 |
+
job:
|
101 |
+
config:
|
102 |
+
override_dirname:
|
103 |
+
kv_sep: '-'
|
104 |
+
item_sep: '__'
|
105 |
+
exclude_keys:
|
106 |
+
- run
|
107 |
+
- task.data
|
108 |
+
- task.label_dir
|
109 |
+
run:
|
110 |
+
dir: ???
|
111 |
+
sweep:
|
112 |
+
dir: ???
|
113 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/base_vox_iter4.yaml
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
seed: 1337
|
8 |
+
user_dir: ???
|
9 |
+
empty_cache_freq: 10000
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval_updates: 25000
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
|
16 |
+
|
17 |
+
distributed_training:
|
18 |
+
ddp_backend: no_c10d
|
19 |
+
distributed_backend: 'nccl'
|
20 |
+
distributed_world_size: 32
|
21 |
+
distributed_port: 29671
|
22 |
+
nprocs_per_node: 8
|
23 |
+
|
24 |
+
task:
|
25 |
+
_name: av_hubert_pretraining
|
26 |
+
data: ???
|
27 |
+
label_dir: ???
|
28 |
+
labels: ["km"]
|
29 |
+
label_rate: ${model.label_rate}
|
30 |
+
sample_rate: 25
|
31 |
+
max_sample_size: 2000
|
32 |
+
min_sample_size: 5
|
33 |
+
pad_audio: false
|
34 |
+
random_crop: true
|
35 |
+
normalize: true
|
36 |
+
stack_order_audio: 4
|
37 |
+
# stack_order: 1
|
38 |
+
input_modality: image
|
39 |
+
image_aug: true
|
40 |
+
max_trim_sample_size: 400
|
41 |
+
|
42 |
+
dataset:
|
43 |
+
num_workers: 6
|
44 |
+
max_tokens: 1000
|
45 |
+
skip_invalid_size_inputs_valid_test: true
|
46 |
+
validate_interval: 5
|
47 |
+
validate_interval_updates: 10000
|
48 |
+
|
49 |
+
criterion:
|
50 |
+
_name: av_hubert
|
51 |
+
pred_masked_weight: 1.0
|
52 |
+
pred_nomask_weight: 0.0
|
53 |
+
loss_weights: [10,]
|
54 |
+
|
55 |
+
optimization:
|
56 |
+
max_update: 800000
|
57 |
+
lr: [0.002]
|
58 |
+
clip_norm: 10.0
|
59 |
+
|
60 |
+
optimizer:
|
61 |
+
_name: adam
|
62 |
+
adam_betas: (0.9,0.98)
|
63 |
+
adam_eps: 1e-06
|
64 |
+
weight_decay: 0.01
|
65 |
+
|
66 |
+
lr_scheduler:
|
67 |
+
_name: polynomial_decay
|
68 |
+
warmup_updates: 64000
|
69 |
+
|
70 |
+
model:
|
71 |
+
_name: av_hubert
|
72 |
+
label_rate: 25
|
73 |
+
skip_masked: false
|
74 |
+
skip_nomask: false
|
75 |
+
modality_dropout: 0.5
|
76 |
+
audio_dropout: 0.5
|
77 |
+
modality_fuse: concat
|
78 |
+
masking_type: feature
|
79 |
+
mask_prob_image: 0.8
|
80 |
+
mask_length_image: 10
|
81 |
+
mask_prob_audio: 0.8
|
82 |
+
mask_length_audio: 10
|
83 |
+
extractor_mode: default
|
84 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
85 |
+
final_dim: 256
|
86 |
+
encoder_layerdrop: 0.05
|
87 |
+
dropout_input: 0.1
|
88 |
+
dropout_features: 0.1
|
89 |
+
dropout: 0.1
|
90 |
+
attention_dropout: 0.1
|
91 |
+
feature_grad_mult: 0.1
|
92 |
+
untie_final_proj: true
|
93 |
+
activation_dropout: 0.0
|
94 |
+
wav_input: false
|
95 |
+
layer_norm_first: true
|
96 |
+
audio_feat_dim: 104
|
97 |
+
|
98 |
+
hydra:
|
99 |
+
job:
|
100 |
+
config:
|
101 |
+
override_dirname:
|
102 |
+
kv_sep: '-'
|
103 |
+
item_sep: '__'
|
104 |
+
exclude_keys:
|
105 |
+
- run
|
106 |
+
- task.data
|
107 |
+
- task.label_dir
|
108 |
+
run:
|
109 |
+
dir: ???
|
110 |
+
sweep:
|
111 |
+
dir: ???
|
112 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/base_vox_iter5.yaml
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
seed: 1337
|
8 |
+
user_dir: ???
|
9 |
+
empty_cache_freq: 10000
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval_updates: 25000
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
|
16 |
+
|
17 |
+
distributed_training:
|
18 |
+
ddp_backend: no_c10d
|
19 |
+
distributed_backend: 'nccl'
|
20 |
+
distributed_world_size: 32
|
21 |
+
distributed_port: 29671
|
22 |
+
nprocs_per_node: 8
|
23 |
+
|
24 |
+
task:
|
25 |
+
_name: av_hubert_pretraining
|
26 |
+
data: ???
|
27 |
+
label_dir: ???
|
28 |
+
labels: ["km"]
|
29 |
+
label_rate: ${model.label_rate}
|
30 |
+
sample_rate: 25
|
31 |
+
max_sample_size: 2000
|
32 |
+
min_sample_size: 5
|
33 |
+
pad_audio: false
|
34 |
+
random_crop: true
|
35 |
+
normalize: true
|
36 |
+
stack_order_audio: 4
|
37 |
+
# stack_order: 1
|
38 |
+
input_modality: image
|
39 |
+
image_aug: true
|
40 |
+
max_trim_sample_size: 400
|
41 |
+
|
42 |
+
dataset:
|
43 |
+
num_workers: 6
|
44 |
+
max_tokens: 1000
|
45 |
+
skip_invalid_size_inputs_valid_test: true
|
46 |
+
validate_interval: 5
|
47 |
+
validate_interval_updates: 10000
|
48 |
+
|
49 |
+
criterion:
|
50 |
+
_name: av_hubert
|
51 |
+
pred_masked_weight: 1.0
|
52 |
+
pred_nomask_weight: 0.0
|
53 |
+
loss_weights: [10,]
|
54 |
+
|
55 |
+
optimization:
|
56 |
+
max_update: 800000
|
57 |
+
lr: [0.002]
|
58 |
+
clip_norm: 10.0
|
59 |
+
|
60 |
+
optimizer:
|
61 |
+
_name: adam
|
62 |
+
adam_betas: (0.9,0.98)
|
63 |
+
adam_eps: 1e-06
|
64 |
+
weight_decay: 0.01
|
65 |
+
|
66 |
+
lr_scheduler:
|
67 |
+
_name: polynomial_decay
|
68 |
+
warmup_updates: 64000
|
69 |
+
|
70 |
+
model:
|
71 |
+
_name: av_hubert
|
72 |
+
label_rate: ???
|
73 |
+
skip_masked: false
|
74 |
+
skip_nomask: false
|
75 |
+
modality_dropout: 0.5
|
76 |
+
audio_dropout: 0.5
|
77 |
+
modality_fuse: concat
|
78 |
+
selection_type: same_seq
|
79 |
+
masking_type: input
|
80 |
+
mask_prob_image: 0.3
|
81 |
+
mask_length_image: 5
|
82 |
+
mask_prob_audio: 0.8
|
83 |
+
mask_length_audio: 10
|
84 |
+
extractor_mode: default
|
85 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
86 |
+
final_dim: 256
|
87 |
+
encoder_layerdrop: 0.05
|
88 |
+
dropout_input: 0.1
|
89 |
+
dropout_features: 0.1
|
90 |
+
dropout: 0.1
|
91 |
+
attention_dropout: 0.1
|
92 |
+
feature_grad_mult: 0.1
|
93 |
+
untie_final_proj: true
|
94 |
+
activation_dropout: 0.0
|
95 |
+
wav_input: false
|
96 |
+
layer_norm_first: true
|
97 |
+
audio_feat_dim: 104
|
98 |
+
|
99 |
+
hydra:
|
100 |
+
job:
|
101 |
+
config:
|
102 |
+
override_dirname:
|
103 |
+
kv_sep: '-'
|
104 |
+
item_sep: '__'
|
105 |
+
exclude_keys:
|
106 |
+
- run
|
107 |
+
- task.data
|
108 |
+
- task.label_dir
|
109 |
+
run:
|
110 |
+
dir: ???
|
111 |
+
sweep:
|
112 |
+
dir: ???
|
113 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/large_lrs3_iter5.yaml
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
seed: 1337
|
8 |
+
user_dir: ???
|
9 |
+
empty_cache_freq: 10000
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval_updates: 25000
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
|
16 |
+
|
17 |
+
distributed_training:
|
18 |
+
ddp_backend: no_c10d
|
19 |
+
distributed_backend: 'nccl'
|
20 |
+
distributed_world_size: 64
|
21 |
+
distributed_port: 29671
|
22 |
+
nprocs_per_node: 8
|
23 |
+
|
24 |
+
task:
|
25 |
+
_name: av_hubert_pretraining
|
26 |
+
data: ???
|
27 |
+
label_dir: ???
|
28 |
+
labels: ["km"]
|
29 |
+
label_rate: ${model.label_rate}
|
30 |
+
sample_rate: 25
|
31 |
+
max_sample_size: 2000
|
32 |
+
min_sample_size: 5
|
33 |
+
pad_audio: false
|
34 |
+
random_crop: true
|
35 |
+
normalize: true
|
36 |
+
stack_order_audio: 4
|
37 |
+
# stack_order: 1
|
38 |
+
input_modality: image
|
39 |
+
image_aug: true
|
40 |
+
max_trim_sample_size: 400
|
41 |
+
|
42 |
+
dataset:
|
43 |
+
num_workers: 6
|
44 |
+
max_tokens: 1000
|
45 |
+
skip_invalid_size_inputs_valid_test: true
|
46 |
+
validate_interval: 5
|
47 |
+
validate_interval_updates: 10000
|
48 |
+
|
49 |
+
criterion:
|
50 |
+
_name: av_hubert
|
51 |
+
pred_masked_weight: 1.0
|
52 |
+
pred_nomask_weight: 1.0
|
53 |
+
loss_weights: [10,]
|
54 |
+
|
55 |
+
optimization:
|
56 |
+
max_update: 400000
|
57 |
+
lr: [0.002]
|
58 |
+
clip_norm: 10.0
|
59 |
+
|
60 |
+
optimizer:
|
61 |
+
_name: adam
|
62 |
+
adam_betas: (0.9,0.98)
|
63 |
+
adam_eps: 1e-06
|
64 |
+
weight_decay: 0.01
|
65 |
+
|
66 |
+
lr_scheduler:
|
67 |
+
_name: polynomial_decay
|
68 |
+
warmup_updates: 32000
|
69 |
+
|
70 |
+
model:
|
71 |
+
_name: av_hubert
|
72 |
+
label_rate: 25
|
73 |
+
skip_masked: false
|
74 |
+
skip_nomask: false
|
75 |
+
modality_dropout: 0.5
|
76 |
+
audio_dropout: 0.5
|
77 |
+
modality_fuse: concat
|
78 |
+
selection_type: same_seq
|
79 |
+
masking_type: input
|
80 |
+
mask_prob_image: 0.3
|
81 |
+
mask_length_image: 5
|
82 |
+
mask_prob_audio: 0.8
|
83 |
+
mask_length_audio: 10
|
84 |
+
extractor_mode: default
|
85 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
86 |
+
final_dim: 256
|
87 |
+
encoder_layerdrop: 0.05
|
88 |
+
dropout_input: 0.1
|
89 |
+
dropout_features: 0.1
|
90 |
+
dropout: 0.1
|
91 |
+
attention_dropout: 0.1
|
92 |
+
feature_grad_mult: 0.1
|
93 |
+
untie_final_proj: true
|
94 |
+
activation_dropout: 0.0
|
95 |
+
wav_input: false
|
96 |
+
layer_norm_first: true
|
97 |
+
audio_feat_dim: 104
|
98 |
+
encoder_layers: 24
|
99 |
+
encoder_embed_dim: 1024
|
100 |
+
encoder_ffn_embed_dim: 4096
|
101 |
+
encoder_attention_heads: 16
|
102 |
+
|
103 |
+
hydra:
|
104 |
+
job:
|
105 |
+
config:
|
106 |
+
override_dirname:
|
107 |
+
kv_sep: '-'
|
108 |
+
item_sep: '__'
|
109 |
+
exclude_keys:
|
110 |
+
- run
|
111 |
+
- task.data
|
112 |
+
- task.label_dir
|
113 |
+
run:
|
114 |
+
dir: ???
|
115 |
+
sweep:
|
116 |
+
dir: ???
|
117 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/large_vox_iter5.yaml
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
seed: 1337
|
8 |
+
user_dir: ???
|
9 |
+
empty_cache_freq: 10000
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval_updates: 25000
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
|
16 |
+
|
17 |
+
distributed_training:
|
18 |
+
ddp_backend: no_c10d
|
19 |
+
distributed_backend: 'nccl'
|
20 |
+
distributed_world_size: 64
|
21 |
+
distributed_port: 29671
|
22 |
+
nprocs_per_node: 8
|
23 |
+
|
24 |
+
task:
|
25 |
+
_name: av_hubert_pretraining
|
26 |
+
data: ???
|
27 |
+
label_dir: ???
|
28 |
+
labels: ["km"]
|
29 |
+
label_rate: ${model.label_rate}
|
30 |
+
sample_rate: 25
|
31 |
+
max_sample_size: 2000
|
32 |
+
min_sample_size: 5
|
33 |
+
pad_audio: false
|
34 |
+
random_crop: true
|
35 |
+
normalize: true
|
36 |
+
stack_order_audio: 4
|
37 |
+
# stack_order: 1
|
38 |
+
input_modality: image
|
39 |
+
image_aug: true
|
40 |
+
max_trim_sample_size: 400
|
41 |
+
|
42 |
+
dataset:
|
43 |
+
num_workers: 6
|
44 |
+
max_tokens: 1000
|
45 |
+
skip_invalid_size_inputs_valid_test: true
|
46 |
+
validate_interval: 5
|
47 |
+
validate_interval_updates: 10000
|
48 |
+
|
49 |
+
criterion:
|
50 |
+
_name: av_hubert
|
51 |
+
pred_masked_weight: 1.0
|
52 |
+
pred_nomask_weight: 1.0
|
53 |
+
loss_weights: [10,]
|
54 |
+
|
55 |
+
optimization:
|
56 |
+
max_update: 600000
|
57 |
+
lr: [0.002]
|
58 |
+
clip_norm: 10.0
|
59 |
+
|
60 |
+
optimizer:
|
61 |
+
_name: adam
|
62 |
+
adam_betas: (0.9,0.98)
|
63 |
+
adam_eps: 1e-06
|
64 |
+
weight_decay: 0.01
|
65 |
+
|
66 |
+
lr_scheduler:
|
67 |
+
_name: polynomial_decay
|
68 |
+
warmup_updates: 48000
|
69 |
+
|
70 |
+
model:
|
71 |
+
_name: av_hubert
|
72 |
+
label_rate: ???
|
73 |
+
skip_masked: false
|
74 |
+
skip_nomask: false
|
75 |
+
modality_dropout: 0.5
|
76 |
+
audio_dropout: 0.5
|
77 |
+
modality_fuse: concat
|
78 |
+
selection_type: same_seq
|
79 |
+
masking_type: input
|
80 |
+
mask_prob_image: 0.3
|
81 |
+
mask_length_image: 5
|
82 |
+
mask_prob_audio: 0.8
|
83 |
+
mask_length_audio: 10
|
84 |
+
extractor_mode: default
|
85 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
86 |
+
final_dim: 256
|
87 |
+
encoder_layerdrop: 0.05
|
88 |
+
dropout_input: 0.1
|
89 |
+
dropout_features: 0.1
|
90 |
+
dropout: 0.1
|
91 |
+
attention_dropout: 0.1
|
92 |
+
feature_grad_mult: 0.1
|
93 |
+
untie_final_proj: true
|
94 |
+
activation_dropout: 0.0
|
95 |
+
wav_input: false
|
96 |
+
layer_norm_first: true
|
97 |
+
audio_feat_dim: 104
|
98 |
+
encoder_layers: 24
|
99 |
+
encoder_embed_dim: 1024
|
100 |
+
encoder_ffn_embed_dim: 4096
|
101 |
+
encoder_attention_heads: 16
|
102 |
+
|
103 |
+
hydra:
|
104 |
+
job:
|
105 |
+
config:
|
106 |
+
override_dirname:
|
107 |
+
kv_sep: '-'
|
108 |
+
item_sep: '__'
|
109 |
+
exclude_keys:
|
110 |
+
- run
|
111 |
+
- task.data
|
112 |
+
- task.label_dir
|
113 |
+
run:
|
114 |
+
dir: ???
|
115 |
+
sweep:
|
116 |
+
dir: ???
|
117 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/noise_base_vox_iter5.yaml
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
seed: 1337
|
8 |
+
user_dir: ???
|
9 |
+
empty_cache_freq: 10000
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval_updates: 25000
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
|
16 |
+
|
17 |
+
distributed_training:
|
18 |
+
ddp_backend: no_c10d
|
19 |
+
distributed_backend: 'nccl'
|
20 |
+
distributed_world_size: 32
|
21 |
+
distributed_port: 29671
|
22 |
+
nprocs_per_node: 8
|
23 |
+
|
24 |
+
task:
|
25 |
+
_name: av_hubert_pretraining
|
26 |
+
data: ???
|
27 |
+
label_dir: ???
|
28 |
+
labels: ["km"]
|
29 |
+
label_rate: ${model.label_rate}
|
30 |
+
sample_rate: 25
|
31 |
+
max_sample_size: 2000
|
32 |
+
min_sample_size: 5
|
33 |
+
pad_audio: false
|
34 |
+
random_crop: true
|
35 |
+
normalize: true
|
36 |
+
stack_order_audio: 4
|
37 |
+
input_modality: image
|
38 |
+
image_aug: true
|
39 |
+
max_trim_sample_size: 400
|
40 |
+
noise_prob: 0.25
|
41 |
+
noise_snr: 0
|
42 |
+
noise_wav: ???
|
43 |
+
|
44 |
+
dataset:
|
45 |
+
num_workers: 6
|
46 |
+
max_tokens: 1000
|
47 |
+
skip_invalid_size_inputs_valid_test: true
|
48 |
+
validate_interval: 5
|
49 |
+
validate_interval_updates: 10000
|
50 |
+
|
51 |
+
criterion:
|
52 |
+
_name: av_hubert
|
53 |
+
pred_masked_weight: 1.0
|
54 |
+
pred_nomask_weight: 0.0
|
55 |
+
loss_weights: [10,]
|
56 |
+
|
57 |
+
optimization:
|
58 |
+
max_update: 800000
|
59 |
+
lr: [0.002]
|
60 |
+
clip_norm: 10.0
|
61 |
+
|
62 |
+
optimizer:
|
63 |
+
_name: adam
|
64 |
+
adam_betas: (0.9,0.98)
|
65 |
+
adam_eps: 1e-06
|
66 |
+
weight_decay: 0.01
|
67 |
+
|
68 |
+
lr_scheduler:
|
69 |
+
_name: polynomial_decay
|
70 |
+
warmup_updates: 64000
|
71 |
+
|
72 |
+
model:
|
73 |
+
_name: av_hubert
|
74 |
+
label_rate: ???
|
75 |
+
skip_masked: false
|
76 |
+
skip_nomask: false
|
77 |
+
modality_dropout: 0.5
|
78 |
+
audio_dropout: 0.5
|
79 |
+
modality_fuse: concat
|
80 |
+
selection_type: same_seq
|
81 |
+
masking_type: input
|
82 |
+
mask_prob_image: 0.3
|
83 |
+
mask_length_image: 5
|
84 |
+
mask_prob_audio: 0.8
|
85 |
+
mask_length_audio: 10
|
86 |
+
extractor_mode: default
|
87 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
88 |
+
final_dim: 256
|
89 |
+
encoder_layerdrop: 0.05
|
90 |
+
dropout_input: 0.1
|
91 |
+
dropout_features: 0.1
|
92 |
+
dropout: 0.1
|
93 |
+
attention_dropout: 0.1
|
94 |
+
feature_grad_mult: 0.1
|
95 |
+
untie_final_proj: true
|
96 |
+
activation_dropout: 0.0
|
97 |
+
wav_input: false
|
98 |
+
layer_norm_first: true
|
99 |
+
audio_feat_dim: 104
|
100 |
+
|
101 |
+
hydra:
|
102 |
+
job:
|
103 |
+
config:
|
104 |
+
override_dirname:
|
105 |
+
kv_sep: '-'
|
106 |
+
item_sep: '__'
|
107 |
+
exclude_keys:
|
108 |
+
- run
|
109 |
+
- task.data
|
110 |
+
- task.label_dir
|
111 |
+
run:
|
112 |
+
dir: ???
|
113 |
+
sweep:
|
114 |
+
dir: ???
|
115 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/noise_large_vox_iter5.yaml
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
seed: 1337
|
8 |
+
user_dir: ???
|
9 |
+
empty_cache_freq: 10000
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval_updates: 25000
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
|
16 |
+
|
17 |
+
distributed_training:
|
18 |
+
ddp_backend: no_c10d
|
19 |
+
distributed_backend: 'nccl'
|
20 |
+
distributed_world_size: 64
|
21 |
+
distributed_port: 29671
|
22 |
+
nprocs_per_node: 8
|
23 |
+
|
24 |
+
task:
|
25 |
+
_name: av_hubert_pretraining
|
26 |
+
data: ???
|
27 |
+
label_dir: ???
|
28 |
+
labels: ["km"]
|
29 |
+
label_rate: ${model.label_rate}
|
30 |
+
sample_rate: 25
|
31 |
+
max_sample_size: 2000
|
32 |
+
min_sample_size: 5
|
33 |
+
pad_audio: false
|
34 |
+
random_crop: true
|
35 |
+
normalize: true
|
36 |
+
stack_order_audio: 4
|
37 |
+
input_modality: image
|
38 |
+
image_aug: true
|
39 |
+
max_trim_sample_size: 400
|
40 |
+
noise_prob: 0.25
|
41 |
+
noise_snr: 0
|
42 |
+
noise_wav: ???
|
43 |
+
|
44 |
+
dataset:
|
45 |
+
num_workers: 6
|
46 |
+
max_tokens: 1000
|
47 |
+
skip_invalid_size_inputs_valid_test: true
|
48 |
+
validate_interval: 5
|
49 |
+
validate_interval_updates: 10000
|
50 |
+
|
51 |
+
criterion:
|
52 |
+
_name: av_hubert
|
53 |
+
pred_masked_weight: 1.0
|
54 |
+
pred_nomask_weight: 1.0
|
55 |
+
loss_weights: [10,]
|
56 |
+
|
57 |
+
optimization:
|
58 |
+
max_update: 600000
|
59 |
+
lr: [0.002]
|
60 |
+
clip_norm: 10.0
|
61 |
+
|
62 |
+
optimizer:
|
63 |
+
_name: adam
|
64 |
+
adam_betas: (0.9,0.98)
|
65 |
+
adam_eps: 1e-06
|
66 |
+
weight_decay: 0.01
|
67 |
+
|
68 |
+
lr_scheduler:
|
69 |
+
_name: polynomial_decay
|
70 |
+
warmup_updates: 48000
|
71 |
+
|
72 |
+
model:
|
73 |
+
_name: av_hubert
|
74 |
+
label_rate: ???
|
75 |
+
skip_masked: false
|
76 |
+
skip_nomask: false
|
77 |
+
modality_dropout: 0.5
|
78 |
+
audio_dropout: 0.5
|
79 |
+
modality_fuse: concat
|
80 |
+
selection_type: same_seq
|
81 |
+
masking_type: input
|
82 |
+
mask_prob_image: 0.3
|
83 |
+
mask_length_image: 5
|
84 |
+
mask_prob_audio: 0.8
|
85 |
+
mask_length_audio: 10
|
86 |
+
extractor_mode: default
|
87 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
88 |
+
final_dim: 256
|
89 |
+
encoder_layerdrop: 0.05
|
90 |
+
dropout_input: 0.1
|
91 |
+
dropout_features: 0.1
|
92 |
+
dropout: 0.1
|
93 |
+
attention_dropout: 0.1
|
94 |
+
feature_grad_mult: 0.1
|
95 |
+
untie_final_proj: true
|
96 |
+
activation_dropout: 0.0
|
97 |
+
wav_input: false
|
98 |
+
layer_norm_first: true
|
99 |
+
audio_feat_dim: 104
|
100 |
+
encoder_layers: 24
|
101 |
+
encoder_embed_dim: 1024
|
102 |
+
encoder_ffn_embed_dim: 4096
|
103 |
+
encoder_attention_heads: 16
|
104 |
+
|
105 |
+
hydra:
|
106 |
+
job:
|
107 |
+
config:
|
108 |
+
override_dirname:
|
109 |
+
kv_sep: '-'
|
110 |
+
item_sep: '__'
|
111 |
+
exclude_keys:
|
112 |
+
- run
|
113 |
+
- task.data
|
114 |
+
- task.label_dir
|
115 |
+
run:
|
116 |
+
dir: ???
|
117 |
+
sweep:
|
118 |
+
dir: ???
|
119 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/s2s_decode.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
common:
|
2 |
+
user_dir: ???
|
3 |
+
|
4 |
+
generation:
|
5 |
+
beam: 50
|
6 |
+
max_len_a: 1.0
|
7 |
+
max_len_b: 0
|
8 |
+
lenpen: 1.0
|
9 |
+
lm_weight: 0
|
10 |
+
|
11 |
+
common_eval:
|
12 |
+
results_path: ???
|
13 |
+
path: ???
|
14 |
+
|
15 |
+
dataset:
|
16 |
+
max_tokens: 1000
|
17 |
+
gen_subset: valid
|
18 |
+
num_workers: 0
|
19 |
+
|
20 |
+
override:
|
21 |
+
noise_prob: 0.0
|
22 |
+
noise_snr: 0
|
23 |
+
modalities: ???
|
av_hubert/avhubert/decoder.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from argparse import Namespace
|
8 |
+
import contextlib
|
9 |
+
import copy
|
10 |
+
import math
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from dataclasses import dataclass, field
|
16 |
+
from omegaconf import MISSING, II, open_dict
|
17 |
+
from typing import Any, Optional
|
18 |
+
|
19 |
+
from fairseq import checkpoint_utils, tasks, utils
|
20 |
+
from fairseq.dataclass import FairseqDataclass
|
21 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
22 |
+
from fairseq.tasks import FairseqTask
|
23 |
+
from fairseq.models import (
|
24 |
+
BaseFairseqModel,
|
25 |
+
FairseqEncoder,
|
26 |
+
FairseqEncoderDecoderModel,
|
27 |
+
FairseqIncrementalDecoder,
|
28 |
+
register_model,
|
29 |
+
)
|
30 |
+
# from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES
|
31 |
+
from fairseq.modules import (
|
32 |
+
LayerNorm,
|
33 |
+
PositionalEmbedding,
|
34 |
+
TransformerDecoderLayer,
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
class TransformerDecoder(FairseqIncrementalDecoder):
|
39 |
+
"""
|
40 |
+
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
|
41 |
+
is a :class:`TransformerDecoderLayer`.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
args (argparse.Namespace): parsed command-line arguments
|
45 |
+
dictionary (~fairseq.data.Dictionary): decoding dictionary
|
46 |
+
embed_tokens (torch.nn.Embedding): output embedding
|
47 |
+
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
48 |
+
(default: False).
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
cfg,
|
54 |
+
dictionary,
|
55 |
+
embed_tokens,
|
56 |
+
no_encoder_attn=False,
|
57 |
+
):
|
58 |
+
super().__init__(dictionary)
|
59 |
+
|
60 |
+
self.dropout = cfg.decoder_dropout
|
61 |
+
self.share_input_output_embed = cfg.share_decoder_input_output_embed
|
62 |
+
|
63 |
+
input_embed_dim = embed_tokens.embedding_dim
|
64 |
+
embed_dim = cfg.decoder_embed_dim
|
65 |
+
self.output_embed_dim = cfg.decoder_embed_dim
|
66 |
+
|
67 |
+
self.layerdrop = cfg.decoder_layerdrop
|
68 |
+
|
69 |
+
padding_idx = embed_tokens.padding_idx
|
70 |
+
self.max_target_positions = cfg.max_target_positions
|
71 |
+
|
72 |
+
self.embed_tokens = embed_tokens
|
73 |
+
# self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
|
74 |
+
self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim)
|
75 |
+
|
76 |
+
self.project_in_dim = (
|
77 |
+
Linear(input_embed_dim, embed_dim, bias=False)
|
78 |
+
if embed_dim != input_embed_dim
|
79 |
+
else None
|
80 |
+
)
|
81 |
+
|
82 |
+
self.embed_positions = (
|
83 |
+
PositionalEmbedding(
|
84 |
+
cfg.max_target_positions,
|
85 |
+
embed_dim,
|
86 |
+
padding_idx,
|
87 |
+
learned=cfg.decoder_learned_pos,
|
88 |
+
)
|
89 |
+
if not cfg.no_token_positional_embeddings
|
90 |
+
else None
|
91 |
+
)
|
92 |
+
|
93 |
+
# TODO: update this when transformer gets converted to dataclass configs
|
94 |
+
transformer_cfg = copy.deepcopy(cfg)
|
95 |
+
# with open_dict(transformer_cfg):
|
96 |
+
transformer_cfg.dropout = transformer_cfg.decoder_dropout
|
97 |
+
transformer_cfg.attention_dropout = (
|
98 |
+
transformer_cfg.decoder_attention_dropout
|
99 |
+
)
|
100 |
+
transformer_cfg.activation_dropout = (
|
101 |
+
transformer_cfg.decoder_activation_dropout
|
102 |
+
)
|
103 |
+
|
104 |
+
self.layers = nn.ModuleList([])
|
105 |
+
self.layers.extend(
|
106 |
+
[
|
107 |
+
TransformerDecoderLayer(transformer_cfg, no_encoder_attn)
|
108 |
+
for _ in range(transformer_cfg.decoder_layers)
|
109 |
+
]
|
110 |
+
)
|
111 |
+
|
112 |
+
if not self.share_input_output_embed:
|
113 |
+
self.embed_out = nn.Parameter(
|
114 |
+
torch.Tensor(len(dictionary), self.output_embed_dim)
|
115 |
+
)
|
116 |
+
nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim ** -0.5)
|
117 |
+
|
118 |
+
if transformer_cfg.decoder_normalize_before:
|
119 |
+
self.layer_norm = LayerNorm(embed_dim)
|
120 |
+
else:
|
121 |
+
self.layer_norm = None
|
122 |
+
|
123 |
+
def forward(
|
124 |
+
self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused
|
125 |
+
):
|
126 |
+
"""
|
127 |
+
Args:
|
128 |
+
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
129 |
+
`(batch, tgt_len)`, for teacher forcing
|
130 |
+
encoder_out (Tensor, optional): output from the encoder, used for
|
131 |
+
encoder-side attention
|
132 |
+
incremental_state (dict): dictionary used for storing state during
|
133 |
+
:ref:`Incremental decoding`
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
tuple:
|
137 |
+
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
138 |
+
- a dictionary with any model-specific outputs
|
139 |
+
"""
|
140 |
+
prev_output_tokens = prev_output_tokens.long()
|
141 |
+
x, extra = self.extract_features(
|
142 |
+
prev_output_tokens, encoder_out, incremental_state
|
143 |
+
)
|
144 |
+
x = self.output_layer(x)
|
145 |
+
return x, extra
|
146 |
+
|
147 |
+
def extract_features(
|
148 |
+
self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused
|
149 |
+
):
|
150 |
+
"""
|
151 |
+
Similar to *forward* but only return features.
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
tuple:
|
155 |
+
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
|
156 |
+
- a dictionary with any model-specific outputs
|
157 |
+
"""
|
158 |
+
|
159 |
+
# embed positions
|
160 |
+
positions = (
|
161 |
+
self.embed_positions(
|
162 |
+
prev_output_tokens, incremental_state=incremental_state
|
163 |
+
)
|
164 |
+
if self.embed_positions is not None
|
165 |
+
else None
|
166 |
+
)
|
167 |
+
|
168 |
+
if incremental_state is not None:
|
169 |
+
prev_output_tokens = prev_output_tokens[:, -1:]
|
170 |
+
if positions is not None:
|
171 |
+
positions = positions[:, -1:]
|
172 |
+
|
173 |
+
# embed tokens and positions
|
174 |
+
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
|
175 |
+
|
176 |
+
if self.project_in_dim is not None:
|
177 |
+
x = self.project_in_dim(x)
|
178 |
+
|
179 |
+
if positions is not None:
|
180 |
+
x += positions
|
181 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
182 |
+
|
183 |
+
# B x T x C -> T x B x C
|
184 |
+
x = x.transpose(0, 1)
|
185 |
+
attn = None
|
186 |
+
|
187 |
+
inner_states = [x]
|
188 |
+
|
189 |
+
# decoder layers
|
190 |
+
for layer in self.layers:
|
191 |
+
dropout_probability = np.random.random()
|
192 |
+
if not self.training or (dropout_probability > self.layerdrop):
|
193 |
+
x, attn, _ = layer(
|
194 |
+
x,
|
195 |
+
encoder_out["encoder_out"] if encoder_out is not None else None,
|
196 |
+
encoder_out["padding_mask"] if encoder_out is not None else None,
|
197 |
+
incremental_state,
|
198 |
+
self_attn_mask=self.buffered_future_mask(x)
|
199 |
+
if incremental_state is None
|
200 |
+
else None,
|
201 |
+
)
|
202 |
+
inner_states.append(x)
|
203 |
+
|
204 |
+
if self.layer_norm:
|
205 |
+
x = self.layer_norm(x)
|
206 |
+
|
207 |
+
# T x B x C -> B x T x C
|
208 |
+
x = x.transpose(0, 1)
|
209 |
+
|
210 |
+
return x, {"attn": attn, "inner_states": inner_states}
|
211 |
+
|
212 |
+
def output_layer(self, features, **kwargs):
|
213 |
+
"""Project features to the vocabulary size."""
|
214 |
+
# project back to size of vocabulary
|
215 |
+
emb_mat = self.embed_tokens.weight if self.share_input_output_embed else self.embed_out
|
216 |
+
return torch.matmul(features, emb_mat.transpose(0, 1))
|
217 |
+
# if self.share_input_output_embed:
|
218 |
+
# return F.linear(features, self.embed_tokens.weight)
|
219 |
+
# else:
|
220 |
+
# return F.linear(features, self.embed_out)
|
221 |
+
|
222 |
+
def max_positions(self):
|
223 |
+
"""Maximum output length supported by the decoder."""
|
224 |
+
if self.embed_positions is None:
|
225 |
+
return self.max_target_positions
|
226 |
+
return min(self.max_target_positions, self.embed_positions.max_positions)
|
227 |
+
|
228 |
+
def buffered_future_mask(self, tensor):
|
229 |
+
dim = tensor.size(0)
|
230 |
+
if (
|
231 |
+
not hasattr(self, "_future_mask")
|
232 |
+
or self._future_mask is None
|
233 |
+
or self._future_mask.device != tensor.device
|
234 |
+
or self._future_mask.size(0) < dim
|
235 |
+
):
|
236 |
+
self._future_mask = torch.triu(
|
237 |
+
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
|
238 |
+
)
|
239 |
+
return self._future_mask[:dim, :dim]
|
240 |
+
|
241 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
242 |
+
return state_dict
|
243 |
+
|
av_hubert/avhubert/hubert.py
ADDED
@@ -0,0 +1,779 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import os,sys
|
8 |
+
import logging
|
9 |
+
from typing import Dict, List, Optional, Tuple
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
from dataclasses import dataclass, field
|
16 |
+
from fairseq import utils
|
17 |
+
from fairseq.data.data_utils import compute_mask_indices
|
18 |
+
from fairseq.data.dictionary import Dictionary
|
19 |
+
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
20 |
+
from fairseq.models import BaseFairseqModel, register_model
|
21 |
+
from fairseq.models.wav2vec.wav2vec2 import (
|
22 |
+
ConvFeatureExtractionModel,
|
23 |
+
TransformerEncoder,
|
24 |
+
)
|
25 |
+
from fairseq.modules import GradMultiply, LayerNorm
|
26 |
+
from copy import deepcopy
|
27 |
+
|
28 |
+
DBG=True if len(sys.argv) == 1 else False
|
29 |
+
|
30 |
+
if DBG:
|
31 |
+
from hubert_pretraining import (
|
32 |
+
AVHubertPretrainingConfig,
|
33 |
+
AVHubertPretrainingTask,
|
34 |
+
)
|
35 |
+
from resnet import ResEncoder
|
36 |
+
logging.basicConfig(
|
37 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
38 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
39 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
40 |
+
stream=sys.stdout,
|
41 |
+
)
|
42 |
+
from utils import compute_mask_indices
|
43 |
+
from decoder import TransformerDecoder
|
44 |
+
|
45 |
+
else:
|
46 |
+
from .hubert_pretraining import (
|
47 |
+
AVHubertPretrainingConfig,
|
48 |
+
AVHubertPretrainingTask,
|
49 |
+
)
|
50 |
+
from .resnet import ResEncoder
|
51 |
+
from .utils import compute_mask_indices
|
52 |
+
from .decoder import TransformerDecoder
|
53 |
+
|
54 |
+
from omegaconf import II
|
55 |
+
|
56 |
+
logger = logging.getLogger(__name__)
|
57 |
+
|
58 |
+
EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])
|
59 |
+
MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(
|
60 |
+
["static", "uniform", "normal", "poisson"]
|
61 |
+
)
|
62 |
+
|
63 |
+
|
64 |
+
@dataclass
|
65 |
+
class AVHubertConfig(FairseqDataclass):
|
66 |
+
label_rate: int = II("task.label_rate")
|
67 |
+
input_modality: str = II("task.input_modality")
|
68 |
+
extractor_mode: EXTRACTOR_MODE_CHOICES = field(
|
69 |
+
default="default",
|
70 |
+
metadata={
|
71 |
+
"help": "mode for feature extractor. default has a single group "
|
72 |
+
"norm with d groups in the first conv block, whereas layer_norm "
|
73 |
+
"has layer norms in every block (meant to use with normalize=True)"
|
74 |
+
},
|
75 |
+
)
|
76 |
+
encoder_layers: int = field(
|
77 |
+
default=12, metadata={"help": "num encoder layers in the transformer"}
|
78 |
+
)
|
79 |
+
encoder_embed_dim: int = field(
|
80 |
+
default=768, metadata={"help": "encoder embedding dimension"}
|
81 |
+
)
|
82 |
+
encoder_ffn_embed_dim: int = field(
|
83 |
+
default=3072, metadata={"help": "encoder embedding dimension for FFN"}
|
84 |
+
)
|
85 |
+
encoder_attention_heads: int = field(
|
86 |
+
default=12, metadata={"help": "num encoder attention heads"}
|
87 |
+
)
|
88 |
+
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
89 |
+
default="gelu", metadata={"help": "activation function to use"}
|
90 |
+
)
|
91 |
+
|
92 |
+
# dropouts
|
93 |
+
dropout: float = field(
|
94 |
+
default=0.1,
|
95 |
+
metadata={"help": "dropout probability for the transformer"},
|
96 |
+
)
|
97 |
+
attention_dropout: float = field(
|
98 |
+
default=0.1,
|
99 |
+
metadata={"help": "dropout probability for attention weights"},
|
100 |
+
)
|
101 |
+
activation_dropout: float = field(
|
102 |
+
default=0.0,
|
103 |
+
metadata={"help": "dropout probability after activation in FFN"},
|
104 |
+
)
|
105 |
+
encoder_layerdrop: float = field(
|
106 |
+
default=0.0,
|
107 |
+
metadata={"help": "probability of dropping a tarnsformer layer"},
|
108 |
+
)
|
109 |
+
dropout_input: float = field(
|
110 |
+
default=0.0,
|
111 |
+
metadata={"help": "dropout to apply to the input (after feat extr)"},
|
112 |
+
)
|
113 |
+
dropout_features: float = field(
|
114 |
+
default=0.0,
|
115 |
+
metadata={
|
116 |
+
"help": "dropout to apply to the features (after feat extr)"
|
117 |
+
},
|
118 |
+
)
|
119 |
+
|
120 |
+
final_dim: int = field(
|
121 |
+
default=0,
|
122 |
+
metadata={
|
123 |
+
"help": "project final representations and targets to this many "
|
124 |
+
"dimensions. set to encoder_embed_dim is <= 0"
|
125 |
+
},
|
126 |
+
)
|
127 |
+
untie_final_proj: bool = field(
|
128 |
+
default=False,
|
129 |
+
metadata={"help": "use separate projection for each target"},
|
130 |
+
)
|
131 |
+
layer_norm_first: bool = field(
|
132 |
+
default=False,
|
133 |
+
metadata={"help": "apply layernorm first in the transformer"},
|
134 |
+
)
|
135 |
+
conv_feature_layers: str = field(
|
136 |
+
default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
|
137 |
+
metadata={
|
138 |
+
"help": "string describing convolutional feature extraction "
|
139 |
+
"layers in form of a python list that contains "
|
140 |
+
"[(dim, kernel_size, stride), ...]"
|
141 |
+
},
|
142 |
+
)
|
143 |
+
conv_bias: bool = field(
|
144 |
+
default=False, metadata={"help": "include bias in conv encoder"}
|
145 |
+
)
|
146 |
+
logit_temp: float = field(
|
147 |
+
default=0.1, metadata={"help": "temperature to divide logits by"}
|
148 |
+
)
|
149 |
+
target_glu: bool = field(
|
150 |
+
default=False, metadata={"help": "adds projection + glu to targets"}
|
151 |
+
)
|
152 |
+
feature_grad_mult: float = field(
|
153 |
+
default=1.0,
|
154 |
+
metadata={"help": "multiply feature extractor var grads by this"},
|
155 |
+
)
|
156 |
+
|
157 |
+
# masking
|
158 |
+
mask_length_audio: int = field(default=10, metadata={"help": "mask length"})
|
159 |
+
mask_prob_audio: float = field(
|
160 |
+
default=0.65,
|
161 |
+
metadata={"help": "probability of replacing a token with mask"},
|
162 |
+
)
|
163 |
+
mask_length_image: int = field(default=10, metadata={"help": "mask length"})
|
164 |
+
mask_prob_image: float = field(
|
165 |
+
default=0.65,
|
166 |
+
metadata={"help": "probability of replacing a token with mask"},
|
167 |
+
)
|
168 |
+
mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
169 |
+
default="static", metadata={"help": "how to choose mask length"}
|
170 |
+
)
|
171 |
+
mask_other: float = field(
|
172 |
+
default=0,
|
173 |
+
metadata={
|
174 |
+
"help": "secondary mask argument "
|
175 |
+
"(used for more complex distributions), "
|
176 |
+
"see help in compute_mask_indicesh"
|
177 |
+
},
|
178 |
+
)
|
179 |
+
no_mask_overlap: bool = field(
|
180 |
+
default=False, metadata={"help": "whether to allow masks to overlap"}
|
181 |
+
)
|
182 |
+
mask_min_space: int = field(
|
183 |
+
default=1,
|
184 |
+
metadata={
|
185 |
+
"help": "min space between spans (if no overlap is enabled)"
|
186 |
+
},
|
187 |
+
)
|
188 |
+
|
189 |
+
# channel masking
|
190 |
+
mask_channel_length: int = field(
|
191 |
+
default=10,
|
192 |
+
metadata={"help": "length of the mask for features (channels)"},
|
193 |
+
)
|
194 |
+
mask_channel_prob: float = field(
|
195 |
+
default=0.0,
|
196 |
+
metadata={"help": "probability of replacing a feature with 0"},
|
197 |
+
)
|
198 |
+
mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
199 |
+
default="static",
|
200 |
+
metadata={"help": "how to choose mask length for channel masking"},
|
201 |
+
)
|
202 |
+
mask_channel_other: float = field(
|
203 |
+
default=0,
|
204 |
+
metadata={
|
205 |
+
"help": "secondary mask argument "
|
206 |
+
"(used for more complex distributions), "
|
207 |
+
"see help in compute_mask_indicesh"
|
208 |
+
},
|
209 |
+
)
|
210 |
+
no_mask_channel_overlap: bool = field(
|
211 |
+
default=False,
|
212 |
+
metadata={"help": "whether to allow channel masks to overlap"},
|
213 |
+
)
|
214 |
+
mask_channel_min_space: int = field(
|
215 |
+
default=1,
|
216 |
+
metadata={
|
217 |
+
"help": "min space between spans (if no overlap is enabled)"
|
218 |
+
},
|
219 |
+
)
|
220 |
+
|
221 |
+
# positional embeddings
|
222 |
+
conv_pos: int = field(
|
223 |
+
default=128,
|
224 |
+
metadata={
|
225 |
+
"help": "number of filters for convolutional positional embeddings"
|
226 |
+
},
|
227 |
+
)
|
228 |
+
conv_pos_groups: int = field(
|
229 |
+
default=16,
|
230 |
+
metadata={
|
231 |
+
"help": "number of groups for convolutional positional embedding"
|
232 |
+
},
|
233 |
+
)
|
234 |
+
|
235 |
+
latent_temp: Tuple[float, float, float] = field(
|
236 |
+
default=(2, 0.5, 0.999995),
|
237 |
+
metadata={"help": "legacy (to be removed)"},
|
238 |
+
)
|
239 |
+
|
240 |
+
# loss computation
|
241 |
+
skip_masked: bool = field(
|
242 |
+
default=False,
|
243 |
+
metadata={"help": "skip computing losses over masked frames"},
|
244 |
+
)
|
245 |
+
skip_nomask: bool = field(
|
246 |
+
default=False,
|
247 |
+
metadata={"help": "skip computing losses over unmasked frames"},
|
248 |
+
)
|
249 |
+
resnet_relu_type: str = field(default='prelu', metadata={"help": 'relu type for resnet'})
|
250 |
+
resnet_weights: Optional[str] = field(default=None, metadata={"help": 'resnet weights'})
|
251 |
+
sim_type: str = field(default='cosine', metadata={"help": 'similarity type'})
|
252 |
+
|
253 |
+
sub_encoder_layers: int = field(default=0, metadata={'help': 'number of transformer layers for single modality'})
|
254 |
+
audio_feat_dim: int = field(default=-1, metadata={'help': 'audio feature dimension'})
|
255 |
+
modality_dropout: float = field(default=0, metadata={'help': 'drop one modality'})
|
256 |
+
audio_dropout: float = field(default=0, metadata={'help': 'drop audio feature'})
|
257 |
+
modality_fuse: str = field(default='concat', metadata={'help': 'fusing two modalities: add,concat'})
|
258 |
+
selection_type : str = field(default='same_other_seq', metadata={'help': 'type of selectig images, same_other_seq: replace masked span with span from another sequence, same_seq: repace masked span with span of the same sequence'})
|
259 |
+
masking_type : str = field(default='input', metadata={'help': 'input or feature masking'})
|
260 |
+
|
261 |
+
decoder_embed_dim: int = field(
|
262 |
+
default=768, metadata={"help": "decoder embedding dimension"}
|
263 |
+
)
|
264 |
+
decoder_ffn_embed_dim: int = field(
|
265 |
+
default=3072, metadata={"help": "decoder embedding dimension for FFN"}
|
266 |
+
)
|
267 |
+
decoder_layers: int = field(
|
268 |
+
default=6, metadata={"help": "num of decoder layers"}
|
269 |
+
)
|
270 |
+
decoder_layerdrop: float = field(
|
271 |
+
default=0.0, metadata={"help": "decoder layerdrop chance"}
|
272 |
+
)
|
273 |
+
decoder_attention_heads: int = field(
|
274 |
+
default=4, metadata={"help": "num decoder attention heads"}
|
275 |
+
)
|
276 |
+
decoder_learned_pos: bool = field(
|
277 |
+
default=False,
|
278 |
+
metadata={"help": "use learned positional embeddings in the decoder"},
|
279 |
+
)
|
280 |
+
decoder_normalize_before: bool = field(
|
281 |
+
default=False,
|
282 |
+
metadata={"help": "apply layernorm before each decoder block"},
|
283 |
+
)
|
284 |
+
no_token_positional_embeddings: bool = field(
|
285 |
+
default=False,
|
286 |
+
metadata={
|
287 |
+
"help": "if set, disables positional embeddings "
|
288 |
+
"(outside self attention)"
|
289 |
+
},
|
290 |
+
)
|
291 |
+
decoder_dropout: float = field(
|
292 |
+
default=0.1, metadata={"help": "dropout probability in the decoder"}
|
293 |
+
)
|
294 |
+
decoder_attention_dropout: float = field(
|
295 |
+
default=0.1,
|
296 |
+
metadata={
|
297 |
+
"help": "dropout probability for attention weights "
|
298 |
+
"inside the decoder"
|
299 |
+
},
|
300 |
+
)
|
301 |
+
decoder_activation_dropout: float = field(
|
302 |
+
default=0.0,
|
303 |
+
metadata={
|
304 |
+
"help": "dropout probability after activation in FFN "
|
305 |
+
"inside the decoder"
|
306 |
+
},
|
307 |
+
)
|
308 |
+
max_target_positions: int = field(
|
309 |
+
default=2048, metadata={"help": "max target positions"}
|
310 |
+
)
|
311 |
+
share_decoder_input_output_embed: bool = field(
|
312 |
+
default=False,
|
313 |
+
metadata={"help": "share decoder input and output embeddings"},
|
314 |
+
)
|
315 |
+
no_scale_embedding: bool = field(default=True, metadata={'help': 'scale embedding'})
|
316 |
+
|
317 |
+
class SubModel(nn.Module):
|
318 |
+
def __init__(self, resnet=None, input_dim=None, cfg=None):
|
319 |
+
super().__init__()
|
320 |
+
self.resnet = resnet
|
321 |
+
self.proj = nn.Linear(input_dim, cfg.encoder_embed_dim)
|
322 |
+
self.encoder = TransformerEncoder(cfg) if cfg.encoder_layers > 0 else None
|
323 |
+
|
324 |
+
def forward(self, x):
|
325 |
+
if self.resnet is not None:
|
326 |
+
x = self.resnet(x)
|
327 |
+
x = self.proj(x.transpose(1, 2))
|
328 |
+
if self.encoder is not None:
|
329 |
+
x = self.encoder(x)[0].transpose(1, 2)
|
330 |
+
else:
|
331 |
+
x = x.transpose(1, 2)
|
332 |
+
return x
|
333 |
+
|
334 |
+
@register_model("av_hubert", dataclass=AVHubertConfig)
|
335 |
+
class AVHubertModel(BaseFairseqModel):
|
336 |
+
def __init__(
|
337 |
+
self,
|
338 |
+
cfg: AVHubertConfig,
|
339 |
+
task_cfg: AVHubertPretrainingConfig,
|
340 |
+
dictionaries: List[Dictionary],
|
341 |
+
**kwargs
|
342 |
+
) -> None:
|
343 |
+
super().__init__()
|
344 |
+
logger.info(f"HubertModel Config: {cfg}")
|
345 |
+
|
346 |
+
feature_ds_rate = 1
|
347 |
+
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate
|
348 |
+
sub_cfg = deepcopy(cfg)
|
349 |
+
sub_cfg.encoder_layers = sub_cfg.sub_encoder_layers
|
350 |
+
resnet = ResEncoder(relu_type=cfg.resnet_relu_type, weights=cfg.resnet_weights)
|
351 |
+
self.feature_extractor_audio = SubModel(resnet=None, input_dim=cfg.audio_feat_dim, cfg=sub_cfg)
|
352 |
+
self.feature_extractor_video = SubModel(resnet=resnet, input_dim=resnet.backend_out, cfg=sub_cfg)
|
353 |
+
self.modality_dropout, self.audio_dropout = cfg.modality_dropout, cfg.audio_dropout
|
354 |
+
self.modality_fuse = cfg.modality_fuse
|
355 |
+
self.encoder_embed_dim = cfg.encoder_embed_dim
|
356 |
+
if self.modality_fuse == 'concat':
|
357 |
+
self.embed = cfg.encoder_embed_dim * 2
|
358 |
+
elif self.modality_fuse == 'add':
|
359 |
+
self.embed = cfg.encoder_embed_dim
|
360 |
+
self.post_extract_proj = (
|
361 |
+
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
362 |
+
if self.embed != cfg.encoder_embed_dim
|
363 |
+
else None
|
364 |
+
)
|
365 |
+
|
366 |
+
self.mask_prob_image, self.mask_prob_audio = cfg.mask_prob_image, cfg.mask_prob_audio
|
367 |
+
self.mask_selection = cfg.mask_selection
|
368 |
+
self.mask_other = cfg.mask_other
|
369 |
+
self.mask_length_image, self.mask_length_audio = cfg.mask_length_image, cfg.mask_length_audio
|
370 |
+
self.no_mask_overlap = cfg.no_mask_overlap
|
371 |
+
self.mask_min_space = cfg.mask_min_space
|
372 |
+
|
373 |
+
self.mask_channel_prob = cfg.mask_channel_prob
|
374 |
+
self.mask_channel_selection = cfg.mask_channel_selection
|
375 |
+
self.mask_channel_other = cfg.mask_channel_other
|
376 |
+
self.mask_channel_length = cfg.mask_channel_length
|
377 |
+
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
378 |
+
self.mask_channel_min_space = cfg.mask_channel_min_space
|
379 |
+
|
380 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
381 |
+
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
382 |
+
|
383 |
+
self.feature_grad_mult = cfg.feature_grad_mult
|
384 |
+
self.logit_temp = cfg.logit_temp
|
385 |
+
self.skip_masked = cfg.skip_masked
|
386 |
+
self.skip_nomask = cfg.skip_nomask
|
387 |
+
self.sim_type = cfg.sim_type
|
388 |
+
self.selection_type = cfg.selection_type
|
389 |
+
self.masking_type = cfg.masking_type
|
390 |
+
|
391 |
+
final_dim = (
|
392 |
+
cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
|
393 |
+
)
|
394 |
+
|
395 |
+
self.mask_emb = nn.Parameter(
|
396 |
+
torch.FloatTensor(cfg.audio_feat_dim).uniform_() if self.masking_type == 'input' else torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
|
397 |
+
)
|
398 |
+
|
399 |
+
self.encoder = TransformerEncoder(cfg)
|
400 |
+
self.layer_norm = LayerNorm(self.embed)
|
401 |
+
|
402 |
+
self.target_glu = None
|
403 |
+
if cfg.target_glu:
|
404 |
+
self.target_glu = nn.Sequential(
|
405 |
+
nn.Linear(final_dim, final_dim * 2), nn.GLU()
|
406 |
+
)
|
407 |
+
|
408 |
+
self.untie_final_proj = cfg.untie_final_proj
|
409 |
+
if self.untie_final_proj:
|
410 |
+
self.final_proj = nn.Linear(
|
411 |
+
cfg.encoder_embed_dim, final_dim * len(dictionaries)
|
412 |
+
)
|
413 |
+
else:
|
414 |
+
self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
|
415 |
+
|
416 |
+
# modules below are not needed during fine-tuning
|
417 |
+
if any([d is None for d in dictionaries]):
|
418 |
+
logger.info(
|
419 |
+
"cannot find dictionary. assume will be used for fine-tuning"
|
420 |
+
)
|
421 |
+
else:
|
422 |
+
self.num_classes = [len(d) for d in dictionaries]
|
423 |
+
self.label_embs_concat = nn.Parameter(
|
424 |
+
torch.FloatTensor(sum(self.num_classes), final_dim)
|
425 |
+
)
|
426 |
+
nn.init.uniform_(self.label_embs_concat)
|
427 |
+
|
428 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
429 |
+
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
430 |
+
|
431 |
+
super().upgrade_state_dict_named(state_dict, name)
|
432 |
+
return state_dict
|
433 |
+
|
434 |
+
@classmethod
|
435 |
+
def build_model(cls, cfg: AVHubertConfig, task: AVHubertPretrainingTask):
|
436 |
+
"""Build a new model instance."""
|
437 |
+
|
438 |
+
kwargs = {}
|
439 |
+
model = AVHubertModel(cfg, task.cfg, task.dictionaries, **kwargs)
|
440 |
+
return model
|
441 |
+
|
442 |
+
def apply_input_mask(self, x, padding_mask, target_list):
|
443 |
+
B, C, T = x.shape[:3]
|
444 |
+
is_audio = True if len(x.shape) == 3 else False
|
445 |
+
if is_audio:
|
446 |
+
mask_prob, mask_length = self.mask_prob_audio, self.mask_length_audio
|
447 |
+
else:
|
448 |
+
mask_prob, mask_length = self.mask_prob_image, self.mask_length_image
|
449 |
+
if mask_prob > 0:
|
450 |
+
|
451 |
+
mask_indices, starts, ends, batch_indexes = compute_mask_indices(
|
452 |
+
(B, T),
|
453 |
+
padding_mask,
|
454 |
+
mask_prob,
|
455 |
+
mask_length,
|
456 |
+
self.mask_selection,
|
457 |
+
self.mask_other,
|
458 |
+
min_masks=2,
|
459 |
+
no_overlap=self.no_mask_overlap,
|
460 |
+
min_space=self.mask_min_space,
|
461 |
+
)
|
462 |
+
mask_indices_np = mask_indices
|
463 |
+
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
464 |
+
x = x.transpose(1, 2).contiguous() # [B, T, C, H, W]
|
465 |
+
if B == 1:
|
466 |
+
x[mask_indices] = 0
|
467 |
+
elif is_audio:
|
468 |
+
x[mask_indices] = self.mask_emb
|
469 |
+
elif self.selection_type == 'same_other_seq':
|
470 |
+
perm = (torch.arange(B) + torch.randint(low=1, high=B, size=(1,))) % B
|
471 |
+
x_perm = x[perm]
|
472 |
+
x[mask_indices] = x_perm[mask_indices]
|
473 |
+
elif self.selection_type == 'same_seq':
|
474 |
+
batch_indexes_, other_indexes = [], []
|
475 |
+
for batch_index, start, end in zip(batch_indexes, starts, ends):
|
476 |
+
length = end-start
|
477 |
+
other_start = np.setdiff1d(np.arange(T), np.arange(max(0, start-length), end))
|
478 |
+
if len(other_start) > 0:
|
479 |
+
other_start = np.random.choice(other_start, size=1)
|
480 |
+
else:
|
481 |
+
other_start = 0
|
482 |
+
other_end = other_start + length
|
483 |
+
other_indexes.append(np.arange(other_start, other_end).clip(max=T-1))
|
484 |
+
batch_indexes_.append(np.zeros([length], dtype=np.int64)+batch_index)
|
485 |
+
batch_indexes, other_indexes = np.concatenate(batch_indexes_), np.concatenate(other_indexes)
|
486 |
+
x[mask_indices] = x[batch_indexes, other_indexes]
|
487 |
+
|
488 |
+
x = x.transpose(1, 2).contiguous()
|
489 |
+
else:
|
490 |
+
mask_indices = None
|
491 |
+
|
492 |
+
if self.mask_channel_prob > 0:
|
493 |
+
logger.info(f"No mask channel prob for input masking")
|
494 |
+
return x, mask_indices
|
495 |
+
|
496 |
+
def apply_feature_mask(self, x, padding_mask, target_list):
|
497 |
+
B, T, C = x.shape
|
498 |
+
assert self.mask_prob_audio == self.mask_prob_image and self.mask_length_audio == self.mask_length_image, f"masking prob/length for image/audio be same for feature masking"
|
499 |
+
mask_prob, mask_length = self.mask_prob_audio, self.mask_length_image
|
500 |
+
if mask_prob > 0:
|
501 |
+
mask_indices, _, _, _ = compute_mask_indices(
|
502 |
+
(B, T),
|
503 |
+
padding_mask,
|
504 |
+
mask_prob,
|
505 |
+
mask_length,
|
506 |
+
self.mask_selection,
|
507 |
+
self.mask_other,
|
508 |
+
min_masks=2,
|
509 |
+
no_overlap=self.no_mask_overlap,
|
510 |
+
min_space=self.mask_min_space,
|
511 |
+
)
|
512 |
+
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
513 |
+
x[mask_indices] = self.mask_emb
|
514 |
+
else:
|
515 |
+
mask_indices = None
|
516 |
+
|
517 |
+
if self.mask_channel_prob > 0:
|
518 |
+
mask_channel_indices, _, _, _ = compute_mask_indices(
|
519 |
+
(B, C),
|
520 |
+
None,
|
521 |
+
self.mask_channel_prob,
|
522 |
+
self.mask_channel_length,
|
523 |
+
self.mask_channel_selection,
|
524 |
+
self.mask_channel_other,
|
525 |
+
no_overlap=self.no_mask_channel_overlap,
|
526 |
+
min_space=self.mask_channel_min_space,
|
527 |
+
)
|
528 |
+
mask_channel_indices = (
|
529 |
+
torch.from_numpy(mask_channel_indices)
|
530 |
+
.to(x.device)
|
531 |
+
.unsqueeze(1)
|
532 |
+
.expand(-1, T, -1)
|
533 |
+
)
|
534 |
+
x[mask_channel_indices] = 0
|
535 |
+
|
536 |
+
return x, mask_indices
|
537 |
+
|
538 |
+
def forward_features(self, source: torch.Tensor, modality: str) -> torch.Tensor:
|
539 |
+
extractor = eval(f"self.feature_extractor_{modality}")
|
540 |
+
if self.feature_grad_mult > 0:
|
541 |
+
features = extractor(source)
|
542 |
+
if self.feature_grad_mult != 1.0:
|
543 |
+
features = GradMultiply.apply(features, self.feature_grad_mult)
|
544 |
+
else:
|
545 |
+
with torch.no_grad():
|
546 |
+
features = extractor(source)
|
547 |
+
return features
|
548 |
+
|
549 |
+
def forward_targets(
|
550 |
+
self, features: torch.Tensor, mask_indices: torch.Tensor, target_list: List[torch.Tensor],
|
551 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
552 |
+
# Trim features to ensure labels exist and then get aligned labels
|
553 |
+
feat_tsz = features.size(2)
|
554 |
+
targ_tsz = min([t.size(1) for t in target_list])
|
555 |
+
if self.feat2tar_ratio * feat_tsz > targ_tsz:
|
556 |
+
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
|
557 |
+
features = features[..., :feat_tsz]
|
558 |
+
if mask_indices is not None:
|
559 |
+
mask_indices = mask_indices[..., :feat_tsz]
|
560 |
+
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
|
561 |
+
target_list = [t[:, target_inds.long()] for t in target_list]
|
562 |
+
return features, mask_indices, target_list
|
563 |
+
|
564 |
+
def forward_padding_mask(
|
565 |
+
self, features: torch.Tensor, padding_mask: torch.Tensor,
|
566 |
+
) -> torch.Tensor:
|
567 |
+
extra = padding_mask.size(1) % features.size(1)
|
568 |
+
if extra > 0:
|
569 |
+
padding_mask = padding_mask[:, :-extra]
|
570 |
+
padding_mask = padding_mask.view(
|
571 |
+
padding_mask.size(0), features.size(1), -1
|
572 |
+
)
|
573 |
+
padding_mask = padding_mask.all(-1)
|
574 |
+
return padding_mask
|
575 |
+
|
576 |
+
def compute_logits(self, feats, emb_mat):
|
577 |
+
# feats: [B, T, F], emb_mat: [V, F]
|
578 |
+
if self.sim_type == 'dot':
|
579 |
+
logits = torch.matmul(feats, emb_mat.transpose(0, 1))
|
580 |
+
elif self.sim_type == 'cosine':
|
581 |
+
batch_size, timesteps, emb_dim = feats.size()
|
582 |
+
feats_ = feats.view(-1, emb_dim)
|
583 |
+
nom = (feats_.unsqueeze(dim=1) * emb_mat.unsqueeze(dim=0)).sum(dim=-1) # [B*T, V]
|
584 |
+
denom = (feats_**2).sum(dim=-1).sqrt().unsqueeze(dim=1) * (emb_mat**2).sum(dim=-1).sqrt().unsqueeze(dim=0) # [B*T, V]
|
585 |
+
logits = (nom/denom.clamp(min=1e-6)).view(batch_size, timesteps, -1)
|
586 |
+
else:
|
587 |
+
raise NotImplementedError
|
588 |
+
logits = logits / self.logit_temp
|
589 |
+
return logits
|
590 |
+
|
591 |
+
def forward(
|
592 |
+
self,
|
593 |
+
source: torch.Tensor,
|
594 |
+
target_list: Optional[List[torch.Tensor]] = None,
|
595 |
+
padding_mask: Optional[torch.Tensor] = None,
|
596 |
+
mask: bool = True,
|
597 |
+
features_only: bool = False,
|
598 |
+
output_layer: Optional[int] = None
|
599 |
+
) -> Dict[str, torch.Tensor]:
|
600 |
+
"""output layer is 1-based"""
|
601 |
+
src_audio, src_video = source['audio'], source['video']
|
602 |
+
if mask and self.masking_type == 'input':
|
603 |
+
src_video, mask_indices_video = self.apply_input_mask(src_video, padding_mask, target_list)
|
604 |
+
src_audio, mask_indices_audio = self.apply_input_mask(src_audio, padding_mask, target_list)
|
605 |
+
mask_indices = torch.logical_or(mask_indices_audio, mask_indices_video)
|
606 |
+
else:
|
607 |
+
src_audio, src_video, mask_indices = src_audio, src_video, None
|
608 |
+
|
609 |
+
features_audio = self.forward_features(src_audio, modality='audio') # features: [B, F, T]
|
610 |
+
features_video = self.forward_features(src_video, modality='video')
|
611 |
+
modality_drop_prob, audio_drop_prob = np.random.random(), np.random.random()
|
612 |
+
if self.training:
|
613 |
+
if modality_drop_prob < self.modality_dropout:
|
614 |
+
if audio_drop_prob < self.audio_dropout:
|
615 |
+
features_audio = 0 * features_audio
|
616 |
+
else:
|
617 |
+
features_video = 0 * features_video
|
618 |
+
if self.modality_fuse == 'concat':
|
619 |
+
features = torch.cat([features_audio, features_video], dim=1)
|
620 |
+
elif self.modality_fuse == 'add':
|
621 |
+
features = features_audio + features_video
|
622 |
+
if target_list is not None:
|
623 |
+
features, mask_indices, target_list = self.forward_targets(features, mask_indices, target_list)
|
624 |
+
|
625 |
+
features_pen = features.float().pow(2).mean()
|
626 |
+
|
627 |
+
features = features.transpose(1, 2)
|
628 |
+
features = self.layer_norm(features)
|
629 |
+
|
630 |
+
if padding_mask is not None:
|
631 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
632 |
+
|
633 |
+
if self.post_extract_proj is not None:
|
634 |
+
features = self.post_extract_proj(features)
|
635 |
+
|
636 |
+
features = self.dropout_input(features)
|
637 |
+
if self.masking_type == 'feature' and mask:
|
638 |
+
x, mask_indices = self.apply_feature_mask(features, padding_mask, target_list)
|
639 |
+
else:
|
640 |
+
x = features
|
641 |
+
|
642 |
+
# feature: (B, T, D), float
|
643 |
+
# target: (B, T), long
|
644 |
+
# x: (B, T, D), float
|
645 |
+
# padding_mask: (B, T), bool
|
646 |
+
# mask_indices: (B, T), bool
|
647 |
+
x, _ = self.encoder(
|
648 |
+
x,
|
649 |
+
padding_mask=padding_mask,
|
650 |
+
layer=None if output_layer is None else output_layer - 1
|
651 |
+
)
|
652 |
+
|
653 |
+
if features_only:
|
654 |
+
return {"x": x, "padding_mask": padding_mask, "features": features}
|
655 |
+
|
656 |
+
label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
|
657 |
+
proj_x = self.final_proj(x)
|
658 |
+
if self.untie_final_proj:
|
659 |
+
proj_x_list = proj_x.chunk(len(self.num_classes), dim=-1)
|
660 |
+
else:
|
661 |
+
proj_x_list = [proj_x for _ in self.num_classes]
|
662 |
+
logit_list = [self.compute_logits(proj, emb).view(-1, num_class) for proj, emb, num_class in zip(proj_x_list, label_embs_list, self.num_classes)] # [[B*T, V]]
|
663 |
+
mask, unmask = torch.logical_and(mask_indices, ~padding_mask).view(-1), torch.logical_and(~mask_indices, ~padding_mask).view(-1) # [B*T]
|
664 |
+
logit_m_list, logit_u_list = [logit[mask] for logit in logit_list], [logit[unmask] for logit in logit_list]
|
665 |
+
target_m_list, target_u_list = [target.view(-1)[mask].long() for target in target_list], [target.view(-1)[unmask].long() for target in target_list]
|
666 |
+
result = {
|
667 |
+
"logit_m_list": logit_m_list,
|
668 |
+
"logit_u_list": logit_u_list,
|
669 |
+
"target_m_list": target_m_list,
|
670 |
+
"target_u_list": target_u_list,
|
671 |
+
"padding_mask": padding_mask,
|
672 |
+
"features_pen": features_pen,
|
673 |
+
}
|
674 |
+
return result
|
675 |
+
|
676 |
+
def extract_features(
|
677 |
+
self,
|
678 |
+
source: torch.Tensor,
|
679 |
+
padding_mask: Optional[torch.Tensor] = None,
|
680 |
+
mask: bool = False,
|
681 |
+
ret_conv: bool = False,
|
682 |
+
output_layer: Optional[int] = None,
|
683 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
684 |
+
res = self.forward(
|
685 |
+
source,
|
686 |
+
padding_mask=padding_mask,
|
687 |
+
mask=mask,
|
688 |
+
features_only=True,
|
689 |
+
output_layer=output_layer,
|
690 |
+
)
|
691 |
+
feature = res["features"] if ret_conv else res["x"]
|
692 |
+
return feature, res["padding_mask"]
|
693 |
+
|
694 |
+
def extract_finetune(self, source, padding_mask=None, mask=False, ret_conv=False, output_layer=None):
|
695 |
+
src_audio, src_video = source['audio'], source['video']
|
696 |
+
if mask and self.masking_type == 'input':
|
697 |
+
src_video, mask_indices_video = self.apply_input_mask(src_video, padding_mask, target_list=None)
|
698 |
+
src_audio, mask_indices_audio = self.apply_input_mask(src_audio, padding_mask, target_list=None)
|
699 |
+
mask_indices = torch.logical_or(mask_indices_audio, mask_indices_video) # mask_indices not used in fine-tuning
|
700 |
+
else:
|
701 |
+
src_audio, src_video, mask_indices = src_audio, src_video, None
|
702 |
+
|
703 |
+
if src_audio is not None and src_video is None:
|
704 |
+
features_audio = self.forward_features(src_audio, modality='audio') # features: [B, F, T]
|
705 |
+
features_video = features_audio.new_zeros(features_audio.size(0), self.encoder_embed_dim, features_audio.size(-1))
|
706 |
+
elif src_audio is None and src_video is not None:
|
707 |
+
features_video = self.forward_features(src_video, modality='video')
|
708 |
+
features_audio = features_video.new_zeros(features_video.size(0), self.encoder_embed_dim, features_video.size(-1))
|
709 |
+
elif src_audio is not None and src_video is not None:
|
710 |
+
features_video = self.forward_features(src_video, modality='video')
|
711 |
+
features_audio = self.forward_features(src_audio, modality='audio') # features: [B, F, T]
|
712 |
+
|
713 |
+
if self.modality_fuse == 'concat':
|
714 |
+
features = torch.cat([features_audio, features_video], dim=1)
|
715 |
+
elif self.modality_fuse == 'add':
|
716 |
+
features = features_audio + features_video
|
717 |
+
features_pen = features.float().pow(2).mean()
|
718 |
+
|
719 |
+
features = features.transpose(1, 2)
|
720 |
+
features = self.layer_norm(features)
|
721 |
+
unmasked_features = features.clone()
|
722 |
+
|
723 |
+
if padding_mask is not None:
|
724 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
725 |
+
|
726 |
+
if self.post_extract_proj is not None:
|
727 |
+
features = self.post_extract_proj(features)
|
728 |
+
|
729 |
+
features = self.dropout_input(features)
|
730 |
+
unmasked_features = self.dropout_features(unmasked_features)
|
731 |
+
x = features
|
732 |
+
mask_indices = None
|
733 |
+
|
734 |
+
# feature: (B, T, D), float
|
735 |
+
# target: (B, T), long
|
736 |
+
# x: (B, T, D), float
|
737 |
+
# padding_mask: (B, T), bool
|
738 |
+
# mask_indices: (B, T), bool
|
739 |
+
x, _ = self.encoder(
|
740 |
+
x,
|
741 |
+
padding_mask=padding_mask,
|
742 |
+
layer=None if output_layer is None else output_layer - 1
|
743 |
+
)
|
744 |
+
|
745 |
+
return x, padding_mask
|
746 |
+
|
747 |
+
|
748 |
+
def get_extra_losses(self, net_output):
|
749 |
+
extra_losses = []
|
750 |
+
names = []
|
751 |
+
if "features_pen" in net_output:
|
752 |
+
extra_losses.append(net_output["features_pen"])
|
753 |
+
names.append("features_pen")
|
754 |
+
|
755 |
+
return extra_losses, names
|
756 |
+
|
757 |
+
def remove_pretraining_modules(self):
|
758 |
+
self.target_glu = None
|
759 |
+
self.final_proj = None
|
760 |
+
|
761 |
+
def get_logits(self, net_output, is_masked=True):
|
762 |
+
raise NotImplementedError
|
763 |
+
|
764 |
+
def get_targets(self, net_output, is_masked=True):
|
765 |
+
raise NotImplementedError
|
766 |
+
|
767 |
+
def compute_nce(self, x, pos, negs):
|
768 |
+
neg_is_pos = (pos == negs).all(-1)
|
769 |
+
pos = pos.unsqueeze(0)
|
770 |
+
targets = torch.cat([pos, negs], dim=0)
|
771 |
+
|
772 |
+
logits = torch.cosine_similarity(
|
773 |
+
x.float(), targets.float(), dim=-1
|
774 |
+
).type_as(x)
|
775 |
+
logits /= self.logit_temp
|
776 |
+
if neg_is_pos.any():
|
777 |
+
logits[1:][neg_is_pos] = float("-inf")
|
778 |
+
logits = logits.transpose(0, 1) # (num_x, num_cls+1)
|
779 |
+
return logits
|
av_hubert/avhubert/hubert_asr.py
ADDED
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import sys,logging
|
8 |
+
import contextlib
|
9 |
+
import tempfile
|
10 |
+
from argparse import Namespace
|
11 |
+
from typing import Any, Optional
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
from dataclasses import dataclass, field
|
16 |
+
from fairseq import checkpoint_utils, tasks, utils
|
17 |
+
from fairseq.dataclass import FairseqDataclass
|
18 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
19 |
+
from fairseq.models import BaseFairseqModel, FairseqEncoder, FairseqEncoderDecoderModel, register_model
|
20 |
+
from fairseq.models.hubert.hubert import MASKING_DISTRIBUTION_CHOICES
|
21 |
+
from fairseq.tasks import FairseqTask
|
22 |
+
from omegaconf import II, MISSING
|
23 |
+
|
24 |
+
DBG=True if len(sys.argv) == 1 else False
|
25 |
+
|
26 |
+
if DBG:
|
27 |
+
from hubert import AVHubertModel
|
28 |
+
from decoder import TransformerDecoder
|
29 |
+
else:
|
30 |
+
from .hubert import AVHubertModel
|
31 |
+
from .decoder import TransformerDecoder
|
32 |
+
|
33 |
+
logger = logging.getLogger(__name__)
|
34 |
+
|
35 |
+
|
36 |
+
@dataclass
|
37 |
+
class AVHubertAsrConfig(FairseqDataclass):
|
38 |
+
w2v_path: str = field(
|
39 |
+
default=MISSING, metadata={"help": "path to hubert model"}
|
40 |
+
)
|
41 |
+
no_pretrained_weights: bool = field(
|
42 |
+
default=False,
|
43 |
+
metadata={"help": "if true, does not load pretrained weights"},
|
44 |
+
)
|
45 |
+
dropout_input: float = field(
|
46 |
+
default=0.0,
|
47 |
+
metadata={"help": "dropout to apply to the input (after feat extr)"},
|
48 |
+
)
|
49 |
+
final_dropout: float = field(
|
50 |
+
default=0.0,
|
51 |
+
metadata={
|
52 |
+
"help": "dropout after transformer and before final projection"
|
53 |
+
},
|
54 |
+
)
|
55 |
+
dropout: float = field(
|
56 |
+
default=0.0,
|
57 |
+
metadata={"help": "dropout probability inside hubert model"},
|
58 |
+
)
|
59 |
+
attention_dropout: float = field(
|
60 |
+
default=0.0,
|
61 |
+
metadata={
|
62 |
+
"help": "dropout probability for attention weights "
|
63 |
+
"inside hubert model"
|
64 |
+
},
|
65 |
+
)
|
66 |
+
activation_dropout: float = field(
|
67 |
+
default=0.0,
|
68 |
+
metadata={
|
69 |
+
"help": "dropout probability after activation in FFN "
|
70 |
+
"inside hubert model"
|
71 |
+
},
|
72 |
+
)
|
73 |
+
|
74 |
+
# masking
|
75 |
+
apply_mask: bool = field(
|
76 |
+
default=False, metadata={"help": "apply masking during fine-tuning"}
|
77 |
+
)
|
78 |
+
mask_length: int = field(
|
79 |
+
default=10, metadata={"help": "repeat the mask indices multiple times"}
|
80 |
+
)
|
81 |
+
mask_prob: float = field(
|
82 |
+
default=0.5,
|
83 |
+
metadata={
|
84 |
+
"help": "probability of replacing a token with mask "
|
85 |
+
"(normalized by length)"
|
86 |
+
},
|
87 |
+
)
|
88 |
+
mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
89 |
+
default="static", metadata={"help": "how to choose masks"}
|
90 |
+
)
|
91 |
+
mask_other: float = field(
|
92 |
+
default=0,
|
93 |
+
metadata={
|
94 |
+
"help": "secondary mask argument "
|
95 |
+
"(used for more complex distributions), "
|
96 |
+
"see help in compute_mask_indices"
|
97 |
+
},
|
98 |
+
)
|
99 |
+
no_mask_overlap: bool = field(
|
100 |
+
default=False, metadata={"help": "whether to allow masks to overlap"}
|
101 |
+
)
|
102 |
+
|
103 |
+
# channel masking
|
104 |
+
mask_channel_length: int = field(
|
105 |
+
default=10,
|
106 |
+
metadata={"help": "length of the mask for features (channels)"},
|
107 |
+
)
|
108 |
+
mask_channel_prob: float = field(
|
109 |
+
default=0.0,
|
110 |
+
metadata={"help": "probability of replacing a feature with 0"},
|
111 |
+
)
|
112 |
+
mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
113 |
+
default="static",
|
114 |
+
metadata={"help": "how to choose mask length for channel masking"},
|
115 |
+
)
|
116 |
+
mask_channel_other: float = field(
|
117 |
+
default=0,
|
118 |
+
metadata={
|
119 |
+
"help": "secondary mask argument "
|
120 |
+
"(used for more complex distributions), "
|
121 |
+
"see help in compute_mask_indices"
|
122 |
+
},
|
123 |
+
)
|
124 |
+
no_mask_channel_overlap: bool = field(
|
125 |
+
default=False,
|
126 |
+
metadata={"help": "whether to allow channel masks to overlap"},
|
127 |
+
)
|
128 |
+
freeze_finetune_updates: int = field(
|
129 |
+
default=0,
|
130 |
+
metadata={"help": "dont finetune hubert for this many updates"},
|
131 |
+
)
|
132 |
+
feature_grad_mult: float = field(
|
133 |
+
default=0.0,
|
134 |
+
metadata={"help": "reset feature grad mult in hubert to this"},
|
135 |
+
)
|
136 |
+
layerdrop: float = field(
|
137 |
+
default=0.0,
|
138 |
+
metadata={"help": "probability of dropping a layer in hubert"},
|
139 |
+
)
|
140 |
+
normalize: bool = II("task.normalize")
|
141 |
+
data: str = II("task.data")
|
142 |
+
|
143 |
+
# this holds the loaded hubert args
|
144 |
+
w2v_args: Any = None
|
145 |
+
|
146 |
+
|
147 |
+
@dataclass
|
148 |
+
class AVHubertCtcConfig(AVHubertAsrConfig):
|
149 |
+
pass
|
150 |
+
|
151 |
+
|
152 |
+
@register_model("av_hubert_ctc", dataclass=AVHubertCtcConfig)
|
153 |
+
class AVHubertCtc(BaseFairseqModel):
|
154 |
+
def __init__(self, cfg: AVHubertCtcConfig, w2v_encoder: BaseFairseqModel):
|
155 |
+
super().__init__()
|
156 |
+
self.cfg = cfg
|
157 |
+
self.w2v_encoder = w2v_encoder
|
158 |
+
|
159 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
160 |
+
super().upgrade_state_dict_named(state_dict, name)
|
161 |
+
return state_dict
|
162 |
+
|
163 |
+
@classmethod
|
164 |
+
def build_model(cls, cfg: AVHubertCtcConfig, task: FairseqTask):
|
165 |
+
"""Build a new model instance."""
|
166 |
+
w2v_encoder = HubertEncoder(cfg, task.target_dictionary)
|
167 |
+
return cls(cfg, w2v_encoder)
|
168 |
+
|
169 |
+
def get_normalized_probs(self, net_output, log_probs):
|
170 |
+
"""Get normalized probabilities (or log probs) from a net's output."""
|
171 |
+
|
172 |
+
logits = net_output["encoder_out"]
|
173 |
+
if log_probs:
|
174 |
+
return utils.log_softmax(logits.float(), dim=-1)
|
175 |
+
else:
|
176 |
+
return utils.softmax(logits.float(), dim=-1)
|
177 |
+
|
178 |
+
def get_logits(self, net_output):
|
179 |
+
logits = net_output["encoder_out"]
|
180 |
+
padding = net_output["encoder_padding_mask"]
|
181 |
+
if padding is not None and padding.any():
|
182 |
+
padding = padding.T
|
183 |
+
logits[padding][..., 0] = 0
|
184 |
+
logits[padding][..., 1:] = float("-inf")
|
185 |
+
|
186 |
+
return logits
|
187 |
+
|
188 |
+
def forward(self, **kwargs):
|
189 |
+
x = self.w2v_encoder(**kwargs)
|
190 |
+
return x
|
191 |
+
|
192 |
+
|
193 |
+
@dataclass
|
194 |
+
class AVHubertSeq2SeqConfig(AVHubertAsrConfig):
|
195 |
+
decoder_embed_dim: int = field(
|
196 |
+
default=768, metadata={"help": "decoder embedding dimension"}
|
197 |
+
)
|
198 |
+
decoder_ffn_embed_dim: int = field(
|
199 |
+
default=3072, metadata={"help": "decoder embedding dimension for FFN"}
|
200 |
+
)
|
201 |
+
decoder_layers: int = field(
|
202 |
+
default=6, metadata={"help": "num of decoder layers"}
|
203 |
+
)
|
204 |
+
decoder_layerdrop: float = field(
|
205 |
+
default=0.0, metadata={"help": "decoder layerdrop chance"}
|
206 |
+
)
|
207 |
+
decoder_attention_heads: int = field(
|
208 |
+
default=4, metadata={"help": "num decoder attention heads"}
|
209 |
+
)
|
210 |
+
decoder_learned_pos: bool = field(
|
211 |
+
default=False,
|
212 |
+
metadata={"help": "use learned positional embeddings in the decoder"},
|
213 |
+
)
|
214 |
+
decoder_normalize_before: bool = field(
|
215 |
+
default=False,
|
216 |
+
metadata={"help": "apply layernorm before each decoder block"},
|
217 |
+
)
|
218 |
+
no_token_positional_embeddings: bool = field(
|
219 |
+
default=False,
|
220 |
+
metadata={
|
221 |
+
"help": "if set, disables positional embeddings "
|
222 |
+
"(outside self attention)"
|
223 |
+
},
|
224 |
+
)
|
225 |
+
decoder_dropout: float = field(
|
226 |
+
default=0.0, metadata={"help": "dropout probability in the decoder"}
|
227 |
+
)
|
228 |
+
decoder_attention_dropout: float = field(
|
229 |
+
default=0.0,
|
230 |
+
metadata={
|
231 |
+
"help": "dropout probability for attention weights "
|
232 |
+
"inside the decoder"
|
233 |
+
},
|
234 |
+
)
|
235 |
+
decoder_activation_dropout: float = field(
|
236 |
+
default=0.0,
|
237 |
+
metadata={
|
238 |
+
"help": "dropout probability after activation in FFN "
|
239 |
+
"inside the decoder"
|
240 |
+
},
|
241 |
+
)
|
242 |
+
max_target_positions: int = field(
|
243 |
+
default=2048, metadata={"help": "max target positions"}
|
244 |
+
)
|
245 |
+
share_decoder_input_output_embed: bool = field(
|
246 |
+
default=False,
|
247 |
+
metadata={"help": "share decoder input and output embeddings"},
|
248 |
+
)
|
249 |
+
no_scale_embedding: bool = field(default=True, metadata={'help': 'scale embedding'})
|
250 |
+
|
251 |
+
class HubertEncoder(FairseqEncoder):
|
252 |
+
def __init__(self, cfg: AVHubertAsrConfig, tgt_dict=None):
|
253 |
+
self.apply_mask = cfg.apply_mask
|
254 |
+
|
255 |
+
arg_overrides = {
|
256 |
+
"dropout": cfg.dropout,
|
257 |
+
"activation_dropout": cfg.activation_dropout,
|
258 |
+
"dropout_input": cfg.dropout_input,
|
259 |
+
"attention_dropout": cfg.attention_dropout,
|
260 |
+
"mask_length": cfg.mask_length,
|
261 |
+
"mask_prob": cfg.mask_prob,
|
262 |
+
"mask_selection": cfg.mask_selection,
|
263 |
+
"mask_other": cfg.mask_other,
|
264 |
+
"no_mask_overlap": cfg.no_mask_overlap,
|
265 |
+
"mask_channel_length": cfg.mask_channel_length,
|
266 |
+
"mask_channel_prob": cfg.mask_channel_prob,
|
267 |
+
"mask_channel_selection": cfg.mask_channel_selection,
|
268 |
+
"mask_channel_other": cfg.mask_channel_other,
|
269 |
+
"no_mask_channel_overlap": cfg.no_mask_channel_overlap,
|
270 |
+
"encoder_layerdrop": cfg.layerdrop,
|
271 |
+
"feature_grad_mult": cfg.feature_grad_mult,
|
272 |
+
}
|
273 |
+
|
274 |
+
if cfg.w2v_args is None:
|
275 |
+
state = checkpoint_utils.load_checkpoint_to_cpu(
|
276 |
+
cfg.w2v_path, arg_overrides
|
277 |
+
)
|
278 |
+
w2v_args = state.get("cfg", None)
|
279 |
+
if w2v_args is None:
|
280 |
+
w2v_args = convert_namespace_to_omegaconf(state["args"])
|
281 |
+
cfg.w2v_args = w2v_args
|
282 |
+
else:
|
283 |
+
state = None
|
284 |
+
w2v_args = cfg.w2v_args
|
285 |
+
if isinstance(w2v_args, Namespace):
|
286 |
+
cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(
|
287 |
+
w2v_args
|
288 |
+
)
|
289 |
+
|
290 |
+
assert cfg.normalize == w2v_args.task.normalize, (
|
291 |
+
"Fine-tuning works best when data normalization is the same. "
|
292 |
+
"Please check that --normalize is set or unset for "
|
293 |
+
"both pre-training and here"
|
294 |
+
)
|
295 |
+
|
296 |
+
w2v_args.task.data = cfg.data
|
297 |
+
|
298 |
+
task = tasks.setup_task(w2v_args.task)
|
299 |
+
model = task.build_model(w2v_args.model)
|
300 |
+
|
301 |
+
if state is not None and not cfg.no_pretrained_weights:
|
302 |
+
# set strict=False because we omit some modules
|
303 |
+
model.load_state_dict(state["model"], strict=False)
|
304 |
+
|
305 |
+
model.remove_pretraining_modules()
|
306 |
+
|
307 |
+
super().__init__(task.source_dictionary)
|
308 |
+
|
309 |
+
d = model.encoder.embedding_dim
|
310 |
+
|
311 |
+
self.w2v_model = model
|
312 |
+
|
313 |
+
self.final_dropout = nn.Dropout(cfg.final_dropout)
|
314 |
+
self.freeze_finetune_updates = cfg.freeze_finetune_updates
|
315 |
+
self.num_updates = 0
|
316 |
+
|
317 |
+
if tgt_dict is not None:
|
318 |
+
self.proj = Linear(d, len(tgt_dict))
|
319 |
+
elif getattr(cfg, "decoder_embed_dim", d) != d:
|
320 |
+
self.proj = Linear(d, cfg.decoder_embed_dim)
|
321 |
+
else:
|
322 |
+
self.proj = None
|
323 |
+
|
324 |
+
def set_num_updates(self, num_updates):
|
325 |
+
"""Set the number of parameters updates."""
|
326 |
+
super().set_num_updates(num_updates)
|
327 |
+
self.num_updates = num_updates
|
328 |
+
|
329 |
+
def forward(self, source, padding_mask, tbc=True, **kwargs):
|
330 |
+
|
331 |
+
w2v_args = {
|
332 |
+
"source": source,
|
333 |
+
"padding_mask": padding_mask,
|
334 |
+
"mask": self.apply_mask and self.training,
|
335 |
+
}
|
336 |
+
ft = self.freeze_finetune_updates <= self.num_updates
|
337 |
+
|
338 |
+
with torch.no_grad() if not ft else contextlib.ExitStack():
|
339 |
+
x, padding_mask = self.w2v_model.extract_finetune(**w2v_args)
|
340 |
+
|
341 |
+
if tbc:
|
342 |
+
# B x T x C -> T x B x C
|
343 |
+
x = x.transpose(0, 1)
|
344 |
+
|
345 |
+
x = self.final_dropout(x)
|
346 |
+
|
347 |
+
if self.proj:
|
348 |
+
x = self.proj(x)
|
349 |
+
|
350 |
+
return {
|
351 |
+
"encoder_out": x, # T x B x C
|
352 |
+
"encoder_padding_mask": padding_mask, # B x T
|
353 |
+
"padding_mask": padding_mask,
|
354 |
+
}
|
355 |
+
|
356 |
+
def reorder_encoder_out(self, encoder_out, new_order):
|
357 |
+
if encoder_out["encoder_out"] is not None:
|
358 |
+
encoder_out["encoder_out"] = encoder_out[
|
359 |
+
"encoder_out"
|
360 |
+
].index_select(1, new_order)
|
361 |
+
if encoder_out["encoder_padding_mask"] is not None:
|
362 |
+
encoder_out["encoder_padding_mask"] = encoder_out[
|
363 |
+
"encoder_padding_mask"
|
364 |
+
].index_select(0, new_order)
|
365 |
+
return encoder_out
|
366 |
+
|
367 |
+
def max_positions(self):
|
368 |
+
"""Maximum input length supported by the encoder."""
|
369 |
+
return None
|
370 |
+
|
371 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
372 |
+
return state_dict
|
373 |
+
|
374 |
+
|
375 |
+
class HubertEncoderWrapper(FairseqEncoder):
|
376 |
+
def __init__(self, w2v_model):
|
377 |
+
super().__init__(None)
|
378 |
+
self.w2v_model = w2v_model
|
379 |
+
|
380 |
+
def forward(self, source, padding_mask, **kwargs):
|
381 |
+
w2v_args = {
|
382 |
+
"source": source,
|
383 |
+
"padding_mask": padding_mask,
|
384 |
+
}
|
385 |
+
|
386 |
+
x, padding_mask = self.w2v_model.extract_finetune(**w2v_args)
|
387 |
+
# B x T x C -> T x B x C
|
388 |
+
x = x.transpose(0, 1)
|
389 |
+
|
390 |
+
return {
|
391 |
+
"encoder_out": x, # T x B x C
|
392 |
+
"encoder_padding_mask": padding_mask, # B x T
|
393 |
+
"padding_mask": padding_mask
|
394 |
+
}
|
395 |
+
|
396 |
+
def reorder_encoder_out(self, encoder_out, new_order):
|
397 |
+
if encoder_out["encoder_out"] is not None:
|
398 |
+
encoder_out["encoder_out"] = encoder_out[
|
399 |
+
"encoder_out"
|
400 |
+
].index_select(1, new_order)
|
401 |
+
if encoder_out["encoder_padding_mask"] is not None:
|
402 |
+
encoder_out["encoder_padding_mask"] = encoder_out[
|
403 |
+
"encoder_padding_mask"
|
404 |
+
].index_select(0, new_order)
|
405 |
+
if encoder_out["padding_mask"] is not None:
|
406 |
+
encoder_out["padding_mask"] = encoder_out[
|
407 |
+
"padding_mask"
|
408 |
+
].index_select(0, new_order)
|
409 |
+
return encoder_out
|
410 |
+
|
411 |
+
@register_model("av_hubert_seq2seq", dataclass=AVHubertSeq2SeqConfig)
|
412 |
+
class AVHubertSeq2Seq(FairseqEncoderDecoderModel):
|
413 |
+
def __init__(self, encoder, decoder, tgt_dict, cfg):
|
414 |
+
super().__init__(encoder, decoder)
|
415 |
+
self.cfg = cfg
|
416 |
+
self.freeze_finetune_updates = cfg.freeze_finetune_updates
|
417 |
+
|
418 |
+
@classmethod
|
419 |
+
def build_model(cls, cfg, task):
|
420 |
+
"""Build a new model instance."""
|
421 |
+
|
422 |
+
arg_overrides = {
|
423 |
+
"dropout": cfg.dropout,
|
424 |
+
"activation_dropout": cfg.activation_dropout,
|
425 |
+
"dropout_input": cfg.dropout_input,
|
426 |
+
"attention_dropout": cfg.attention_dropout,
|
427 |
+
"mask_length": cfg.mask_length,
|
428 |
+
"mask_prob": cfg.mask_prob,
|
429 |
+
"mask_selection": cfg.mask_selection,
|
430 |
+
"mask_other": cfg.mask_other,
|
431 |
+
"no_mask_overlap": cfg.no_mask_overlap,
|
432 |
+
"mask_channel_length": cfg.mask_channel_length,
|
433 |
+
"mask_channel_prob": cfg.mask_channel_prob,
|
434 |
+
"mask_channel_selection": cfg.mask_channel_selection,
|
435 |
+
"mask_channel_other": cfg.mask_channel_other,
|
436 |
+
"no_mask_channel_overlap": cfg.no_mask_channel_overlap,
|
437 |
+
"encoder_layerdrop": cfg.layerdrop,
|
438 |
+
"feature_grad_mult": cfg.feature_grad_mult,
|
439 |
+
}
|
440 |
+
|
441 |
+
if cfg.w2v_args is None:
|
442 |
+
state = checkpoint_utils.load_checkpoint_to_cpu(
|
443 |
+
cfg.w2v_path, arg_overrides
|
444 |
+
)
|
445 |
+
w2v_args = state.get("cfg", None)
|
446 |
+
if w2v_args is None:
|
447 |
+
w2v_args = convert_namespace_to_omegaconf(state["args"])
|
448 |
+
cfg.w2v_args = w2v_args
|
449 |
+
else:
|
450 |
+
state = None
|
451 |
+
w2v_args = cfg.w2v_args
|
452 |
+
if isinstance(w2v_args, Namespace):
|
453 |
+
cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(
|
454 |
+
w2v_args
|
455 |
+
)
|
456 |
+
|
457 |
+
assert cfg.normalize == w2v_args.task.normalize, (
|
458 |
+
"Fine-tuning works best when data normalization is the same. "
|
459 |
+
"Please check that --normalize is set or unset for "
|
460 |
+
"both pre-training and here"
|
461 |
+
)
|
462 |
+
|
463 |
+
w2v_args.task.data = cfg.data
|
464 |
+
|
465 |
+
task_pretrain = tasks.setup_task(w2v_args.task)
|
466 |
+
if state is not None:
|
467 |
+
task_pretrain.load_state_dict(state['task_state'])
|
468 |
+
|
469 |
+
encoder_ = task_pretrain.build_model(w2v_args.model)
|
470 |
+
|
471 |
+
encoder = HubertEncoderWrapper(encoder_)
|
472 |
+
if state is not None and not cfg.no_pretrained_weights:
|
473 |
+
# set strict=False because we omit some modules
|
474 |
+
del state['model']['mask_emb']
|
475 |
+
encoder.w2v_model.load_state_dict(state["model"], strict=False)
|
476 |
+
|
477 |
+
encoder.w2v_model.remove_pretraining_modules()
|
478 |
+
|
479 |
+
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
|
480 |
+
|
481 |
+
def build_embedding(dictionary, embed_dim):
|
482 |
+
num_embeddings = len(dictionary)
|
483 |
+
padding_idx = dictionary.pad()
|
484 |
+
emb = Embedding(num_embeddings, embed_dim, padding_idx=padding_idx)
|
485 |
+
return emb
|
486 |
+
|
487 |
+
decoder_embed_tokens = build_embedding(tgt_dict, cfg.decoder_embed_dim)
|
488 |
+
decoder = TransformerDecoder(cfg, tgt_dict, decoder_embed_tokens)
|
489 |
+
|
490 |
+
return AVHubertSeq2Seq(encoder, decoder, tgt_dict, cfg)
|
491 |
+
|
492 |
+
|
493 |
+
def forward(self, **kwargs):
|
494 |
+
ft = self.freeze_finetune_updates <= self.num_updates
|
495 |
+
with torch.no_grad() if not ft else contextlib.ExitStack():
|
496 |
+
output = self.encoder(**kwargs)
|
497 |
+
decoder_out = self.decoder(prev_output_tokens=kwargs['prev_output_tokens'], encoder_out=output)
|
498 |
+
return decoder_out
|
499 |
+
|
500 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
501 |
+
super().upgrade_state_dict_named(state_dict, name)
|
502 |
+
return state_dict
|
503 |
+
|
504 |
+
def set_num_updates(self, num_updates):
|
505 |
+
"""Set the number of parameters updates."""
|
506 |
+
super().set_num_updates(num_updates)
|
507 |
+
self.num_updates = num_updates
|
508 |
+
|
509 |
+
def Embedding(num_embeddings, embedding_dim, padding_idx):
|
510 |
+
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
511 |
+
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
|
512 |
+
nn.init.constant_(m.weight[padding_idx], 0)
|
513 |
+
return m
|
514 |
+
|
515 |
+
|
516 |
+
def Linear(in_features, out_features, bias=True):
|
517 |
+
m = nn.Linear(in_features, out_features, bias)
|
518 |
+
nn.init.xavier_uniform_(m.weight)
|
519 |
+
if bias:
|
520 |
+
nn.init.constant_(m.bias, 0.0)
|
521 |
+
return m
|
av_hubert/avhubert/hubert_criterion.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
import re
|
9 |
+
from dataclasses import dataclass, field
|
10 |
+
from typing import List, Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from fairseq import metrics, utils
|
15 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
16 |
+
from fairseq.dataclass import FairseqDataclass
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class AVHubertCriterionConfig(FairseqDataclass):
|
21 |
+
pred_masked_weight: float = field(
|
22 |
+
default=1.0,
|
23 |
+
metadata={"help": "weight for predictive loss for masked frames"},
|
24 |
+
)
|
25 |
+
pred_nomask_weight: float = field(
|
26 |
+
default=0.0,
|
27 |
+
metadata={"help": "weight for predictive loss for unmasked frames"},
|
28 |
+
)
|
29 |
+
loss_weights: Optional[List[float]] = field(
|
30 |
+
default=None,
|
31 |
+
metadata={"help": "weights for additional loss terms (not first one)"},
|
32 |
+
)
|
33 |
+
log_keys: List[str] = field(
|
34 |
+
default_factory=lambda: [],
|
35 |
+
metadata={"help": "output keys to log"},
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
@register_criterion("av_hubert", dataclass=AVHubertCriterionConfig)
|
40 |
+
class AVHubertCriterion(FairseqCriterion):
|
41 |
+
def __init__(self, task, pred_masked_weight, pred_nomask_weight, loss_weights=None, log_keys=None):
|
42 |
+
super().__init__(task)
|
43 |
+
self.pred_masked_weight = pred_masked_weight
|
44 |
+
self.pred_nomask_weight = pred_nomask_weight
|
45 |
+
self.loss_weights = loss_weights
|
46 |
+
self.log_keys = [] if log_keys is None else log_keys
|
47 |
+
|
48 |
+
def forward(self, model, sample, reduce=True, log_pred=False):
|
49 |
+
"""Compute the loss for the given sample.
|
50 |
+
Returns a tuple with three elements:
|
51 |
+
1) the loss
|
52 |
+
2) the sample size, which is used as the denominator for the gradient
|
53 |
+
3) logging outputs to display while training
|
54 |
+
"""
|
55 |
+
net_output = model(target_list=sample["target_list"], **sample["net_input"])
|
56 |
+
loss = 0.
|
57 |
+
sample_size = 0
|
58 |
+
logging_output = {}
|
59 |
+
reduction = "sum" if reduce else "none"
|
60 |
+
|
61 |
+
loss_m_list = []
|
62 |
+
logp_m_list, targ_m_list = net_output['logit_m_list'], net_output['target_m_list']
|
63 |
+
for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
|
64 |
+
loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
|
65 |
+
loss_m_list.append(loss_m)
|
66 |
+
logging_output[f"loss_m_{i}"] = loss_m.detach().item()
|
67 |
+
if self.pred_masked_weight > 0:
|
68 |
+
loss += self.pred_masked_weight * sum(loss_m_list)
|
69 |
+
sample_size += targ_m_list[0].numel()
|
70 |
+
|
71 |
+
loss_u_list = []
|
72 |
+
logp_u_list, targ_u_list = net_output['logit_u_list'], net_output['target_u_list']
|
73 |
+
for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
|
74 |
+
loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
|
75 |
+
loss_u_list.append(loss_u)
|
76 |
+
logging_output[f"loss_u_{i}"] = loss_u.detach().item()
|
77 |
+
if self.pred_nomask_weight > 0:
|
78 |
+
loss += self.pred_nomask_weight * sum(loss_u_list)
|
79 |
+
sample_size += targ_u_list[0].numel()
|
80 |
+
|
81 |
+
if self.loss_weights is not None:
|
82 |
+
assert hasattr(model, "get_extra_losses")
|
83 |
+
extra_losses, names = model.get_extra_losses(net_output)
|
84 |
+
if torch.is_tensor(extra_losses):
|
85 |
+
extra_losses = [extra_losses]
|
86 |
+
names = [names]
|
87 |
+
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
|
88 |
+
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
|
89 |
+
assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}"
|
90 |
+
for p, n, coef in zip(extra_losses, names, self.loss_weights):
|
91 |
+
if coef != 0 and p is not None:
|
92 |
+
p = coef * p.float() * sample_size
|
93 |
+
loss += p
|
94 |
+
logging_output[f"loss_{n}"] = p.item()
|
95 |
+
|
96 |
+
logging_output = {
|
97 |
+
"loss": loss.item() if reduce else loss,
|
98 |
+
"ntokens": sample_size,
|
99 |
+
"nsentences": sample["id"].numel(),
|
100 |
+
"sample_size": sample_size,
|
101 |
+
**logging_output,
|
102 |
+
}
|
103 |
+
|
104 |
+
for lk in self.log_keys:
|
105 |
+
if lk in net_output:
|
106 |
+
logging_output[lk] = float((net_output[lk]))
|
107 |
+
|
108 |
+
with torch.no_grad():
|
109 |
+
for i, logp_m in enumerate(logp_m_list):
|
110 |
+
# corr_m, count_m = compute_correct(logp_m)
|
111 |
+
if logp_m.numel() == 0:
|
112 |
+
corr_m, count_m = 0, 0
|
113 |
+
else:
|
114 |
+
corr_m, count_m = (logp_m.argmax(dim=-1)==targ_m_list[i]).sum().item(), len(targ_m_list[i])
|
115 |
+
logging_output[f"correct_m_{i}"] = corr_m
|
116 |
+
logging_output[f"count_m_{i}"] = count_m
|
117 |
+
|
118 |
+
for i, logp_u in enumerate(logp_u_list):
|
119 |
+
if logp_u.numel() == 0:
|
120 |
+
corr_u, count_u = 0, 0
|
121 |
+
else:
|
122 |
+
corr_u, count_u = (logp_u.argmax(dim=-1)==targ_u_list[i]).sum().item(), len(targ_u_list[i])
|
123 |
+
logging_output[f"correct_u_{i}"] = corr_u
|
124 |
+
logging_output[f"count_u_{i}"] = count_u
|
125 |
+
|
126 |
+
return loss, sample_size, logging_output
|
127 |
+
|
128 |
+
@staticmethod
|
129 |
+
def reduce_metrics(logging_outputs) -> None:
|
130 |
+
"""Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
|
131 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
132 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
133 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
134 |
+
|
135 |
+
metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3)
|
136 |
+
if sample_size != ntokens:
|
137 |
+
metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3)
|
138 |
+
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg))
|
139 |
+
else:
|
140 |
+
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg))
|
141 |
+
|
142 |
+
counts = {}
|
143 |
+
for lk in logging_outputs[0].keys():
|
144 |
+
if lk.startswith("count_"):
|
145 |
+
val = sum(log[lk] for log in logging_outputs)
|
146 |
+
metrics.log_scalar(lk, val)
|
147 |
+
counts[lk] = val
|
148 |
+
|
149 |
+
for lk in logging_outputs[0].keys():
|
150 |
+
if lk.startswith("loss_"):
|
151 |
+
val = sum(log[lk] for log in logging_outputs)
|
152 |
+
metrics.log_scalar(lk, val / sample_size / math.log(2), round=3)
|
153 |
+
elif lk.startswith("correct_"):
|
154 |
+
val = sum(log[lk] for log in logging_outputs)
|
155 |
+
metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)])
|
156 |
+
|
157 |
+
@staticmethod
|
158 |
+
def aggregate_logging_outputs(logging_outputs):
|
159 |
+
"""Aggregate logging outputs from data parallel training."""
|
160 |
+
raise NotImplementedError()
|
161 |
+
|
162 |
+
@staticmethod
|
163 |
+
def logging_outputs_can_be_summed() -> bool:
|
164 |
+
"""
|
165 |
+
Whether the logging outputs returned by `forward` can be summed
|
166 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
167 |
+
to True will improves distributed training speed.
|
168 |
+
"""
|
169 |
+
return False
|
av_hubert/avhubert/hubert_dataset.py
ADDED
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import itertools
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
import time
|
12 |
+
from typing import Any, List, Optional, Union
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from fairseq.data import data_utils
|
19 |
+
from fairseq.data.fairseq_dataset import FairseqDataset
|
20 |
+
from python_speech_features import logfbank
|
21 |
+
from scipy.io import wavfile
|
22 |
+
|
23 |
+
DBG=True if len(sys.argv) == 1 else False
|
24 |
+
|
25 |
+
if DBG:
|
26 |
+
import utils as custom_utils
|
27 |
+
logging.basicConfig(
|
28 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
29 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
30 |
+
level=os.environ.get("LOGLEVEL", "DEBUG").upper(),
|
31 |
+
stream=sys.stdout,
|
32 |
+
)
|
33 |
+
else:
|
34 |
+
from . import utils as custom_utils
|
35 |
+
|
36 |
+
logger = logging.getLogger(__name__)
|
37 |
+
|
38 |
+
|
39 |
+
def load_audio_visual(manifest_path, max_keep, min_keep, frame_rate, label_paths, label_rates, tol=0.1):
|
40 |
+
def is_audio_label_aligned(audio_dur, label_durs):
|
41 |
+
return all([abs(audio_dur - label_dur)<tol for label_dur in label_durs])
|
42 |
+
|
43 |
+
n_long, n_short, n_unaligned = 0, 0, 0
|
44 |
+
names, inds, sizes = [], [], []
|
45 |
+
dur_from_label_list = []
|
46 |
+
is_seq_label = any([x==-1 for x in label_rates])
|
47 |
+
for label_path, label_rate in zip(label_paths, label_rates):
|
48 |
+
label_lengths = [len(line.rstrip().split())/label_rate for line in open(label_path).readlines()]
|
49 |
+
dur_from_label_list.append(label_lengths)
|
50 |
+
dur_from_label_list = list(zip(*dur_from_label_list))
|
51 |
+
|
52 |
+
with open(manifest_path) as f:
|
53 |
+
root = f.readline().strip()
|
54 |
+
for ind, line in enumerate(f):
|
55 |
+
items = line.strip().split("\t")
|
56 |
+
sz = int(items[-2]) #
|
57 |
+
if min_keep is not None and sz < min_keep:
|
58 |
+
n_short += 1
|
59 |
+
elif max_keep is not None and sz > max_keep:
|
60 |
+
n_long += 1
|
61 |
+
elif (not is_seq_label) and (not is_audio_label_aligned(sz/frame_rate, dur_from_label_list[ind])):
|
62 |
+
n_unaligned += 1
|
63 |
+
else:
|
64 |
+
video_path = items[1]
|
65 |
+
audio_path = items[2]
|
66 |
+
audio_id = items[0]
|
67 |
+
names.append((video_path, audio_path+':'+audio_id))
|
68 |
+
inds.append(ind)
|
69 |
+
sizes.append(sz)
|
70 |
+
tot = ind + 1
|
71 |
+
logger.info(
|
72 |
+
(
|
73 |
+
f"max_keep={max_keep}, min_keep={min_keep}, "
|
74 |
+
f"loaded {len(names)}, skipped {n_short} short and {n_long} long and {n_unaligned} unaligned, "
|
75 |
+
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
|
76 |
+
)
|
77 |
+
)
|
78 |
+
return root, names, inds, tot, sizes
|
79 |
+
|
80 |
+
def load_label(label_path, inds, tot):
|
81 |
+
with open(label_path) as f:
|
82 |
+
labels = [line.rstrip() for line in f]
|
83 |
+
assert (
|
84 |
+
len(labels) == tot
|
85 |
+
), f"number of labels does not match ({len(labels)} != {tot})"
|
86 |
+
labels = [labels[i] for i in inds]
|
87 |
+
return labels
|
88 |
+
|
89 |
+
|
90 |
+
def load_label_offset(label_path, inds, tot):
|
91 |
+
with open(label_path) as f:
|
92 |
+
code_lengths = [len(line.encode("utf-8")) for line in f]
|
93 |
+
assert (
|
94 |
+
len(code_lengths) == tot
|
95 |
+
), f"number of labels does not match ({len(code_lengths)} != {tot})"
|
96 |
+
offsets = list(itertools.accumulate([0] + code_lengths))
|
97 |
+
offsets = [(offsets[i], offsets[i + 1]) for i in inds]
|
98 |
+
return offsets
|
99 |
+
|
100 |
+
|
101 |
+
def verify_label_lengths(
|
102 |
+
audio_sizes,
|
103 |
+
audio_rate,
|
104 |
+
label_path,
|
105 |
+
label_rate,
|
106 |
+
inds,
|
107 |
+
tot,
|
108 |
+
tol=0.1, # tolerance in seconds
|
109 |
+
):
|
110 |
+
if label_rate < 0:
|
111 |
+
logger.info(f"{label_path} is sequence label. skipped")
|
112 |
+
return
|
113 |
+
|
114 |
+
with open(label_path) as f:
|
115 |
+
lengths = [len(line.rstrip().split()) for line in f]
|
116 |
+
assert len(lengths) == tot
|
117 |
+
lengths = [lengths[i] for i in inds]
|
118 |
+
num_invalid = 0
|
119 |
+
for i, ind in enumerate(inds):
|
120 |
+
dur_from_audio = audio_sizes[i] / audio_rate
|
121 |
+
dur_from_label = lengths[i] / label_rate
|
122 |
+
if abs(dur_from_audio - dur_from_label) > tol:
|
123 |
+
logger.warning(
|
124 |
+
(
|
125 |
+
f"audio and label duration differ too much "
|
126 |
+
f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
|
127 |
+
f"in line {ind+1} of {label_path}. Check if `label_rate` "
|
128 |
+
f"is correctly set (currently {label_rate}). "
|
129 |
+
f"num. of samples = {audio_sizes[i]}; "
|
130 |
+
f"label length = {lengths[i]}"
|
131 |
+
)
|
132 |
+
)
|
133 |
+
num_invalid += 1
|
134 |
+
if num_invalid > 0:
|
135 |
+
logger.warning(
|
136 |
+
f"total {num_invalid} (audio, label) pairs with mismatched lengths"
|
137 |
+
)
|
138 |
+
|
139 |
+
|
140 |
+
class AVHubertDataset(FairseqDataset):
|
141 |
+
def __init__(
|
142 |
+
self,
|
143 |
+
manifest_path: str,
|
144 |
+
sample_rate: float,
|
145 |
+
label_paths: List[str],
|
146 |
+
label_rates: Union[List[float], float], # -1 for sequence labels
|
147 |
+
pad_list: List[str],
|
148 |
+
eos_list: List[str],
|
149 |
+
label_processors: Optional[List[Any]] = None,
|
150 |
+
max_keep_sample_size: Optional[int] = None,
|
151 |
+
min_keep_sample_size: Optional[int] = None,
|
152 |
+
max_sample_size: Optional[int] = None,
|
153 |
+
shuffle: bool = True,
|
154 |
+
pad_audio: bool = False,
|
155 |
+
normalize: bool = False,
|
156 |
+
store_labels: bool = True,
|
157 |
+
random_crop: bool = False,
|
158 |
+
single_target: bool = False,
|
159 |
+
stack_order_audio: int=1,
|
160 |
+
skip_verify: bool=False,
|
161 |
+
image_mean: float=0,
|
162 |
+
image_std: float=1,
|
163 |
+
image_crop_size: int=88,
|
164 |
+
image_aug: bool=False,
|
165 |
+
modalities: Optional[List[str]]=None,
|
166 |
+
is_s2s=False,
|
167 |
+
noise_fn=None,
|
168 |
+
noise_prob=0,
|
169 |
+
noise_snr=0,
|
170 |
+
noise_num=1
|
171 |
+
):
|
172 |
+
self.label_rates = (
|
173 |
+
[label_rates for _ in range(len(label_paths))]
|
174 |
+
if isinstance(label_rates, int)
|
175 |
+
else label_rates
|
176 |
+
)
|
177 |
+
self.modalities = set(modalities)
|
178 |
+
self.audio_root, self.names, inds, tot, self.sizes = load_audio_visual(manifest_path, max_keep_sample_size, min_keep_sample_size, frame_rate=sample_rate, label_paths=label_paths, label_rates=self.label_rates)
|
179 |
+
self.sample_rate = sample_rate
|
180 |
+
self.stack_order_audio = stack_order_audio
|
181 |
+
self.shuffle = shuffle
|
182 |
+
self.random_crop = random_crop
|
183 |
+
|
184 |
+
self.num_labels = len(label_paths)
|
185 |
+
self.pad_list = pad_list
|
186 |
+
self.eos_list = eos_list
|
187 |
+
self.label_processors = label_processors
|
188 |
+
self.single_target = single_target
|
189 |
+
self.store_labels = store_labels
|
190 |
+
self.is_s2s = is_s2s
|
191 |
+
self.noise_wav, self.noise_prob, self.noise_snr, self.noise_num = [ln.strip() for ln in open(noise_fn).readlines()] if noise_fn is not None else [], noise_prob, noise_snr, noise_num
|
192 |
+
|
193 |
+
assert self.single_target == (self.label_rates[0] == -1), f"single target should be equivalent to sequence label (label_rate==-1)"
|
194 |
+
if store_labels:
|
195 |
+
self.label_list = [load_label(p, inds, tot) for p in label_paths]
|
196 |
+
else:
|
197 |
+
self.label_paths = label_paths
|
198 |
+
self.label_offsets_list = [
|
199 |
+
load_label_offset(p, inds, tot) for p in label_paths
|
200 |
+
]
|
201 |
+
assert (
|
202 |
+
label_processors is None
|
203 |
+
or len(label_processors) == self.num_labels
|
204 |
+
)
|
205 |
+
if not skip_verify:
|
206 |
+
for label_path, label_rate in zip(label_paths, self.label_rates):
|
207 |
+
verify_label_lengths(self.sizes, self.sample_rate, label_path, label_rate, inds, tot)
|
208 |
+
else:
|
209 |
+
logger.info(f"Skip label alignment verifying")
|
210 |
+
|
211 |
+
self.max_sample_size = (
|
212 |
+
max_sample_size if max_sample_size is not None else sys.maxsize
|
213 |
+
)
|
214 |
+
self.pad_audio = pad_audio
|
215 |
+
self.normalize = normalize
|
216 |
+
if image_aug:
|
217 |
+
self.transform = custom_utils.Compose([
|
218 |
+
custom_utils.Normalize( 0.0,255.0 ),
|
219 |
+
custom_utils.RandomCrop((image_crop_size, image_crop_size)),
|
220 |
+
custom_utils.HorizontalFlip(0.5),
|
221 |
+
custom_utils.Normalize(image_mean, image_std) ])
|
222 |
+
else:
|
223 |
+
self.transform = custom_utils.Compose([
|
224 |
+
custom_utils.Normalize( 0.0,255.0 ),
|
225 |
+
custom_utils.CenterCrop((image_crop_size, image_crop_size)),
|
226 |
+
custom_utils.Normalize(image_mean, image_std) ])
|
227 |
+
logger.info(f"image transform: {self.transform}")
|
228 |
+
|
229 |
+
logger.info(
|
230 |
+
f"pad_audio={pad_audio}, random_crop={random_crop}, "
|
231 |
+
f"normalize={normalize}, max_sample_size={self.max_sample_size}, "
|
232 |
+
f"seqs2seq data={self.is_s2s},")
|
233 |
+
logger.info(
|
234 |
+
f"Noise wav: {noise_fn}->{len(self.noise_wav)} wav, Prob: {self.noise_prob}, SNR: {self.noise_snr}, Number of mixture: {self.noise_num}"
|
235 |
+
)
|
236 |
+
|
237 |
+
def get_label(self, index, label_idx):
|
238 |
+
if self.store_labels:
|
239 |
+
label = self.label_list[label_idx][index]
|
240 |
+
else:
|
241 |
+
with open(self.label_paths[label_idx]) as f:
|
242 |
+
offset_s, offset_e = self.label_offsets_list[label_idx][index]
|
243 |
+
f.seek(offset_s)
|
244 |
+
label = f.read(offset_e - offset_s)
|
245 |
+
|
246 |
+
if self.label_processors is not None:
|
247 |
+
label = self.label_processors[label_idx](label)
|
248 |
+
return label
|
249 |
+
|
250 |
+
def get_labels(self, index):
|
251 |
+
return [self.get_label(index, i) for i in range(self.num_labels)]
|
252 |
+
|
253 |
+
def load_feature(self, mix_name):
|
254 |
+
"""
|
255 |
+
Load image and audio feature
|
256 |
+
Returns:
|
257 |
+
video_feats: numpy.ndarray of shape [T, H, W, 1], audio_feats: numpy.ndarray of shape [T, F]
|
258 |
+
"""
|
259 |
+
def stacker(feats, stack_order):
|
260 |
+
"""
|
261 |
+
Concatenating consecutive audio frames
|
262 |
+
Args:
|
263 |
+
feats - numpy.ndarray of shape [T, F]
|
264 |
+
stack_order - int (number of neighboring frames to concatenate
|
265 |
+
Returns:
|
266 |
+
feats - numpy.ndarray of shape [T', F']
|
267 |
+
"""
|
268 |
+
feat_dim = feats.shape[1]
|
269 |
+
if len(feats) % stack_order != 0:
|
270 |
+
res = stack_order - len(feats) % stack_order
|
271 |
+
res = np.zeros([res, feat_dim]).astype(feats.dtype)
|
272 |
+
feats = np.concatenate([feats, res], axis=0)
|
273 |
+
feats = feats.reshape((-1, stack_order, feat_dim)).reshape(-1, stack_order*feat_dim)
|
274 |
+
return feats
|
275 |
+
video_fn, audio_fn = mix_name
|
276 |
+
if 'video' in self.modalities:
|
277 |
+
video_feats = self.load_video(video_fn) # [T, H, W, 1]
|
278 |
+
else:
|
279 |
+
video_feats = None
|
280 |
+
if 'audio' in self.modalities:
|
281 |
+
audio_fn = audio_fn.split(':')[0]
|
282 |
+
sample_rate, wav_data = wavfile.read(audio_fn)
|
283 |
+
assert sample_rate == 16_000 and len(wav_data.shape) == 1
|
284 |
+
if np.random.rand() < self.noise_prob:
|
285 |
+
wav_data = self.add_noise(wav_data)
|
286 |
+
audio_feats = logfbank(wav_data, samplerate=sample_rate).astype(np.float32) # [T, F]
|
287 |
+
audio_feats = stacker(audio_feats, self.stack_order_audio) # [T/stack_order_audio, F*stack_order_audio]
|
288 |
+
else:
|
289 |
+
audio_feats = None
|
290 |
+
if audio_feats is not None and video_feats is not None:
|
291 |
+
diff = len(audio_feats) - len(video_feats)
|
292 |
+
if diff < 0:
|
293 |
+
audio_feats = np.concatenate([audio_feats, np.zeros([-diff, audio_feats.shape[-1]], dtype=audio_feats.dtype)])
|
294 |
+
elif diff > 0:
|
295 |
+
audio_feats = audio_feats[:-diff]
|
296 |
+
return video_feats, audio_feats
|
297 |
+
|
298 |
+
def load_video(self, audio_name):
|
299 |
+
feats = custom_utils.load_video(os.path.join(self.audio_root, audio_name))
|
300 |
+
feats = self.transform(feats)
|
301 |
+
feats = np.expand_dims(feats, axis=-1)
|
302 |
+
return feats
|
303 |
+
|
304 |
+
def select_noise(self):
|
305 |
+
rand_indexes = np.random.randint(0, len(self.noise_wav), size=self.noise_num)
|
306 |
+
noise_wav = []
|
307 |
+
for x in rand_indexes:
|
308 |
+
noise_wav.append(wavfile.read(self.noise_wav[x])[1].astype(np.float32))
|
309 |
+
if self.noise_num == 1:
|
310 |
+
return noise_wav[0]
|
311 |
+
else:
|
312 |
+
min_len = min([len(x) for x in noise_wav])
|
313 |
+
noise_wav = [x[:min_len] for x in noise_wav]
|
314 |
+
noise_wav = np.floor(np.stack(noise_wav).mean(axis=0))
|
315 |
+
return noise_wav
|
316 |
+
|
317 |
+
def add_noise(self, clean_wav):
|
318 |
+
clean_wav = clean_wav.astype(np.float32)
|
319 |
+
noise_wav = self.select_noise()
|
320 |
+
if type(self.noise_snr) == int or type(self.noise_snr) == float:
|
321 |
+
snr = self.noise_snr
|
322 |
+
elif type(self.noise_snr) == tuple:
|
323 |
+
snr = np.random.randint(self.noise_snr[0], self.noise_snr[1]+1)
|
324 |
+
clean_rms = np.sqrt(np.mean(np.square(clean_wav), axis=-1))
|
325 |
+
if len(clean_wav) > len(noise_wav):
|
326 |
+
ratio = int(np.ceil(len(clean_wav)/len(noise_wav)))
|
327 |
+
noise_wav = np.concatenate([noise_wav for _ in range(ratio)])
|
328 |
+
if len(clean_wav) < len(noise_wav):
|
329 |
+
start = 0
|
330 |
+
noise_wav = noise_wav[start: start + len(clean_wav)]
|
331 |
+
noise_rms = np.sqrt(np.mean(np.square(noise_wav), axis=-1))
|
332 |
+
adjusted_noise_rms = clean_rms / (10**(snr/20))
|
333 |
+
adjusted_noise_wav = noise_wav * (adjusted_noise_rms / noise_rms)
|
334 |
+
mixed = clean_wav + adjusted_noise_wav
|
335 |
+
|
336 |
+
#Avoid clipping noise
|
337 |
+
max_int16 = np.iinfo(np.int16).max
|
338 |
+
min_int16 = np.iinfo(np.int16).min
|
339 |
+
if mixed.max(axis=0) > max_int16 or mixed.min(axis=0) < min_int16:
|
340 |
+
if mixed.max(axis=0) >= abs(mixed.min(axis=0)):
|
341 |
+
reduction_rate = max_int16 / mixed.max(axis=0)
|
342 |
+
else :
|
343 |
+
reduction_rate = min_int16 / mixed.min(axis=0)
|
344 |
+
mixed = mixed * (reduction_rate)
|
345 |
+
mixed = mixed.astype(np.int16)
|
346 |
+
return mixed
|
347 |
+
|
348 |
+
def __getitem__(self, index):
|
349 |
+
video_feats, audio_feats = self.load_feature(self.names[index])
|
350 |
+
audio_feats, video_feats = torch.from_numpy(audio_feats.astype(np.float32)) if audio_feats is not None else None, torch.from_numpy(video_feats.astype(np.float32)) if video_feats is not None else None
|
351 |
+
if self.normalize and 'audio' in self.modalities:
|
352 |
+
with torch.no_grad():
|
353 |
+
audio_feats = F.layer_norm(audio_feats, audio_feats.shape[1:])
|
354 |
+
labels = self.get_labels(index)
|
355 |
+
fid = self.names[index][1].split(':')[1]
|
356 |
+
return {"id": index, 'fid': fid, "video_source": video_feats, 'audio_source': audio_feats, "label_list": labels}
|
357 |
+
|
358 |
+
def __len__(self):
|
359 |
+
return len(self.sizes)
|
360 |
+
|
361 |
+
def crop_to_max_size(self, wav, target_size, start=None):
|
362 |
+
size = len(wav)
|
363 |
+
diff = size - target_size
|
364 |
+
if diff <= 0:
|
365 |
+
return wav, 0
|
366 |
+
# longer utterances
|
367 |
+
if start is None:
|
368 |
+
start, end = 0, target_size
|
369 |
+
if self.random_crop:
|
370 |
+
start = np.random.randint(0, diff + 1)
|
371 |
+
end = size - diff + start
|
372 |
+
else:
|
373 |
+
end = start + target_size
|
374 |
+
return wav[start:end], start
|
375 |
+
|
376 |
+
def collater(self, samples):
|
377 |
+
samples = [s for s in samples if s["id"] is not None]
|
378 |
+
if len(samples) == 0:
|
379 |
+
return {}
|
380 |
+
|
381 |
+
audio_source, video_source = [s["audio_source"] for s in samples], [s["video_source"] for s in samples]
|
382 |
+
if audio_source[0] is None:
|
383 |
+
audio_source = None
|
384 |
+
if video_source[0] is None:
|
385 |
+
video_source = None
|
386 |
+
if audio_source is not None:
|
387 |
+
audio_sizes = [len(s) for s in audio_source]
|
388 |
+
else:
|
389 |
+
audio_sizes = [len(s) for s in video_source]
|
390 |
+
if self.pad_audio:
|
391 |
+
audio_size = min(max(audio_sizes), self.max_sample_size)
|
392 |
+
else:
|
393 |
+
audio_size = min(min(audio_sizes), self.max_sample_size)
|
394 |
+
if audio_source is not None:
|
395 |
+
collated_audios, padding_mask, audio_starts = self.collater_audio(audio_source, audio_size)
|
396 |
+
else:
|
397 |
+
collated_audios, audio_starts = None, None
|
398 |
+
if video_source is not None:
|
399 |
+
collated_videos, padding_mask, audio_starts = self.collater_audio(video_source, audio_size, audio_starts)
|
400 |
+
else:
|
401 |
+
collated_videos = None
|
402 |
+
targets_by_label = [
|
403 |
+
[s["label_list"][i] for s in samples]
|
404 |
+
for i in range(self.num_labels)
|
405 |
+
]
|
406 |
+
targets_list, lengths_list, ntokens_list = self.collater_label(
|
407 |
+
targets_by_label, audio_size, audio_starts
|
408 |
+
)
|
409 |
+
source = {"audio": collated_audios, "video": collated_videos}
|
410 |
+
net_input = {"source": source, "padding_mask": padding_mask}
|
411 |
+
batch = {
|
412 |
+
"id": torch.LongTensor([s["id"] for s in samples]),
|
413 |
+
"net_input": net_input,
|
414 |
+
"utt_id": [s['fid'] for s in samples]
|
415 |
+
}
|
416 |
+
|
417 |
+
if self.single_target:
|
418 |
+
batch["target_lengths"] = lengths_list[0]
|
419 |
+
batch["ntokens"] = ntokens_list[0]
|
420 |
+
if self.is_s2s:
|
421 |
+
batch['target'], net_input['prev_output_tokens'] = targets_list[0][0], targets_list[0][1]
|
422 |
+
else:
|
423 |
+
batch["target"] = targets_list[0]
|
424 |
+
else:
|
425 |
+
batch["target_lengths_list"] = lengths_list
|
426 |
+
batch["ntokens_list"] = ntokens_list
|
427 |
+
batch["target_list"] = targets_list
|
428 |
+
return batch
|
429 |
+
|
430 |
+
def collater_audio(self, audios, audio_size, audio_starts=None):
|
431 |
+
audio_feat_shape = list(audios[0].shape[1:])
|
432 |
+
collated_audios = audios[0].new_zeros([len(audios), audio_size]+audio_feat_shape)
|
433 |
+
padding_mask = (
|
434 |
+
torch.BoolTensor(len(audios), audio_size).fill_(False) #
|
435 |
+
)
|
436 |
+
start_known = audio_starts is not None
|
437 |
+
audio_starts = [0 for _ in audios] if not start_known else audio_starts
|
438 |
+
for i, audio in enumerate(audios):
|
439 |
+
diff = len(audio) - audio_size
|
440 |
+
if diff == 0:
|
441 |
+
collated_audios[i] = audio
|
442 |
+
elif diff < 0:
|
443 |
+
assert self.pad_audio
|
444 |
+
collated_audios[i] = torch.cat(
|
445 |
+
[audio, audio.new_full([-diff]+audio_feat_shape, 0.0)]
|
446 |
+
)
|
447 |
+
padding_mask[i, diff:] = True
|
448 |
+
else:
|
449 |
+
collated_audios[i], audio_starts[i] = self.crop_to_max_size(
|
450 |
+
audio, audio_size, audio_starts[i] if start_known else None
|
451 |
+
)
|
452 |
+
if len(audios[0].shape) == 2:
|
453 |
+
collated_audios = collated_audios.transpose(1, 2) # [B, T, F] -> [B, F, T]
|
454 |
+
else:
|
455 |
+
collated_audios = collated_audios.permute((0, 4, 1, 2, 3)).contiguous() # [B, T, H, W, C] -> [B, C, T, H, W]
|
456 |
+
return collated_audios, padding_mask, audio_starts
|
457 |
+
|
458 |
+
def collater_frm_label(
|
459 |
+
self, targets, audio_size, audio_starts, label_rate, pad
|
460 |
+
):
|
461 |
+
assert label_rate > 0
|
462 |
+
s2f = label_rate / self.sample_rate # num label per sample
|
463 |
+
frm_starts = [int(round(s * s2f)) for s in audio_starts]
|
464 |
+
frm_size = int(round(audio_size * s2f))
|
465 |
+
if not self.pad_audio:
|
466 |
+
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
|
467 |
+
frm_size = min(frm_size, *rem_size)
|
468 |
+
targets = [t[s: s + frm_size] for t, s in zip(targets, frm_starts)]
|
469 |
+
logger.debug(f"audio_starts={audio_starts}")
|
470 |
+
logger.debug(f"frame_starts={frm_starts}")
|
471 |
+
logger.debug(f"frame_size={frm_size}")
|
472 |
+
|
473 |
+
lengths = torch.LongTensor([len(t) for t in targets])
|
474 |
+
ntokens = lengths.sum().item()
|
475 |
+
targets = data_utils.collate_tokens(
|
476 |
+
targets, pad_idx=pad, left_pad=False
|
477 |
+
)
|
478 |
+
return targets, lengths, ntokens
|
479 |
+
|
480 |
+
def collater_seq_label(self, targets, pad):
|
481 |
+
lengths = torch.LongTensor([len(t) for t in targets])
|
482 |
+
ntokens = lengths.sum().item()
|
483 |
+
targets = data_utils.collate_tokens(
|
484 |
+
targets, pad_idx=pad, left_pad=False
|
485 |
+
)
|
486 |
+
return targets, lengths, ntokens
|
487 |
+
|
488 |
+
def collater_seq_label_s2s(self, targets, pad):
|
489 |
+
lengths = torch.LongTensor([len(t) for t in targets])
|
490 |
+
ntokens = lengths.sum().item()
|
491 |
+
pad, eos = self.label_processors[0].dictionary.pad(), self.label_processors[0].dictionary.eos()
|
492 |
+
targets_ = data_utils.collate_tokens(targets, pad_idx=pad, eos_idx=eos, left_pad=False)
|
493 |
+
prev_output_tokens = data_utils.collate_tokens(targets, pad_idx=pad, eos_idx=eos, left_pad=False, move_eos_to_beginning=True)
|
494 |
+
return (targets_, prev_output_tokens), lengths, ntokens
|
495 |
+
|
496 |
+
def collater_label(self, targets_by_label, audio_size, audio_starts):
|
497 |
+
targets_list, lengths_list, ntokens_list = [], [], []
|
498 |
+
itr = zip(targets_by_label, self.label_rates, self.pad_list)
|
499 |
+
for targets, label_rate, pad in itr:
|
500 |
+
if label_rate == -1:
|
501 |
+
if self.is_s2s:
|
502 |
+
targets, lengths, ntokens = self.collater_seq_label_s2s(targets, pad)
|
503 |
+
else:
|
504 |
+
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
|
505 |
+
else:
|
506 |
+
targets, lengths, ntokens = self.collater_frm_label(
|
507 |
+
targets, audio_size, audio_starts, label_rate, pad
|
508 |
+
)
|
509 |
+
targets_list.append(targets)
|
510 |
+
lengths_list.append(lengths)
|
511 |
+
ntokens_list.append(ntokens)
|
512 |
+
return targets_list, lengths_list, ntokens_list
|
513 |
+
|
514 |
+
def num_tokens(self, index):
|
515 |
+
return self.size(index)
|
516 |
+
|
517 |
+
def size(self, index):
|
518 |
+
if self.pad_audio:
|
519 |
+
return self.sizes[index]
|
520 |
+
return min(self.sizes[index], self.max_sample_size)
|
521 |
+
|
522 |
+
def ordered_indices(self):
|
523 |
+
if self.shuffle:
|
524 |
+
order = [np.random.permutation(len(self))]
|
525 |
+
else:
|
526 |
+
order = [np.arange(len(self))]
|
527 |
+
|
528 |
+
order.append(self.sizes)
|
529 |
+
return np.lexsort(order)[::-1]
|
av_hubert/avhubert/hubert_pretraining.py
ADDED
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import os, glob
|
9 |
+
import sys
|
10 |
+
from typing import Dict, List, Optional, Tuple
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
from dataclasses import dataclass, field
|
15 |
+
from fairseq import metrics, search
|
16 |
+
from fairseq.data import Dictionary, encoders
|
17 |
+
from fairseq.dataclass.configs import FairseqDataclass
|
18 |
+
from fairseq.tasks import register_task
|
19 |
+
from fairseq.tasks.fairseq_task import FairseqTask
|
20 |
+
from omegaconf import MISSING, II
|
21 |
+
import numpy as np
|
22 |
+
from argparse import Namespace
|
23 |
+
|
24 |
+
DBG=True if len(sys.argv) == 1 else False
|
25 |
+
|
26 |
+
if DBG:
|
27 |
+
from hubert_dataset import AVHubertDataset
|
28 |
+
from sequence_generator import SequenceGenerator
|
29 |
+
else:
|
30 |
+
from .hubert_dataset import AVHubertDataset
|
31 |
+
from .sequence_generator import SequenceGenerator
|
32 |
+
|
33 |
+
logger = logging.getLogger(__name__)
|
34 |
+
|
35 |
+
|
36 |
+
class LabelEncoder(object):
|
37 |
+
def __init__(self, dictionary: Dictionary) -> None:
|
38 |
+
self.dictionary = dictionary
|
39 |
+
|
40 |
+
def __call__(self, label: str) -> List[str]:
|
41 |
+
return self.dictionary.encode_line(
|
42 |
+
label, append_eos=False, add_if_not_exist=False,
|
43 |
+
)
|
44 |
+
|
45 |
+
class LabelEncoderS2SToken(object):
|
46 |
+
def __init__(self, dictionary: Dictionary, bpe_tokenizer) -> None:
|
47 |
+
self.bpe_tokenizer = bpe_tokenizer
|
48 |
+
self.dictionary = dictionary
|
49 |
+
|
50 |
+
def __call__(self, label: str) -> List[str]:
|
51 |
+
label = self.bpe_tokenizer.encode(label.lower())
|
52 |
+
return self.dictionary.encode_line(
|
53 |
+
label, append_eos=True, add_if_not_exist=False,
|
54 |
+
).long()
|
55 |
+
|
56 |
+
def decode(self, tok, symbols_ignore=None):
|
57 |
+
tok = self.dictionary.string(tok, extra_symbols_to_ignore=symbols_ignore)
|
58 |
+
if self.bpe_tokenizer:
|
59 |
+
tok = self.bpe_tokenizer.decode(tok)
|
60 |
+
return tok
|
61 |
+
|
62 |
+
@dataclass
|
63 |
+
class AVHubertPretrainingConfig(FairseqDataclass):
|
64 |
+
data: str = field(
|
65 |
+
default=MISSING, metadata={"help": "path to data directory"}
|
66 |
+
)
|
67 |
+
labels: List[str] = field(
|
68 |
+
default_factory=lambda: ["ltr"],
|
69 |
+
metadata={
|
70 |
+
"help": (
|
71 |
+
"extension of the label files to load, frame-level labels for"
|
72 |
+
" pre-training, and sequence-level label for fine-tuning"
|
73 |
+
)
|
74 |
+
},
|
75 |
+
)
|
76 |
+
label_dir: Optional[str] = field(
|
77 |
+
default=None,
|
78 |
+
metadata={
|
79 |
+
"help": "if set, looks for labels in this directory instead",
|
80 |
+
},
|
81 |
+
)
|
82 |
+
label_rate: int = field(
|
83 |
+
default=-1,
|
84 |
+
metadata={"help": "label frame rate. -1 for sequence label"},
|
85 |
+
)
|
86 |
+
|
87 |
+
sample_rate: int = field(
|
88 |
+
default=16_000,
|
89 |
+
metadata={
|
90 |
+
"help": "target sample rate. audio files will be up/down "
|
91 |
+
"sampled to this rate"
|
92 |
+
},
|
93 |
+
)
|
94 |
+
normalize: bool = field(
|
95 |
+
default=False,
|
96 |
+
metadata={
|
97 |
+
"help": "if set, normalizes input to have 0 mean and unit variance"
|
98 |
+
},
|
99 |
+
)
|
100 |
+
enable_padding: bool = field(
|
101 |
+
default=False,
|
102 |
+
metadata={"help": "pad shorter samples instead of cropping"},
|
103 |
+
)
|
104 |
+
max_sample_size: Optional[int] = field(
|
105 |
+
default=None,
|
106 |
+
metadata={"help": "max sample size to keep in training"},
|
107 |
+
)
|
108 |
+
min_sample_size: Optional[int] = field(
|
109 |
+
default=None,
|
110 |
+
metadata={"help": "min sample size to keep in training"},
|
111 |
+
)
|
112 |
+
max_trim_sample_size: Optional[int] = field(
|
113 |
+
default=II("task.max_sample_size"),
|
114 |
+
metadata={"help": "max sample size to trim to for batching"},
|
115 |
+
)
|
116 |
+
single_target: Optional[bool] = field(
|
117 |
+
default=False,
|
118 |
+
metadata={
|
119 |
+
"help": "if set, AddTargetDatasets outputs same keys "
|
120 |
+
"as AddTargetDataset"
|
121 |
+
},
|
122 |
+
)
|
123 |
+
random_crop: Optional[bool] = field(
|
124 |
+
default=True,
|
125 |
+
metadata={"help": "always crop from the beginning if false"},
|
126 |
+
)
|
127 |
+
pad_audio: Optional[bool] = field(
|
128 |
+
default=False,
|
129 |
+
metadata={"help": "pad audio to the longest one in the batch if true"},
|
130 |
+
)
|
131 |
+
pdb: Optional[bool] = field(
|
132 |
+
default=False,
|
133 |
+
metadata={"help": "pdb"},
|
134 |
+
)
|
135 |
+
stack_order_audio: int = field(
|
136 |
+
default=1,
|
137 |
+
metadata={"help": "concatenate n consecutive audio frames for one step"},
|
138 |
+
)
|
139 |
+
skip_verify: Optional[bool] = field(
|
140 |
+
default=False,
|
141 |
+
metadata={"help": "skip verifying label-audio alignment"},
|
142 |
+
)
|
143 |
+
image_aug: bool = field(default=False, metadata={'help': 'image data augmentation'})
|
144 |
+
image_crop_size: int = field(
|
145 |
+
default=88, metadata={"help": "image ROI size"})
|
146 |
+
image_mean: float = field(
|
147 |
+
default=0.421, metadata={"help": "image mean"})
|
148 |
+
image_std: float = field(
|
149 |
+
default=0.165, metadata={"help": "image std"})
|
150 |
+
modalities: Optional[List[str]] = field(default_factory=lambda: ["audio", "video"], metadata={'help': 'modalities to load'})
|
151 |
+
is_s2s: bool=field(default=False, metadata={'help': 'seq2seq fine-tuning only'})
|
152 |
+
tokenizer_bpe_name: Optional[str] = field(default=None, metadata={'help': 'tokenizer model name'})
|
153 |
+
tokenizer_bpe_model: Optional[str] = field(default=None, metadata={'help': 'tokenizer model path'})
|
154 |
+
noise_wav: Optional[str] = field(default=None, metadata={'help': 'manifest of noise wav files (one wav file path per line)'})
|
155 |
+
noise_prob: float = field(default=0, metadata={'help': 'noise probability'})
|
156 |
+
noise_snr: Optional[str] = field(default='0', metadata={'help': 'noise SNR in audio'})
|
157 |
+
noise_num: int = field(default=1, metadata={'help': 'number of noise wav files to mix'})
|
158 |
+
fine_tuning: bool = field(default=False, metadata={"help": "set to true if fine-tuning AV-Hubert"})
|
159 |
+
|
160 |
+
@register_task("av_hubert_pretraining", dataclass=AVHubertPretrainingConfig)
|
161 |
+
class AVHubertPretrainingTask(FairseqTask):
|
162 |
+
|
163 |
+
cfg: AVHubertPretrainingConfig
|
164 |
+
|
165 |
+
def __init__(
|
166 |
+
self,
|
167 |
+
cfg: AVHubertPretrainingConfig,
|
168 |
+
) -> None:
|
169 |
+
super().__init__(cfg)
|
170 |
+
|
171 |
+
logger.info(f"current directory is {os.getcwd()}")
|
172 |
+
logger.info(f"AVHubertPretrainingTask Config {cfg}")
|
173 |
+
|
174 |
+
self.fine_tuning = cfg.fine_tuning
|
175 |
+
if cfg.fine_tuning:
|
176 |
+
self.state.add_factory("target_dictionary", self.load_dictionaries)
|
177 |
+
if cfg.is_s2s:
|
178 |
+
self.state.add_factory("s2s_tokenizer", self.load_tokenizer)
|
179 |
+
else:
|
180 |
+
self.state.add_factory("dictionaries", self.load_dictionaries)
|
181 |
+
|
182 |
+
self.blank_symbol = "<s>"
|
183 |
+
|
184 |
+
@property
|
185 |
+
def source_dictionary(self) -> Optional[Dictionary]:
|
186 |
+
return None # self._source_dictionary
|
187 |
+
|
188 |
+
@property
|
189 |
+
def target_dictionary(self) -> Optional[Dictionary]:
|
190 |
+
return self.state.target_dictionary # self._target_dictionary
|
191 |
+
|
192 |
+
@property
|
193 |
+
def dictionaries(self) -> List[Dictionary]:
|
194 |
+
return self.state.dictionaries
|
195 |
+
|
196 |
+
def load_dictionaries(self):
|
197 |
+
label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir
|
198 |
+
dictionaries = [
|
199 |
+
Dictionary.load(f"{label_dir}/dict.{label}.txt")
|
200 |
+
for label in self.cfg.labels
|
201 |
+
]
|
202 |
+
return dictionaries[0] if self.cfg.fine_tuning else dictionaries
|
203 |
+
|
204 |
+
def load_tokenizer(self):
|
205 |
+
bpe_args = Namespace(**{'bpe': self.cfg.tokenizer_bpe_name, f"{self.cfg.tokenizer_bpe_name}_model": self.cfg.tokenizer_bpe_model})
|
206 |
+
bpe_tokenizer = encoders.build_bpe(bpe_args)
|
207 |
+
return bpe_tokenizer
|
208 |
+
|
209 |
+
@property
|
210 |
+
def s2s_tokenizer(self):
|
211 |
+
return self.state.s2s_tokenizer
|
212 |
+
|
213 |
+
@classmethod
|
214 |
+
def setup_task(
|
215 |
+
cls, cfg: AVHubertPretrainingConfig, **kwargs
|
216 |
+
) -> "AVHubertPretrainingTask":
|
217 |
+
if cfg.pdb:
|
218 |
+
import pdb
|
219 |
+
pdb.set_trace()
|
220 |
+
return cls(cfg)
|
221 |
+
|
222 |
+
def get_label_dir(self) -> str:
|
223 |
+
if self.cfg.label_dir is None:
|
224 |
+
return self.cfg.data
|
225 |
+
return self.cfg.label_dir
|
226 |
+
|
227 |
+
def load_dataset(self, split: str, **kwargs) -> None:
|
228 |
+
manifest = f"{self.cfg.data}/{split}.tsv"
|
229 |
+
dictionaries = [self.target_dictionary] if self.fine_tuning else self.dictionaries
|
230 |
+
pad_list = [dictionary.pad() for dictionary in dictionaries]
|
231 |
+
eos_list = [dictionary.eos() for dictionary in dictionaries]
|
232 |
+
if not self.cfg.is_s2s:
|
233 |
+
procs = [LabelEncoder(dictionary) for dictionary in dictionaries]
|
234 |
+
else:
|
235 |
+
logger.info(f"Using tokenizer")
|
236 |
+
bpe_tokenizer = self.s2s_tokenizer
|
237 |
+
procs = [LabelEncoderS2SToken(dictionary, bpe_tokenizer) for dictionary in dictionaries]
|
238 |
+
paths = [
|
239 |
+
f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels
|
240 |
+
]
|
241 |
+
image_aug = self.cfg.image_aug if split == 'train' else False
|
242 |
+
noise_fn, noise_snr = f"{self.cfg.noise_wav}/{split}.tsv" if self.cfg.noise_wav is not None else None, eval(self.cfg.noise_snr)
|
243 |
+
noise_num = self.cfg.noise_num #
|
244 |
+
self.datasets[split] = AVHubertDataset(
|
245 |
+
manifest,
|
246 |
+
sample_rate=self.cfg.sample_rate,
|
247 |
+
label_paths=paths,
|
248 |
+
label_rates=self.cfg.label_rate,
|
249 |
+
pad_list=pad_list,
|
250 |
+
eos_list=eos_list,
|
251 |
+
label_processors=procs,
|
252 |
+
max_keep_sample_size=self.cfg.max_sample_size,
|
253 |
+
min_keep_sample_size=self.cfg.min_sample_size,
|
254 |
+
max_sample_size=self.cfg.max_trim_sample_size,
|
255 |
+
pad_audio=self.cfg.pad_audio,
|
256 |
+
normalize=self.cfg.normalize,
|
257 |
+
store_labels=False,
|
258 |
+
random_crop=self.cfg.random_crop,
|
259 |
+
single_target=self.cfg.single_target,
|
260 |
+
stack_order_audio=self.cfg.stack_order_audio,
|
261 |
+
skip_verify=self.cfg.skip_verify,
|
262 |
+
image_mean=self.cfg.image_mean,
|
263 |
+
image_std=self.cfg.image_std,
|
264 |
+
image_crop_size=self.cfg.image_crop_size,
|
265 |
+
image_aug=image_aug,
|
266 |
+
modalities=self.cfg.modalities,
|
267 |
+
is_s2s=self.cfg.is_s2s,
|
268 |
+
noise_fn=noise_fn,
|
269 |
+
noise_prob=self.cfg.noise_prob,
|
270 |
+
noise_snr=noise_snr,
|
271 |
+
noise_num=noise_num
|
272 |
+
)
|
273 |
+
|
274 |
+
def max_positions(self) -> Tuple[int, int]:
|
275 |
+
return (sys.maxsize, sys.maxsize)
|
276 |
+
|
277 |
+
def filter_indices_by_size(
|
278 |
+
self, indices: np.array, *args, **kwargs
|
279 |
+
) -> np.array:
|
280 |
+
return indices
|
281 |
+
|
282 |
+
def build_generator(
|
283 |
+
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None,
|
284 |
+
):
|
285 |
+
"""
|
286 |
+
Build a :class:`~fairseq.SequenceGenerator` instance for this
|
287 |
+
task.
|
288 |
+
Args:
|
289 |
+
models (List[~fairseq.models.FairseqModel]): ensemble of models
|
290 |
+
args (fairseq.dataclass.configs.GenerationConfig):
|
291 |
+
configuration object (dataclass) for generation
|
292 |
+
extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass
|
293 |
+
through to SequenceGenerator
|
294 |
+
prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]):
|
295 |
+
If provided, this function constrains the beam search to
|
296 |
+
allowed tokens only at each step. The provided function
|
297 |
+
should take 2 arguments: the batch ID (`batch_id: int`)
|
298 |
+
and a unidimensional tensor of token ids (`inputs_ids:
|
299 |
+
torch.Tensor`). It has to return a `List[int]` with the
|
300 |
+
allowed tokens for the next generation step conditioned
|
301 |
+
on the previously generated tokens (`inputs_ids`) and
|
302 |
+
the batch ID (`batch_id`). This argument is useful for
|
303 |
+
constrained generation conditioned on the prefix, as
|
304 |
+
described in "Autoregressive Entity Retrieval"
|
305 |
+
(https://arxiv.org/abs/2010.00904) and
|
306 |
+
https://github.com/facebookresearch/GENRE.
|
307 |
+
"""
|
308 |
+
if getattr(args, "score_reference", False):
|
309 |
+
from fairseq.sequence_scorer import SequenceScorer
|
310 |
+
|
311 |
+
return SequenceScorer(
|
312 |
+
self.target_dictionary,
|
313 |
+
compute_alignment=getattr(args, "print_alignment", False),
|
314 |
+
)
|
315 |
+
|
316 |
+
# Choose search strategy. Defaults to Beam Search.
|
317 |
+
sampling = getattr(args, "sampling", False)
|
318 |
+
sampling_topk = getattr(args, "sampling_topk", -1)
|
319 |
+
sampling_topp = getattr(args, "sampling_topp", -1.0)
|
320 |
+
diverse_beam_groups = getattr(args, "diverse_beam_groups", -1)
|
321 |
+
diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5)
|
322 |
+
match_source_len = getattr(args, "match_source_len", False)
|
323 |
+
diversity_rate = getattr(args, "diversity_rate", -1)
|
324 |
+
constrained = getattr(args, "constraints", False)
|
325 |
+
if prefix_allowed_tokens_fn is None:
|
326 |
+
prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None)
|
327 |
+
if (
|
328 |
+
sum(
|
329 |
+
int(cond)
|
330 |
+
for cond in [
|
331 |
+
sampling,
|
332 |
+
diverse_beam_groups > 0,
|
333 |
+
match_source_len,
|
334 |
+
diversity_rate > 0,
|
335 |
+
]
|
336 |
+
)
|
337 |
+
> 1
|
338 |
+
):
|
339 |
+
raise ValueError("Provided Search parameters are mutually exclusive.")
|
340 |
+
assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling"
|
341 |
+
assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling"
|
342 |
+
|
343 |
+
if sampling:
|
344 |
+
search_strategy = search.Sampling(
|
345 |
+
self.target_dictionary, sampling_topk, sampling_topp
|
346 |
+
)
|
347 |
+
elif diverse_beam_groups > 0:
|
348 |
+
search_strategy = search.DiverseBeamSearch(
|
349 |
+
self.target_dictionary, diverse_beam_groups, diverse_beam_strength
|
350 |
+
)
|
351 |
+
elif match_source_len:
|
352 |
+
# this is useful for tagging applications where the output
|
353 |
+
# length should match the input length, so we hardcode the
|
354 |
+
# length constraints for simplicity
|
355 |
+
search_strategy = search.LengthConstrainedBeamSearch(
|
356 |
+
self.target_dictionary,
|
357 |
+
min_len_a=1,
|
358 |
+
min_len_b=0,
|
359 |
+
max_len_a=1,
|
360 |
+
max_len_b=0,
|
361 |
+
)
|
362 |
+
elif diversity_rate > -1:
|
363 |
+
search_strategy = search.DiverseSiblingsSearch(
|
364 |
+
self.target_dictionary, diversity_rate
|
365 |
+
)
|
366 |
+
elif constrained:
|
367 |
+
search_strategy = search.LexicallyConstrainedBeamSearch(
|
368 |
+
self.target_dictionary, args.constraints
|
369 |
+
)
|
370 |
+
elif prefix_allowed_tokens_fn:
|
371 |
+
search_strategy = search.PrefixConstrainedBeamSearch(
|
372 |
+
self.target_dictionary, prefix_allowed_tokens_fn
|
373 |
+
)
|
374 |
+
else:
|
375 |
+
search_strategy = search.BeamSearch(self.target_dictionary)
|
376 |
+
|
377 |
+
extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
|
378 |
+
if seq_gen_cls is None:
|
379 |
+
if getattr(args, "print_alignment", False):
|
380 |
+
seq_gen_cls = SequenceGeneratorWithAlignment
|
381 |
+
extra_gen_cls_kwargs["print_alignment"] = args.print_alignment
|
382 |
+
else:
|
383 |
+
seq_gen_cls = SequenceGenerator
|
384 |
+
|
385 |
+
return seq_gen_cls(
|
386 |
+
models,
|
387 |
+
self.target_dictionary,
|
388 |
+
beam_size=getattr(args, "beam", 5),
|
389 |
+
max_len_a=getattr(args, "max_len_a", 0),
|
390 |
+
max_len_b=getattr(args, "max_len_b", 200),
|
391 |
+
min_len=getattr(args, "min_len", 1),
|
392 |
+
normalize_scores=(not getattr(args, "unnormalized", False)),
|
393 |
+
len_penalty=getattr(args, "lenpen", 1),
|
394 |
+
unk_penalty=getattr(args, "unkpen", 0),
|
395 |
+
temperature=getattr(args, "temperature", 1.0),
|
396 |
+
match_source_len=getattr(args, "match_source_len", False),
|
397 |
+
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
|
398 |
+
search_strategy=search_strategy,
|
399 |
+
**extra_gen_cls_kwargs,
|
400 |
+
)
|