@@ -0,0 +1,128 @@
1 |
# Contributor Covenant Code of Conduct
2 |
3 |
## Our Pledge
4 |
5 |
We as members, contributors, and leaders pledge to make participation in our
6 |
community a harassment-free experience for everyone, regardless of age, body
7 |
size, visible or invisible disability, ethnicity, sex characteristics, gender
8 |
identity and expression, level of experience, education, socio-economic status,
9 |
nationality, personal appearance, race, religion, or sexual identity
10 |
and orientation.
11 |
12 |
We pledge to act and interact in ways that contribute to an open, welcoming,
13 |
diverse, inclusive, and healthy community.
14 |
15 |
## Our Standards
16 |
17 |
Examples of behavior that contributes to a positive environment for our
18 |
community include:
19 |
20 |
* Demonstrating empathy and kindness toward other people
21 |
* Being respectful of differing opinions, viewpoints, and experiences
22 |
* Giving and gracefully accepting constructive feedback
23 |
* Accepting responsibility and apologizing to those affected by our mistakes,
24 |
and learning from the experience
25 |
* Focusing on what is best not just for us as individuals, but for the
26 |
overall community
27 |
28 |
Examples of unacceptable behavior include:
29 |
30 |
* The use of sexualized language or imagery, and sexual attention or
31 |
advances of any kind
32 |
* Trolling, insulting or derogatory comments, and personal or political attacks
33 |
* Public or private harassment
34 |
* Publishing others' private information, such as a physical or email
35 |
address, without their explicit permission
36 |
* Other conduct which could reasonably be considered inappropriate in a
37 |
professional setting
38 |
39 |
## Enforcement Responsibilities
40 |
41 |
Community leaders are responsible for clarifying and enforcing our standards of
42 |
acceptable behavior and will take appropriate and fair corrective action in
43 |
response to any behavior that they deem inappropriate, threatening, offensive,
44 |
or harmful.
45 |
46 |
Community leaders have the right and responsibility to remove, edit, or reject
47 |
comments, commits, code, wiki edits, issues, and other contributions that are
48 |
not aligned to this Code of Conduct, and will communicate reasons for moderation
49 |
decisions when appropriate.
50 |
51 |
## Scope
52 |
53 |
This Code of Conduct applies within all community spaces, and also applies when
54 |
an individual is officially representing the community in public spaces.
55 |
Examples of representing our community include using an official e-mail address,
56 |
posting via an official social media account, or acting as an appointed
57 |
representative at an online or offline event.
58 |
59 |
## Enforcement
60 |
61 |
Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 |
reported to the community leaders responsible for enforcement at
63 | |
64 |
All complaints will be reviewed and investigated promptly and fairly.
65 |
66 |
All community leaders are obligated to respect the privacy and security of the
67 |
reporter of any incident.
68 |
69 |
## Enforcement Guidelines
70 |
71 |
Community leaders will follow these Community Impact Guidelines in determining
72 |
the consequences for any action they deem in violation of this Code of Conduct:
73 |
74 |
### 1. Correction
75 |
76 |
**Community Impact**: Use of inappropriate language or other behavior deemed
77 |
unprofessional or unwelcome in the community.
78 |
79 |
**Consequence**: A private, written warning from community leaders, providing
80 |
clarity around the nature of the violation and an explanation of why the
81 |
behavior was inappropriate. A public apology may be requested.
82 |
83 |
### 2. Warning
84 |
85 |
**Community Impact**: A violation through a single incident or series
86 |
of actions.
87 |
88 |
**Consequence**: A warning with consequences for continued behavior. No
89 |
interaction with the people involved, including unsolicited interaction with
90 |
those enforcing the Code of Conduct, for a specified period of time. This
91 |
includes avoiding interactions in community spaces as well as external channels
92 |
like social media. Violating these terms may lead to a temporary or
93 |
permanent ban.
94 |
95 |
### 3. Temporary Ban
96 |
97 |
**Community Impact**: A serious violation of community standards, including
98 |
sustained inappropriate behavior.
99 |
100 |
**Consequence**: A temporary ban from any sort of interaction or public
101 |
communication with the community for a specified period of time. No public or
102 |
private interaction with the people involved, including unsolicited interaction
103 |
with those enforcing the Code of Conduct, is allowed during this period.
104 |
Violating these terms may lead to a permanent ban.
105 |
106 |
### 4. Permanent Ban
107 |
108 |
**Community Impact**: Demonstrating a pattern of violation of community
109 |
standards, including sustained inappropriate behavior, harassment of an
110 |
individual, or aggression toward or disparagement of classes of individuals.
111 |
112 |
**Consequence**: A permanent ban from any sort of public interaction within
113 |
the community.
114 |
115 |
## Attribution
116 |
117 |
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118 |
version 2.0, available at
119 |
120 |
121 |
Community Impact Guidelines were inspired by [Mozilla's code of conduct
122 |
enforcement ladder](
123 |
124 |
125 |
126 |
For answers to common questions about this code of conduct, see the FAQ at
127 |
+ Translations are available at
128 |
@@ -0,0 +1,7 @@
1 |
[Git Guide](
2 |
3 |
[GitHub Cooperation Guide](
4 |
5 |
- [Code Style](
6 |
- [Unit Test](
7 |
- [Code Review](
@@ -0,0 +1,202 @@
1 |
2 |
Apache License
3 |
Version 2.0, January 2004
4 |
5 |
6 |
7 |
8 |
1. Definitions.
9 |
10 |
"License" shall mean the terms and conditions for use, reproduction,
11 |
and distribution as defined by Sections 1 through 9 of this document.
12 |
13 |
"Licensor" shall mean the copyright owner or entity authorized by
14 |
the copyright owner that is granting the License.
15 |
16 |
"Legal Entity" shall mean the union of the acting entity and all
17 |
other entities that control, are controlled by, or are under common
18 |
control with that entity. For the purposes of this definition,
19 |
"control" means (i) the power, direct or indirect, to cause the
20 |
direction or management of such entity, whether by contract or
21 |
otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 |
outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 |
"You" (or "Your") shall mean an individual or Legal Entity
25 |
exercising permissions granted by this License.
26 |
27 |
"Source" form shall mean the preferred form for making modifications,
28 |
including but not limited to software source code, documentation
29 |
source, and configuration files.
30 |
31 |
"Object" form shall mean any form resulting from mechanical
32 |
transformation or translation of a Source form, including but
33 |
not limited to compiled object code, generated documentation,
34 |
and conversions to other media types.
35 |
36 |
"Work" shall mean the work of authorship, whether in Source or
37 |
Object form, made available under the License, as indicated by a
38 |
copyright notice that is included in or attached to the work
39 |
(an example is provided in the Appendix below).
40 |
41 |
"Derivative Works" shall mean any work, whether in Source or Object
42 |
form, that is based on (or derived from) the Work and for which the
43 |
editorial revisions, annotations, elaborations, or other modifications
44 |
represent, as a whole, an original work of authorship. For the purposes
45 |
of this License, Derivative Works shall not include works that remain
46 |
separable from, or merely link (or bind by name) to the interfaces of,
47 |
the Work and Derivative Works thereof.
48 |
49 |
"Contribution" shall mean any work of authorship, including
50 |
the original version of the Work and any modifications or additions
51 |
to that Work or Derivative Works thereof, that is intentionally
52 |
submitted to Licensor for inclusion in the Work by the copyright owner
53 |
or by an individual or Legal Entity authorized to submit on behalf of
54 |
the copyright owner. For the purposes of this definition, "submitted"
55 |
means any form of electronic, verbal, or written communication sent
56 |
to the Licensor or its representatives, including but not limited to
57 |
communication on electronic mailing lists, source code control systems,
58 |
and issue tracking systems that are managed by, or on behalf of, the
59 |
Licensor for the purpose of discussing and improving the Work, but
60 |
excluding communication that is conspicuously marked or otherwise
61 |
designated in writing by the copyright owner as "Not a Contribution."
62 |
63 |
"Contributor" shall mean Licensor and any individual or Legal Entity
64 |
on behalf of whom a Contribution has been received by Licensor and
65 |
subsequently incorporated within the Work.
66 |
67 |
2. Grant of Copyright License. Subject to the terms and conditions of
68 |
this License, each Contributor hereby grants to You a perpetual,
69 |
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 |
copyright license to reproduce, prepare Derivative Works of,
71 |
publicly display, publicly perform, sublicense, and distribute the
72 |
Work and such Derivative Works in Source or Object form.
73 |
74 |
3. Grant of Patent License. Subject to the terms and conditions of
75 |
this License, each Contributor hereby grants to You a perpetual,
76 |
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 |
(except as stated in this section) patent license to make, have made,
78 |
use, offer to sell, sell, import, and otherwise transfer the Work,
79 |
where such license applies only to those patent claims licensable
80 |
by such Contributor that are necessarily infringed by their
81 |
Contribution(s) alone or by combination of their Contribution(s)
82 |
with the Work to which such Contribution(s) was submitted. If You
83 |
institute patent litigation against any entity (including a
84 |
cross-claim or counterclaim in a lawsuit) alleging that the Work
85 |
or a Contribution incorporated within the Work constitutes direct
86 |
or contributory patent infringement, then any patent licenses
87 |
granted to You under this License for that Work shall terminate
88 |
as of the date such litigation is filed.
89 |
90 |
4. Redistribution. You may reproduce and distribute copies of the
91 |
Work or Derivative Works thereof in any medium, with or without
92 |
modifications, and in Source or Object form, provided that You
93 |
meet the following conditions:
94 |
95 |
(a) You must give any other recipients of the Work or
96 |
Derivative Works a copy of this License; and
97 |
98 |
(b) You must cause any modified files to carry prominent notices
99 |
stating that You changed the files; and
100 |
101 |
(c) You must retain, in the Source form of any Derivative Works
102 |
that You distribute, all copyright, patent, trademark, and
103 |
attribution notices from the Source form of the Work,
104 |
excluding those notices that do not pertain to any part of
105 |
the Derivative Works; and
106 |
107 |
(d) If the Work includes a "NOTICE" text file as part of its
108 |
distribution, then any Derivative Works that You distribute must
109 |
include a readable copy of the attribution notices contained
110 |
within such NOTICE file, excluding those notices that do not
111 |
pertain to any part of the Derivative Works, in at least one
112 |
of the following places: within a NOTICE text file distributed
113 |
as part of the Derivative Works; within the Source form or
114 |
documentation, if provided along with the Derivative Works; or,
115 |
within a display generated by the Derivative Works, if and
116 |
wherever such third-party notices normally appear. The contents
117 |
of the NOTICE file are for informational purposes only and
118 |
do not modify the License. You may add Your own attribution
119 |
notices within Derivative Works that You distribute, alongside
120 |
or as an addendum to the NOTICE text from the Work, provided
121 |
that such additional attribution notices cannot be construed
122 |
as modifying the License.
123 |
124 |
You may add Your own copyright statement to Your modifications and
125 |
may provide additional or different license terms and conditions
126 |
for use, reproduction, or distribution of Your modifications, or
127 |
for any such Derivative Works as a whole, provided Your use,
128 |
reproduction, and distribution of the Work otherwise complies with
129 |
the conditions stated in this License.
130 |
131 |
5. Submission of Contributions. Unless You explicitly state otherwise,
132 |
any Contribution intentionally submitted for inclusion in the Work
133 |
by You to the Licensor shall be under the terms and conditions of
134 |
this License, without any additional terms or conditions.
135 |
Notwithstanding the above, nothing herein shall supersede or modify
136 |
the terms of any separate license agreement you may have executed
137 |
with Licensor regarding such Contributions.
138 |
139 |
6. Trademarks. This License does not grant permission to use the trade
140 |
names, trademarks, service marks, or product names of the Licensor,
141 |
except as required for reasonable and customary use in describing the
142 |
origin of the Work and reproducing the content of the NOTICE file.
143 |
144 |
7. Disclaimer of Warranty. Unless required by applicable law or
145 |
agreed to in writing, Licensor provides the Work (and each
146 |
Contributor provides its Contributions) on an "AS IS" BASIS,
147 |
148 |
implied, including, without limitation, any warranties or conditions
149 |
150 |
PARTICULAR PURPOSE. You are solely responsible for determining the
151 |
appropriateness of using or redistributing the Work and assume any
152 |
risks associated with Your exercise of permissions under this License.
153 |
154 |
8. Limitation of Liability. In no event and under no legal theory,
155 |
whether in tort (including negligence), contract, or otherwise,
156 |
unless required by applicable law (such as deliberate and grossly
157 |
negligent acts) or agreed to in writing, shall any Contributor be
158 |
liable to You for damages, including any direct, indirect, special,
159 |
incidental, or consequential damages of any character arising as a
160 |
result of this License or out of the use or inability to use the
161 |
Work (including but not limited to damages for loss of goodwill,
162 |
work stoppage, computer failure or malfunction, or any and all
163 |
other commercial damages or losses), even if such Contributor
164 |
has been advised of the possibility of such damages.
165 |
166 |
9. Accepting Warranty or Additional Liability. While redistributing
167 |
the Work or Derivative Works thereof, You may choose to offer,
168 |
and charge a fee for, acceptance of support, warranty, indemnity,
169 |
or other liability obligations and/or rights consistent with this
170 |
License. However, in accepting such obligations, You may act only
171 |
on Your own behalf and on Your sole responsibility, not on behalf
172 |
of any other Contributor, and only if You agree to indemnify,
173 |
defend, and hold each Contributor harmless for any liability
174 |
incurred by, or claims asserted against, such Contributor by reason
175 |
of your accepting any such warranty or additional liability.
176 |
177 |
178 |
179 |
APPENDIX: How to apply the Apache License to your work.
180 |
181 |
To apply the Apache License to your work, attach the following
182 |
boilerplate notice, with the fields enclosed by brackets "[]"
183 |
replaced with your own identifying information. (Don't include
184 |
the brackets!) The text should be enclosed in the appropriate
185 |
comment syntax for the file format. We also recommend that a
186 |
file or class name and description of purpose be included on the
187 |
same "printed page" as the copyright notice for easier
188 |
identification within third-party archives.
189 |
190 |
Copyright 2017 Google Inc.
191 |
192 |
Licensed under the Apache License, Version 2.0 (the "License");
193 |
you may not use this file except in compliance with the License.
194 |
You may obtain a copy of the License at
195 |
196 |
197 |
198 |
Unless required by applicable law or agreed to in writing, software
199 |
distributed under the License is distributed on an "AS IS" BASIS,
200 |
201 |
See the License for the specific language governing permissions and
202 |
limitations under the License.
@@ -0,0 +1,71 @@
1 |
CI ?=
2 |
3 |
# Directory variables
4 |
DING_DIR ?= ./ding
5 |
DIZOO_DIR ?= ./dizoo
6 |
7 |
8 |
9 |
10 |
PLATFORM_TEST_DIR ?= $(if ${RANGE_DIR},${RANGE_DIR},${DING_DIR}/entry/tests/ ${DING_DIR}/entry/tests/
11 |
12 |
# Workers command
13 |
14 |
WORKERS_COMMAND := $(if ${WORKERS},-n ${WORKERS} --dist=loadscope,)
15 |
16 |
# Duration command
17 |
18 |
19 |
20 |
21 |
$(MAKE) -C ${DING_DIR}/docs html
22 |
23 |
24 |
pytest ${TEST_DIR} \
25 |
--cov-report=xml \
26 |
--cov-report term-missing \
27 |
--cov=${COV_DIR} \
28 |
29 |
30 |
-sv -m unittest \
31 |
32 |
33 |
pytest ${TEST_DIR} \
34 |
35 |
-sv -m algotest
36 |
37 |
38 |
pytest ${TEST_DIR} \
39 |
-sv -m cudatest
40 |
41 |
42 |
pytest ${TEST_DIR} \
43 |
-sv -m envpooltest
44 |
45 |
46 |
47 |
48 |
49 |
pytest ${TEST_DIR} \
50 |
--cov-report term-missing \
51 |
--cov=${COV_DIR} \
52 |
53 |
-sv -m platformtest
54 |
55 |
56 |
pytest ${TEST_DIR} \
57 |
--durations=0 \
58 |
-sv -m benchmark
59 |
60 |
test: unittest # just for compatibility, can be changed later
61 |
62 |
cpu_test: unittest algotest benchmark
63 |
64 |
all_test: unittest algotest cudatest benchmark
65 |
66 |
67 |
yapf --in-place --recursive -p --verbose --style .style.yapf ${FORMAT_DIR}
68 |
69 |
bash ${FORMAT_DIR} --test
70 |
71 |
flake8 ${FORMAT_DIR}
@@ -0,0 +1,475 @@
1 |
<div align="center">
2 |
<a href=""><img width="1000px" height="auto" src=""></a>
3 |
4 |
5 |
6 |
7 |
8 |
9 |

10 |

11 |

12 |

13 |
14 |

15 |

16 |
17 |

18 |
19 |
20 |

21 |

22 |

23 |
24 |
25 |
26 |
27 |

28 |
29 |
30 |

31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
Updated on 2023.12.05 DI-engine-v0.5.0
39 |
40 |
41 |
## Introduction to DI-engine
42 |
[Documentation]( | [中文文档]( | [Tutorials]( | [Feature](#feature) | [Task & Middleware]( | [TreeTensor](#general-data-container-treetensor) | [Roadmap](
43 |
44 |
**DI-engine** is a generalized decision intelligence engine for PyTorch and JAX.
45 |
46 |
It provides **python-first** and **asynchronous-native** task and middleware abstractions, and modularly integrates several of the most important decision-making concepts: Env, Policy and Model. Based on the above mechanisms, DI-engine supports **various [deep reinforcement learning]( algorithms** with superior performance, high efficiency, well-organized [documentation]( and [unittest](
47 |
48 |
- Most basic DRL algorithms: such as DQN, Rainbow, PPO, TD3, SAC, R2D2, IMPALA
49 |
- Multi-agent RL algorithms: such as QMIX, WQMIX, MAPPO, HAPPO, ACE
50 |
- Imitation learning algorithms (BC/IRL/GAIL): such as GAIL, SQIL, Guided Cost Learning, Implicit BC
51 |
- Offline RL algorithms: BCQ, CQL, TD3BC, Decision Transformer, EDAC, Diffuser, Decision Diffuser, SO2
52 |
- Model-based RL algorithms: SVG, STEVE, MBPO, DDPPO, DreamerV3, MuZero
53 |
- Exploration algorithms: HER, RND, ICM, NGU
54 |
- LLM + RL Algorithms: PPO-max, DPO, MPDPO
55 |
- Other algorithms: such as PER, PLR, PCGrad
56 |
57 |
**DI-engine** aims to **standardize different Decision Intelligence environments and applications**, supporting both academic research and prototype applications. Various training pipelines and customized decision AI applications are also supported:
58 |
59 |
<details open>
60 |
<summary>(Click to Collapse)</summary>
61 |
62 |
- Traditional academic environments
63 |
- [DI-zoo]( various decision intelligence demonstrations and benchmark environments with DI-engine.
64 |
- Tutorial courses
65 |
- [PPOxFamily]( PPO x Family DRL Tutorial Course
66 |
- Real world decision AI applications
67 |
- [DI-star]( Decision AI in StarCraftII
68 |
- [DI-drive]( Auto-driving platform
69 |
- [DI-sheep]( Decision AI in 3 Tiles Game
70 |
- [DI-smartcross]( Decision AI in Traffic Light Control
71 |
- [DI-bioseq]( Decision AI in Biological Sequence Prediction and Searching
72 |
- [DI-1024]( Deep Reinforcement Learning + 1024 Game
73 |
- Research paper
74 |
- [InterFuser]( [CoRL 2022] Safety-Enhanced Autonomous Driving Using Interpretable Sensor Fusion Transformer
75 |
- [ACE]( [AAAI 2023] ACE: Cooperative Multi-agent Q-learning with Bidirectional Action-Dependency
76 |
- [GoBigger]( [ICLR 2023] Multi-Agent Decision Intelligence Environment
77 |
- [DOS]( [CVPR 2023] ReasonNet: End-to-End Driving with Temporal and Global Reasoning
78 |
- [LightZero]( [NeurIPS 2023 Spotlight] A lightweight and efficient MCTS/AlphaZero/MuZero algorithm toolkit
79 |
- [SO2]( [AAAI 2024] A Perspective of Q-value Estimation on Offline-to-Online Reinforcement Learning
80 |
- [LMDrive]( LMDrive: Closed-Loop End-to-End Driving with Large Language Models
81 |
- Docs and Tutorials
82 |
- [DI-engine-docs]( Tutorials, best practice and the API reference.
83 |
- [awesome-model-based-RL]( A curated list of awesome Model-Based RL resources
84 |
- [awesome-exploration-RL]( A curated list of awesome exploration RL resources
85 |
- [awesome-decision-transformer]( A curated list of Decision Transformer resources
86 |
- [awesome-RLHF]( A curated list of reinforcement learning with human feedback resources
87 |
- [awesome-multi-modal-reinforcement-learning]( A curated list of Multi-Modal Reinforcement Learning resources
88 |
- [awesome-AI-based-protein-design]( a collection of research papers for AI-based protein design
89 |
- [awesome-diffusion-model-in-rl]( A curated list of Diffusion Model in RL resources
90 |
- [awesome-end-to-end-autonomous-driving]( A curated list of awesome End-to-End Autonomous Driving resources
91 |
- [awesome-driving-behavior-prediction]( A collection of research papers for Driving Behavior Prediction
92 |
93 |
94 |
On the low-level end, DI-engine comes with a set of highly re-usable modules, including [RL optimization functions](, [PyTorch utilities]( and [auxiliary tools](
95 |
96 |
BTW, **DI-engine** also has some special **system optimization and design** for efficient and robust large-scale RL training:
97 |
98 |
<details close>
99 |
<summary>(Click for Details)</summary>
100 |
101 |
- [treevalue]( Tree-nested data structure
102 |
- [DI-treetensor]( Tree-nested PyTorch tensor Lib
103 |
- [DI-toolkit]( A simple toolkit package for decision intelligence
104 |
- [DI-orchestrator]( RL Kubernetes Custom Resource and Operator Lib
105 |
- [DI-hpc]( RL HPC OP Lib
106 |
- [DI-store]( RL Object Store
107 |
108 |
109 |
Have fun with exploration and exploitation.
110 |
111 |
## Outline
112 |
113 |
- [Introduction to DI-engine](#introduction-to-di-engine)
114 |
- [Outline](#outline)
115 |
- [Installation](#installation)
116 |
- [Quick Start](#quick-start)
117 |
- [Feature](#feature)
118 |
- [Algorithm Versatility](#algorithm-versatility)
119 |
- [Environment Versatility](#environment-versatility)
120 |
- [General Data Container: TreeTensor](#general-data-container-treetensor)
121 |
- [Feedback and Contribution](#feedback-and-contribution)
122 |
- [Supporters](#supporters)
123 |
- [↳ Stargazers](#-stargazers)
124 |
- [↳ Forkers](#-forkers)
125 |
- [Citation](#citation)
126 |
- [License](#license)
127 |
128 |
## Installation
129 |
130 |
You can simply install DI-engine from PyPI with the following command:
131 |
132 |
pip install DI-engine
133 |
134 |
135 |
If you use Anaconda or Miniconda, you can install DI-engine from conda-forge through the following command:
136 |
137 |
conda install -c opendilab di-engine
138 |
139 |
140 |
For more information about installation, you can refer to [installation](
141 |
142 |
And our dockerhub repo can be found [here](,we prepare `base image` and `env image` with common RL environments.
143 |
144 |
<details close>
145 |
<summary>(Click for Details)</summary>
146 |
147 |
- base: opendilab/ding:nightly
148 |
- rpc: opendilab/ding:nightly-rpc
149 |
- atari: opendilab/ding:nightly-atari
150 |
- mujoco: opendilab/ding:nightly-mujoco
151 |
- dmc: opendilab/ding:nightly-dmc2gym
152 |
- metaworld: opendilab/ding:nightly-metaworld
153 |
- smac: opendilab/ding:nightly-smac
154 |
- grf: opendilab/ding:nightly-grf
155 |
- cityflow: opendilab/ding:nightly-cityflow
156 |
- evogym: opendilab/ding:nightly-evogym
157 |
- d4rl: opendilab/ding:nightly-d4rl
158 |
159 |
160 |
The detailed documentation are hosted on [doc]( | [中文文档](
161 |
162 |
## Quick Start
163 |
164 |
[3 Minutes Kickoff](
165 |
166 |
[3 Minutes Kickoff (colab)](
167 |
168 |
[DI-engine Huggingface Kickoff (colab)](
169 |
170 |
[How to migrate a new **RL Env**]( | [如何迁移一个新的**强化学习环境**](
171 |
172 |
[How to customize the neural network model]( | [如何定制策略使用的**神经网络模型**](
173 |
174 |
[测试/部署 **强化学习策略** 的样例](
175 |
176 |
[新老 pipeline 的异同对比](
177 |
178 |
179 |
## Feature
180 |
### Algorithm Versatility
181 |
182 |
<details open>
183 |
<summary>(Click to Collapse)</summary>
184 |
185 |
 discrete means discrete action space, which is only label in normal DRL algorithms (1-23)
186 |
187 |
 means continuous action space, which is only label in normal DRL algorithms (1-23)
188 |
189 |
 means hybrid (discrete + continuous) action space (1-23)
190 |
191 |
 [Distributed Reinforcement Learning](|[分布式强化学习](
192 |
193 |
 [Multi-Agent Reinforcement Learning](|[多智能体强化学习](
194 |
195 |
 [Exploration Mechanisms in Reinforcement Learning](|[强化学习中的探索机制](
196 |
197 |
 [Imitation Learning](|[模仿学习](
198 |
199 |
 [Offiline Reinforcement Learning](|[离线强化学习](
200 |
201 |
202 |
 [Model-Based Reinforcement Learning](|[基于模型的强化学习](
203 |
204 |
 means other sub-direction algorithms, usually as plugin-in in the whole pipeline
205 |
206 |
P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
207 |
208 |
209 |
210 |
| No. | Algorithm | Label | Doc and Implementation | Runnable Demo |
211 |
| :--: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
212 |
| 1 | [DQN]( |  | [DQN doc](<br>[DQN中文文档](<br>[policy/dqn]( | python3 -u / ding -m serial -c -s 0 |
213 |
| 2 | [C51]( |  | [C51 doc](<br>[policy/c51]( | ding -m serial -c -s 0 |
214 |
| 3 | [QRDQN]( |  | [QRDQN doc](<br>[policy/qrdqn]( | ding -m serial -c -s 0 |
215 |
| 4 | [IQN]( |  | [IQN doc](<br>[policy/iqn]( | ding -m serial -c -s 0 |
216 |
| 5 | [FQF]( |  | [FQF doc](<br>[policy/fqf]( | ding -m serial -c -s 0 |
217 |
| 6 | [Rainbow]( |  | [Rainbow doc](<br>[policy/rainbow]( | ding -m serial -c -s 0 |
218 |
| 7 | [SQL]( |  | [SQL doc](<br>[policy/sql]( | ding -m serial -c -s 0 |
219 |
| 8 | [R2D2]( |  | [R2D2 doc](<br>[policy/r2d2]( | ding -m serial -c -s 0 |
220 |
| 9 | [PG]( |  | [PG doc](<br>[policy/pg]( | ding -m serial -c -s 0 |
221 |
| 10 | [PromptPG]( |  | [policy/prompt_pg]( | ding -m serial_onpolicy -c -s 0 |
222 |
| 11 | [A2C]( |  | [A2C doc](<br>[policy/a2c]( | ding -m serial -c -s 0 |
223 |
| 12 | [PPO]([MAPPO]( |  | [PPO doc](<br>[policy/ppo]( | python3 -u / ding -m serial_onpolicy -c -s 0 |
224 |
| 13 | [PPG]( |  | [PPG doc](<br>[policy/ppg]( | python3 -u |
225 |
| 14 | [ACER]( |  | [ACER doc](<br>[policy/acer]( | ding -m serial -c -s 0 |
226 |
| 15 | [IMPALA]( |  | [IMPALA doc](<br>[policy/impala]( | ding -m serial -c -s 0 |
227 |
| 16 | [DDPG]([PADDPG]( |  | [DDPG doc](<br>[policy/ddpg]( | ding -m serial -c -s 0 |
228 |
| 17 | [TD3]( |  | [TD3 doc](<br>[policy/td3]( | python3 -u / ding -m serial -c -s 0 |
229 |
| 18 | [D4PG]( |  | [D4PG doc](<br>[policy/d4pg]( | python3 -u |
230 |
| 19 | [SAC]([MASAC] |  | [SAC doc](<br>[policy/sac]( | ding -m serial -c -s 0 |
231 |
| 20 | [PDQN]( |  | [policy/pdqn]( | ding -m serial -c -s 0 |
232 |
| 21 | [MPDQN]( |  | [policy/pdqn]( | ding -m serial -c -s 0 |
233 |
| 22 | [HPPO]( |  | [policy/ppo]( | ding -m serial_onpolicy -c -s 0 |
234 |
| 23 | [BDQ]( |  | [policy/bdq]( | python3 -u |
235 |
| 24 | [MDQN]( |  | [policy/mdqn]( | python3 -u |
236 |
| 25 | [QMIX]( |  | [QMIX doc](<br>[policy/qmix]( | ding -m serial -c -s 0 |
237 |
| 26 | [COMA]( |  | [COMA doc](<br>[policy/coma]( | ding -m serial -c -s 0 |
238 |
| 27 | [QTran]( |  | [policy/qtran]( | ding -m serial -c -s 0 |
239 |
| 28 | [WQMIX]( |  | [WQMIX doc](<br>[policy/wqmix]( | ding -m serial -c -s 0 |
240 |
| 29 | [CollaQ]( |  | [CollaQ doc](<br>[policy/collaq]( | ding -m serial -c -s 0 |
241 |
| 30 | [MADDPG]( |  | [MADDPG doc](<br>[policy/ddpg]( | ding -m serial -c -s 0 |
242 |
| 31 | [GAIL]( |  | [GAIL doc](<br>[reward_model/gail]( | ding -m serial_gail -c -s 0 |
243 |
| 32 | [SQIL]( |  | [SQIL doc](<br>[entry/sqil]( | ding -m serial_sqil -c -s 0 |
244 |
| 33 | [DQFD]( |  | [DQFD doc](<br>[policy/dqfd]( | ding -m serial_dqfd -c -s 0 |
245 |
| 34 | [R2D3]( |  | [R2D3 doc](<br>[R2D3中文文档](<br>[policy/r2d3]( | python3 -u |
246 |
| 35 | [Guided Cost Learning]( |  | [Guided Cost Learning中文文档](<br>[reward_model/guided_cost]( | python3 |
247 |
| 36 | [TREX]( |  | [TREX doc](<br>[reward_model/trex]( | python3 |
248 |
| 37 | [Implicit Behavorial Cloning]( (DFO+MCMC) |  | [policy/ibc]( <br> [model/template/ebm]( | python3 -s 0 -c |
249 |
| 38 | [BCO]( |  | [entry/bco]( | python3 -u |
250 |
| 39 | [HER]( |  | [HER doc](<br>[reward_model/her]( | python3 -u |
251 |
| 40 | [RND]( |  | [RND doc](<br>[reward_model/rnd]( | python3 -u |
252 |
| 41 | [ICM]( |  | [ICM doc](<br>[ICM中文文档](<br>[reward_model/icm]( | python3 -u |
253 |
| 42 | [CQL]( |  | [CQL doc](<br>[policy/cql]( | python3 -u |
254 |
| 43 | [TD3BC]( |  | [TD3BC doc](<br>[policy/td3_bc]( | python3 -u |
255 |
| 44 | [Decision Transformer]( |  | [policy/dt]( | python3 -u |
256 |
| 45 | [EDAC]( |  | [EDAC doc](<br>[policy/edac]( | python3 -u |
257 |
| 46 | MBSAC([SAC]([MVE]([SVG]( |  | [policy/mbpolicy/mbsac]( | python3 -u \ python3 -u |
258 |
| 47 | STEVESAC([SAC]([STEVE]([SVG]( |  | [policy/mbpolicy/mbsac]( | python3 -u |
259 |
| 48 | [MBPO]( |  | [MBPO doc](<br>[world_model/mbpo]( | python3 -u |
260 |
| 49 | [DDPPO]( |  | [world_model/ddppo]( | python3 -u |
261 |
| 50 | [DreamerV3]( |  | [world_model/dreamerv3]( | python3 -u |
262 |
| 51 | [PER]( |  | [worker/replay_buffer]( | `rainbow demo` |
263 |
| 52 | [GAE]( |  | [rl_utils/gae]( | `ppo demo` |
264 |
| 53 | [ST-DIM]( |  | [torch_utils/loss/contrastive_loss]( | ding -m serial -c -s 0 |
265 |
| 54 | [PLR]( |  | [PLR doc](<br>[data/level_replay/level_sampler]( | python3 -u -s 0 |
266 |
| 55 | [PCGrad]( |  | [torch_utils/optimizer_helper/PCGrad]( | python3 -u -s 0 |
267 |
268 |
269 |
270 |
### Environment Versatility
271 |
<details open>
272 |
<summary>(Click to Collapse)</summary>
273 |
274 |
| No | Environment | Label | Visualization | Code and Doc Links |
275 |
| :--: | :--------------------------------------: | :---------------------------------: | :--------------------------------:|:---------------------------------------------------------: |
276 |
| 1 | [Atari]( |  |  | [dizoo link]( <br>[env tutorial](<br>[环境指南]( |
277 |
| 2 | [box2d/bipedalwalker]( |  |  | [dizoo link](<br>[env tutorial](<br>[环境指南]( |
278 |
| 3 | [box2d/lunarlander]( |  |  | [dizoo link](<br>[env tutorial](<br>[环境指南]( |
279 |
| 4 | [classic_control/cartpole]( |  |  | [dizoo link](<br>[env tutorial](<br>[环境指南]( |
280 |
| 5 | [classic_control/pendulum]( |  |  | [dizoo link](<br>[env tutorial](<br>[环境指南]( |
281 |
| 6 | [competitive_rl]( |   |  | [dizoo link](<br>[环境指南]( |
282 |
| 7 | [gfootball]( |  |  | [dizoo link](<br>[env tutorial](<br>[环境指南]( |
283 |
| 8 | [minigrid]( |  |  | [dizoo link](<br>[env tutorial](<br>[环境指南]( |
284 |
| 9 | [MuJoCo]( |  |  | [dizoo link](<br>[env tutorial](<br>[环境指南]( |
285 |
| 10 | [PettingZoo]( |    |  | [dizoo link](<br>[env tutorial](<br>[环境指南]( |
286 |
| 11 | [overcooked]( |   |  | [dizoo link](<br>[env tutorial]( |
287 |
| 12 | [procgen]( |  |  | [dizoo link](<br>[env tutorial](<br>[环境指南]( |
288 |
| 13 | [pybullet]( |  |  | [dizoo link](<br>[环境指南]( |
289 |
| 14 | [smac]( |   |  | [dizoo link](<br>[env tutorial](<br>[环境指南]( |
290 |
| 15 | [d4rl]( |  |  | [dizoo link](<br>[环境指南]( |
291 |
| 16 | league_demo |   |  | [dizoo link]( |
292 |
| 17 | pomdp atari |  | | [dizoo link]( |
293 |
| 18 | [bsuite]( |  |  | [dizoo link](<br>[env tutorial]( <br> [环境指南]( |
294 |
| 19 | [ImageNet]( |  |  | [dizoo link](<br>[环境指南]( |
295 |
| 20 | [slime_volleyball]( |  |  | [dizoo link](<br>[env tutorial](<br>[环境指南]( |
296 |
| 21 | [gym_hybrid]( |  |  | [dizoo link](<br>[env tutorial](<br>[环境指南]( |
297 |
| 22 | [GoBigger]( |  |  | [dizoo link](<br>[env tutorial](<br>[环境指南]( |
298 |
| 23 | [gym_soccer]( |  |  | [dizoo link](<br>[环境指南]( |
299 |
| 24 |[multiagent_mujoco]( |   |  | [dizoo link](<br>[环境指南]( |
300 |
| 25 |bitflip |   |  | [dizoo link](<br>[环境指南]( |
301 |
| 26 |[sokoban]( |  |  | [dizoo link](<br>[env tutorial](<br>[环境指南]( |
302 |
| 27 |[gym_anytrading]( |  |  | [dizoo link]( <br> [env tutorial]( |
303 |
| 28 |[mario]( |  |  | [dizoo link]( <br> [env tutorial]( <br>[环境指南]( |
304 |
| 29 |[dmc2gym]( |  |  | [dizoo link](<br>[env tutorial](<br>[环境指南]( |
305 |
| 30 |[evogym]( |  |  | [dizoo link]( <br> [env tutorial]( <br> [环境指南]( |
306 |
| 31 |[gym-pybullet-drones]( |  |  | [dizoo link](<br>环境指南 |
307 |
| 32 |[beergame]( |  |  | [dizoo link](<br>环境指南 |
308 |
| 33 |[classic_control/acrobot]( |  |  | [dizoo link](<br> [环境指南]( |
309 |
| 34 |[box2d/car_racing]( |  <br>  |  | [dizoo link](<br>环境指南 |
310 |
| 35 |[metadrive]( |  |  | [dizoo link](<br> [环境指南]( |
311 |
| 36 |[cliffwalking]( |  |  | [dizoo link](<br> env tutorial <br> 环境指南 |
312 |
| 37 | [tabmwp]( |  |  | [dizoo link]( <br> env tutorial <br> 环境指南|
313 |
314 |
 means discrete action space
315 |
316 |
 means continuous action space
317 |
318 |
 means hybrid (discrete + continuous) action space
319 |
320 |
 means multi-agent RL environment
321 |
322 |
 means environment which is related to exploration and sparse reward
323 |
324 |
 means offline RL environment
325 |
326 |
 means Imitation Learning or Supervised Learning Dataset
327 |
328 |
 means environment that allows agent VS agent battle
329 |
330 |
P.S. some enviroments in Atari, such as **MontezumaRevenge**, are also the sparse reward type.
331 |
332 |
333 |
334 |
### General Data Container: TreeTensor
335 |
336 |
DI-engine utilizes [TreeTensor]( as the basic data container in various components, which is ease of use and consistent across different code modules such as environment definition, data processing and DRL optimization. Here are some concrete code examples:
337 |
338 |
- TreeTensor can easily extend all the operations of `torch.Tensor` to nested data:
339 |
<details close>
340 |
<summary>(Click for Details)</summary>
341 |
342 |
343 |
import treetensor.torch as ttorch
344 |
345 |
346 |
# create random tensor
347 |
data = ttorch.randn({'a': (3, 2), 'b': {'c': (3, )}})
348 |
# clone+detach tensor
349 |
data_clone = data.clone().detach()
350 |
# access tree structure like attribute
351 |
a = data.a
352 |
c = data.b.c
353 |
# stack/cat/split
354 |
stacked_data = ttorch.stack([data, data_clone], 0)
355 |
cat_data =[data, data_clone], 0)
356 |
data, data_clone = ttorch.split(stacked_data, 1)
357 |
# reshape
358 |
data = data.unsqueeze(-1)
359 |
data = data.squeeze(-1)
360 |
flatten_data = data.view(-1)
361 |
# indexing
362 |
data_0 = data[0]
363 |
data_1to2 = data[1:2]
364 |
# execute math calculations
365 |
data = data.sin()
366 |
data.b.c.cos_().clamp_(-1, 1)
367 |
data += data ** 2
368 |
# backward
369 |
370 |
loss = data.arctan().mean()
371 |
372 |
# print shape
373 |
374 |
# result
375 |
# <Size 0x7fbd3346ddc0>
376 |
# ├── 'a' --> torch.Size([1, 3, 2])
377 |
# └── 'b' --> <Size 0x7fbd3346dd00>
378 |
# └── 'c' --> torch.Size([1, 3])
379 |
380 |
381 |
382 |
383 |
- TreeTensor can make it simple yet effective to implement classic deep reinforcement learning pipeline
384 |
<details close>
385 |
<summary>(Click for Details)</summary>
386 |
387 |
388 |
import torch
389 |
import treetensor.torch as ttorch
390 |
391 |
B = 4
392 |
393 |
394 |
def get_item():
395 |
return {
396 |
'obs': {
397 |
'scalar': torch.randn(12),
398 |
'image': torch.randn(3, 32, 32),
399 |
400 |
'action': torch.randint(0, 10, size=(1,)),
401 |
'reward': torch.rand(1),
402 |
'done': False,
403 |
404 |
405 |
406 |
data = [get_item() for _ in range(B)]
407 |
408 |
409 |
# execute `stack` op
410 |
- def stack(data, dim):
411 |
- elem = data[0]
412 |
- if isinstance(elem, torch.Tensor):
413 |
- return torch.stack(data, dim)
414 |
- elif isinstance(elem, dict):
415 |
- return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
416 |
- elif isinstance(elem, bool):
417 |
- return torch.BoolTensor(data)
418 |
- else:
419 |
- raise TypeError("not support elem type: {}".format(type(elem)))
420 |
- stacked_data = stack(data, dim=0)
421 |
+ data = [ttorch.tensor(d) for d in data]
422 |
+ stacked_data = ttorch.stack(data, dim=0)
423 |
424 |
# validate
425 |
- assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
426 |
- assert stacked_data['action'].shape == (B, 1)
427 |
- assert stacked_data['reward'].shape == (B, 1)
428 |
- assert stacked_data['done'].shape == (B,)
429 |
- assert stacked_data['done'].dtype == torch.bool
430 |
+ assert stacked_data.obs.image.shape == (B, 3, 32, 32)
431 |
+ assert stacked_data.action.shape == (B, 1)
432 |
+ assert stacked_data.reward.shape == (B, 1)
433 |
+ assert stacked_data.done.shape == (B,)
434 |
+ assert stacked_data.done.dtype == torch.bool
435 |
436 |
437 |
438 |
439 |
## Feedback and Contribution
440 |
441 |
- [File an issue]( on Github
442 |
- Open or participate in our [forum](
443 |
- Discuss on DI-engine [slack communication channel](
444 |
- Discuss on DI-engine's WeChat group (i.e. add us on WeChat: ding314assist)
445 |
446 |
<img src= width=35% />
447 |
- Contact our email ([email protected])
448 |
- Contributes to our future plan [Roadmap](
449 |
450 |
We appreciate all the feedbacks and contributions to improve DI-engine, both algorithms and system designs. And `` offers some necessary information.
451 |
452 |
## Supporters
453 |
454 |
463 |
## License
475 |
DI-engine released under the Apache 2.0 license.
@@ -0,0 +1,69 @@
1 |
2 |
3 |
# This scripts counts the lines of code and comments in all source files
4 |
# and prints the results to the command line. It uses the commandline tool
5 |
# "cloc". You can either pass --loc, --comments or --percentage to show the
6 |
# respective values only.
7 |
# Some parts below need to be adapted to your project!
8 |
9 |
# Get the location of this script.
10 |
SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )"
11 |
12 |
# Run cloc - this counts code lines, blank lines and comment lines
13 |
# for the specified languages. You will need to change this accordingly.
14 |
# For C++, you could use "C++,C/C++ Header" for example.
15 |
# We are only interested in the summary, therefore the tail -1
16 |
SUMMARY="$(cloc "${SCRIPT_DIR}" --include-lang="Python" --md | tail -1)"
17 |
18 |
# The $SUMMARY is one line of a markdown table and looks like this:
19 |
# SUM:|101|3123|2238|10783
20 |
# We use the following command to split it into an array.
21 |
IFS='|' read -r -a TOKENS <<< "$SUMMARY"
22 |
23 |
# Store the individual tokens for better readability.
24 |
25 |
26 |
27 |
28 |
# To make the estimate of commented lines more accurate, we have to
29 |
# subtract any copyright header which is included in each file.
30 |
# For Fly-Pie, this header has the length of five lines.
31 |
# All dumb comments like those /////////// or those // ------------
32 |
# are also subtracted. As cloc does not count inline comments,
33 |
# the overall estimate should be rather conservative.
34 |
# Change the lines below according to your project.
35 |
DUMB_COMMENTS="$(grep -r -E '//////|// -----' "${SCRIPT_DIR}" | wc -l)"
36 |
37 |
38 |
# Print all results if no arguments are given.
39 |
if [[ $# -eq 0 ]] ; then
40 |
awk -v a=$LINES_OF_CODE \
41 |
'BEGIN {printf "Lines of source code: %6.1fk\n", a/1000}'
42 |
awk -v a=$COMMENT_LINES \
43 |
'BEGIN {printf "Lines of comments: %6.1fk\n", a/1000}'
44 |
45 |
'BEGIN {printf "Comment Percentage: %6.1f%\n", 100*a/b}'
46 |
exit 0
47 |
48 |
49 |
# Show lines of code if --loc is given.
50 |
if [[ $* == *--loc* ]]
51 |
52 |
awk -v a=$LINES_OF_CODE \
53 |
'BEGIN {printf "%.1fk\n", a/1000}'
54 |
55 |
56 |
# Show lines of comments if --comments is given.
57 |
if [[ $* == *--comments* ]]
58 |
59 |
awk -v a=$COMMENT_LINES \
60 |
'BEGIN {printf "%.1fk\n", a/1000}'
61 |
62 |
63 |
# Show precentage of comments if --percentage is given.
64 |
if [[ $* == *--percentage* ]]
65 |
66 |
67 |
'BEGIN {printf "%.1f\n", 100*a/b}'
68 |
69 |
@@ -0,0 +1,8 @@
1 |
2 |
3 |
4 |
5 |
# basic
6 |
target: auto
7 |
threshold: 0.5%
8 |
if_ci_failed: success #success, failure, error, ignore
@@ -0,0 +1,2 @@
1 |
2 |
- 3.7
@@ -0,0 +1,35 @@
1 |
{% set data = load_setup_py_data() %}
2 |
3 |
name: di-engine
4 |
version: v0.5.0
5 |
6 |
7 |
path: ..
8 |
9 |
10 |
number: 0
11 |
script: python -m pip install . -vv
12 |
13 |
- ding = ding.entry.cli:cli
14 |
15 |
16 |
17 |
- python
18 |
- setuptools
19 |
20 |
- python
21 |
22 |
23 |
24 |
- ding
25 |
- dizoo
26 |
27 |
28 |
29 |
license: Apache-2.0
30 |
license_file: LICENSE
31 |
summary: DI-engine is a generalized Decision Intelligence engine (
32 |
description: Please refer to
33 |
34 |
35 |
@@ -0,0 +1,12 @@
1 |
import os
2 |
3 |
__TITLE__ = 'DI-engine'
4 |
__VERSION__ = 'v0.5.0'
5 |
__DESCRIPTION__ = 'Decision AI Engine'
6 |
__AUTHOR__ = "OpenDILab Contributors"
7 |
__AUTHOR_EMAIL__ = "[email protected]"
8 |
__version__ = __VERSION__
9 |
10 |
enable_hpc_rl = os.environ.get('ENABLE_DI_HPC', 'false').lower() == 'true'
11 |
enable_linklink = os.environ.get('ENABLE_LINKLINK', 'false').lower() == 'true'
12 |
enable_numba = True
@@ -0,0 +1,132 @@
1 |
import ding.config
2 |
from .a2c import A2CAgent
3 |
from .c51 import C51Agent
4 |
from .ddpg import DDPGAgent
5 |
from .dqn import DQNAgent
6 |
from .pg import PGAgent
7 |
from .ppof import PPOF
8 |
from .ppo_offpolicy import PPOOffPolicyAgent
9 |
from .sac import SACAgent
10 |
from .sql import SQLAgent
11 |
from .td3 import TD3Agent
12 |
13 |
supported_algo = dict(
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
supported_algo_list = list(supported_algo.keys())
27 |
28 |
29 |
def env_supported(algo: str = None) -> list:
30 |
31 |
return list of the envs that supported by di-engine.
32 |
33 |
34 |
if algo is not None:
35 |
if algo.upper() == "A2C":
36 |
return list(ding.config.example.A2C.supported_env.keys())
37 |
elif algo.upper() == "C51":
38 |
return list(ding.config.example.C51.supported_env.keys())
39 |
elif algo.upper() == "DDPG":
40 |
return list(ding.config.example.DDPG.supported_env.keys())
41 |
elif algo.upper() == "DQN":
42 |
return list(ding.config.example.DQN.supported_env.keys())
43 |
elif algo.upper() == "PG":
44 |
return list(ding.config.example.PG.supported_env.keys())
45 |
elif algo.upper() == "PPOF":
46 |
return list(ding.config.example.PPOF.supported_env.keys())
47 |
elif algo.upper() == "PPOOFFPOLICY":
48 |
return list(ding.config.example.PPOOffPolicy.supported_env.keys())
49 |
elif algo.upper() == "SAC":
50 |
return list(ding.config.example.SAC.supported_env.keys())
51 |
elif algo.upper() == "SQL":
52 |
return list(ding.config.example.SQL.supported_env.keys())
53 |
elif algo.upper() == "TD3":
54 |
return list(ding.config.example.TD3.supported_env.keys())
55 |
56 |
raise ValueError("The algo {} is not supported by di-engine.".format(algo))
57 |
58 |
supported_env = set()
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
# return the list of the envs
70 |
return list(supported_env)
71 |
72 |
73 |
supported_env = env_supported()
74 |
75 |
76 |
def algo_supported(env_id: str = None) -> list:
77 |
78 |
return list of the algos that supported by di-engine.
79 |
80 |
if env_id is not None:
81 |
algo = []
82 |
if env_id.upper() in [item.upper() for item in ding.config.example.A2C.supported_env.keys()]:
83 |
84 |
if env_id.upper() in [item.upper() for item in ding.config.example.C51.supported_env.keys()]:
85 |
86 |
if env_id.upper() in [item.upper() for item in ding.config.example.DDPG.supported_env.keys()]:
87 |
88 |
if env_id.upper() in [item.upper() for item in ding.config.example.DQN.supported_env.keys()]:
89 |
90 |
if env_id.upper() in [item.upper() for item in ding.config.example.PG.supported_env.keys()]:
91 |
92 |
if env_id.upper() in [item.upper() for item in ding.config.example.PPOF.supported_env.keys()]:
93 |
94 |
if env_id.upper() in [item.upper() for item in ding.config.example.PPOOffPolicy.supported_env.keys()]:
95 |
96 |
if env_id.upper() in [item.upper() for item in ding.config.example.SAC.supported_env.keys()]:
97 |
98 |
if env_id.upper() in [item.upper() for item in ding.config.example.SQL.supported_env.keys()]:
99 |
100 |
if env_id.upper() in [item.upper() for item in ding.config.example.TD3.supported_env.keys()]:
101 |
102 |
103 |
if len(algo) == 0:
104 |
raise ValueError("The env {} is not supported by di-engine.".format(env_id))
105 |
return algo
106 |
107 |
return supported_algo_list
108 |
109 |
110 |
def is_supported(env_id: str = None, algo: str = None) -> bool:
111 |
112 |
Check if the env-algo pair is supported by di-engine.
113 |
114 |
if env_id is not None and env_id.upper() in [item.upper() for item in supported_env.keys()]:
115 |
if algo is not None and algo.upper() in supported_algo_list:
116 |
if env_id.upper() in env_supported(algo):
117 |
return True
118 |
119 |
return False
120 |
elif algo is None:
121 |
return True
122 |
123 |
return False
124 |
elif env_id is None:
125 |
if algo is not None and algo.upper() in supported_algo_list:
126 |
return True
127 |
elif algo is None:
128 |
raise ValueError("Please specify the env or algo.")
129 |
130 |
return False
131 |
132 |
return False
@@ -0,0 +1,460 @@
1 |
from typing import Optional, Union, List
2 |
from ditk import logging
3 |
from easydict import EasyDict
4 |
import os
5 |
import numpy as np
6 |
import torch
7 |
import treetensor.torch as ttorch
8 |
from ding.framework import task, OnlineRLContext
9 |
from ding.framework.middleware import CkptSaver, trainer, \
10 |
wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, \
11 |
gae_estimator, final_ctx_saver
12 |
from ding.envs import BaseEnv
13 |
from ding.envs import setup_ding_env_manager
14 |
from ding.policy import A2CPolicy
15 |
from ding.utils import set_pkg_seed
16 |
from ding.utils import get_env_fps, render
17 |
from ding.config import save_config_py, compile_config
18 |
from ding.model import VAC
19 |
from ding.model import model_wrap
20 |
from ding.bonus.common import TrainingReturn, EvalReturn
21 |
from ding.config.example.A2C import supported_env_cfg
22 |
from ding.config.example.A2C import supported_env
23 |
24 |
25 |
class A2CAgent:
26 |
27 |
28 |
Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
29 |
Advantage Actor Critic(A2C).
30 |
For more information about the system design of RL agent, please refer to \
31 |
32 |
33 |
``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
34 |
35 |
supported_env_list = list(supported_env_cfg.keys())
36 |
37 |
38 |
List of supported envs.
39 |
40 |
>>> from ding.bonus.a2c import A2CAgent
41 |
>>> print(A2CAgent.supported_env_list)
42 |
43 |
44 |
def __init__(
45 |
46 |
env_id: str = None,
47 |
env: BaseEnv = None,
48 |
seed: int = 0,
49 |
exp_name: str = None,
50 |
model: Optional[torch.nn.Module] = None,
51 |
cfg: Optional[Union[EasyDict, dict]] = None,
52 |
policy_state_dict: str = None,
53 |
) -> None:
54 |
55 |
56 |
Initialize agent for A2C algorithm.
57 |
58 |
- env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
59 |
If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
60 |
If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
61 |
``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
62 |
- env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
63 |
If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
64 |
``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
65 |
If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
66 |
- seed (:obj:`int`): The random seed, which is set before running the program. \
67 |
Default to 0.
68 |
- exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
69 |
log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
70 |
- model (:obj:`torch.nn.Module`): The model of A2C algorithm, which should be an instance of class \
71 |
:class:`ding.model.VAC`. \
72 |
If not specified, a default model will be generated according to the configuration.
73 |
- cfg (:obj:Union[EasyDict, dict]): The configuration of A2C algorithm, which is a dict. \
74 |
Default to None. If not specified, the default configuration will be used. \
75 |
The default configuration can be found in ``ding/config/example/A2C/``.
76 |
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
77 |
If specified, the policy will be loaded from this file. Default to None.
78 |
79 |
.. note::
80 |
An RL Agent Instance can be initialized in two basic ways. \
81 |
For example, we have an environment with id ``LunarLanderContinuous-v2`` registered in gym, \
82 |
and we want to train an agent with A2C algorithm with default configuration. \
83 |
Then we can initialize the agent in the following ways:
84 |
>>> agent = A2CAgent(env_id='LunarLanderContinuous-v2')
85 |
or, if we want can specify the env_id in the configuration:
86 |
>>> cfg = {'env': {'env_id': 'LunarLanderContinuous-v2'}, 'policy': ...... }
87 |
>>> agent = A2CAgent(cfg=cfg)
88 |
There are also other arguments to specify the agent when initializing.
89 |
For example, if we want to specify the environment instance:
90 |
>>> env = CustomizedEnv('LunarLanderContinuous-v2')
91 |
>>> agent = A2CAgent(cfg=cfg, env=env)
92 |
or, if we want to specify the model:
93 |
>>> model = VAC(**cfg.policy.model)
94 |
>>> agent = A2CAgent(cfg=cfg, model=model)
95 |
or, if we want to reload the policy from a saved policy state dict:
96 |
>>> agent = A2CAgent(cfg=cfg, policy_state_dict='LunarLanderContinuous-v2.pth.tar')
97 |
Make sure that the configuration is consistent with the saved policy state dict.
98 |
99 |
100 |
assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
101 |
102 |
if cfg is not None and not isinstance(cfg, EasyDict):
103 |
cfg = EasyDict(cfg)
104 |
105 |
if env_id is not None:
106 |
assert env_id in A2CAgent.supported_env_list, "Please use supported envs: {}".format(
107 |
108 |
109 |
if cfg is None:
110 |
cfg = supported_env_cfg[env_id]
111 |
112 |
assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
113 |
114 |
assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
115 |
assert cfg.env.env_id in A2CAgent.supported_env_list, "Please use supported envs: {}".format(
116 |
117 |
118 |
default_policy_config = EasyDict({"policy": A2CPolicy.default_config()})
119 |
120 |
cfg = default_policy_config
121 |
122 |
if exp_name is not None:
123 |
cfg.exp_name = exp_name
124 |
self.cfg = compile_config(cfg, policy=A2CPolicy)
125 |
self.exp_name = self.cfg.exp_name
126 |
if env is None:
127 |
self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
128 |
129 |
assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
130 |
self.env = env
131 |
132 |
133 |
self.seed = seed
134 |
set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
135 |
if not os.path.exists(self.exp_name):
136 |
137 |
save_config_py(self.cfg, os.path.join(self.exp_name, ''))
138 |
if model is None:
139 |
model = VAC(**self.cfg.policy.model)
140 |
self.policy = A2CPolicy(self.cfg.policy, model=model)
141 |
if policy_state_dict is not None:
142 |
143 |
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
144 |
145 |
def train(
146 |
147 |
step: int = int(1e7),
148 |
collector_env_num: int = 4,
149 |
evaluator_env_num: int = 4,
150 |
n_iter_log_show: int = 500,
151 |
n_iter_save_ckpt: int = 1000,
152 |
context: Optional[str] = None,
153 |
debug: bool = False,
154 |
wandb_sweep: bool = False,
155 |
) -> TrainingReturn:
156 |
157 |
158 |
Train the agent with A2C algorithm for ``step`` iterations with ``collector_env_num`` collector \
159 |
environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
160 |
recorded and saved by wandb.
161 |
162 |
- step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
163 |
- collector_env_num (:obj:`int`): The collector environment number. Default to None. \
164 |
If not specified, it will be set according to the configuration.
165 |
- evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
166 |
If not specified, it will be set according to the configuration.
167 |
- n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
168 |
Default to 1000.
169 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
170 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
171 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
172 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
173 |
subprocess environment manager will be used.
174 |
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
175 |
which is a hyper-parameter optimization process for seeking the best configurations. \
176 |
Default to False. If True, the wandb sweep id will be used as the experiment name.
177 |
178 |
- (:obj:`TrainingReturn`): The training result, of which the attributions are:
179 |
- wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
180 |
181 |
182 |
if debug:
183 |
184 |
185 |
# define env and policy
186 |
collector_env = self._setup_env_manager(collector_env_num, context, debug, 'collector')
187 |
evaluator_env = self._setup_env_manager(evaluator_env_num, context, debug, 'evaluator')
188 |
189 |
with task.start(ctx=OnlineRLContext()):
190 |
191 |
192 |
193 |
194 |
195 |
render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
196 |
197 |
198 |
task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
199 |
200 |
201 |
202 |
203 |
204 |
205 |
if hasattr(self.cfg.policy, 'random_collect_size') else 0,
206 |
207 |
208 |
task.use(gae_estimator(self.cfg, self.policy.collect_mode))
209 |
task.use(trainer(self.cfg, self.policy.learn_mode))
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
return TrainingReturn(wandb_url=task.ctx.wandb_url)
224 |
225 |
def deploy(
226 |
227 |
enable_save_replay: bool = False,
228 |
concatenate_all_replay: bool = False,
229 |
replay_save_path: str = None,
230 |
seed: Optional[Union[int, List]] = None,
231 |
debug: bool = False
232 |
) -> EvalReturn:
233 |
234 |
235 |
Deploy the agent with A2C algorithm by interacting with the environment, during which the replay video \
236 |
can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
237 |
238 |
- enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
239 |
- concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
240 |
Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
241 |
If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
242 |
the replay video of each episode will be saved separately.
243 |
- replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
244 |
If not specified, the video will be saved in ``exp_name/videos``.
245 |
- seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
246 |
Default to None. If not specified, ``self.seed`` will be used. \
247 |
If ``seed`` is an integer, the agent will be deployed once. \
248 |
If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
249 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
250 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
251 |
subprocess environment manager will be used.
252 |
253 |
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
254 |
- eval_value (:obj:`np.float32`): The mean of evaluation return.
255 |
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
256 |
257 |
258 |
if debug:
259 |
260 |
# define env and policy
261 |
env = self.env.clone(caller='evaluator')
262 |
263 |
if seed is not None and isinstance(seed, int):
264 |
seeds = [seed]
265 |
elif seed is not None and isinstance(seed, list):
266 |
seeds = seed
267 |
268 |
seeds = [self.seed]
269 |
270 |
returns = []
271 |
images = []
272 |
if enable_save_replay:
273 |
replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
274 |
275 |
276 |
logging.warning('No video would be generated during the deploy.')
277 |
if concatenate_all_replay:
278 |
logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
279 |
concatenate_all_replay = False
280 |
281 |
def single_env_forward_wrapper(forward_fn, cuda=True):
282 |
283 |
if self.cfg.policy.action_space == 'continuous':
284 |
forward_fn = model_wrap(forward_fn, wrapper_name='deterministic_sample').forward
285 |
elif self.cfg.policy.action_space == 'discrete':
286 |
forward_fn = model_wrap(forward_fn, wrapper_name='argmax_sample').forward
287 |
288 |
raise NotImplementedError
289 |
290 |
def _forward(obs):
291 |
# unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
292 |
obs = ttorch.as_tensor(obs).unsqueeze(0)
293 |
if cuda and torch.cuda.is_available():
294 |
obs = obs.cuda()
295 |
action = forward_fn(obs, mode='compute_actor')["action"]
296 |
# squeeze means delete batch dim, i.e. (1, A) -> (A, )
297 |
action = action.squeeze(0).detach().cpu().numpy()
298 |
return action
299 |
300 |
return _forward
301 |
302 |
forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
303 |
304 |
# reset first to make sure the env is in the initial state
305 |
# env will be reset again in the main loop
306 |
307 |
308 |
for seed in seeds:
309 |
env.seed(seed, dynamic_seed=False)
310 |
return_ = 0.
311 |
step = 0
312 |
obs = env.reset()
313 |
images.append(render(env)[None]) if concatenate_all_replay else None
314 |
while True:
315 |
action = forward_fn(obs)
316 |
obs, rew, done, info = env.step(action)
317 |
images.append(render(env)[None]) if concatenate_all_replay else None
318 |
return_ += rew
319 |
step += 1
320 |
if done:
321 |
322 |
+'DQN deploy is finished, final episode return with {step} steps is: {return_}')
323 |
324 |
325 |
326 |
327 |
if concatenate_all_replay:
328 |
images = np.concatenate(images, axis=0)
329 |
import imageio
330 |
imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
331 |
332 |
return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
333 |
334 |
def collect_data(
335 |
336 |
env_num: int = 8,
337 |
save_data_path: Optional[str] = None,
338 |
n_sample: Optional[int] = None,
339 |
n_episode: Optional[int] = None,
340 |
context: Optional[str] = None,
341 |
debug: bool = False
342 |
) -> None:
343 |
344 |
345 |
Collect data with A2C algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
346 |
The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
347 |
348 |
349 |
- env_num (:obj:`int`): The number of collector environments. Default to 8.
350 |
- save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
351 |
If not specified, the data will be saved in ``exp_name/demo_data``.
352 |
- n_sample (:obj:`int`): The number of samples to collect. Default to None. \
353 |
If not specified, ``n_episode`` must be specified.
354 |
- n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
355 |
If not specified, ``n_sample`` must be specified.
356 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
357 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
358 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
359 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
360 |
subprocess environment manager will be used.
361 |
362 |
363 |
if debug:
364 |
365 |
if n_episode is not None:
366 |
raise NotImplementedError
367 |
# define env and policy
368 |
env_num = env_num if env_num else self.cfg.env.collector_env_num
369 |
env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
370 |
371 |
if save_data_path is None:
372 |
save_data_path = os.path.join(self.exp_name, 'demo_data')
373 |
374 |
# main execution task
375 |
with task.start(ctx=OnlineRLContext()):
376 |
377 |
378 |
self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
379 |
380 |
381 |
task.use(offline_data_saver(save_data_path, data_type='hdf5'))
382 |
383 |
384 |
f'A2C collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
385 |
386 |
387 |
def batch_evaluate(
388 |
389 |
env_num: int = 4,
390 |
n_evaluator_episode: int = 4,
391 |
context: Optional[str] = None,
392 |
debug: bool = False
393 |
) -> EvalReturn:
394 |
395 |
396 |
Evaluate the agent with A2C algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
397 |
environments. The evaluation result will be returned.
398 |
The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
399 |
multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
400 |
will only create one evaluator environment to evaluate the agent and save the replay video.
401 |
402 |
- env_num (:obj:`int`): The number of evaluator environments. Default to 4.
403 |
- n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
404 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
405 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
406 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
407 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
408 |
subprocess environment manager will be used.
409 |
410 |
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
411 |
- eval_value (:obj:`np.float32`): The mean of evaluation return.
412 |
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
413 |
414 |
415 |
if debug:
416 |
417 |
# define env and policy
418 |
env_num = env_num if env_num else self.cfg.env.evaluator_env_num
419 |
env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
420 |
421 |
# reset first to make sure the env is in the initial state
422 |
# env will be reset again in the main loop
423 |
424 |
425 |
426 |
evaluate_cfg = self.cfg
427 |
evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
428 |
429 |
# main execution task
430 |
with task.start(ctx=OnlineRLContext()):
431 |
task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
432 |
433 |
434 |
return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
435 |
436 |
437 |
def best(self) -> 'A2CAgent':
438 |
439 |
440 |
Load the best model from the checkpoint directory, \
441 |
which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
442 |
The return value is the agent with the best model.
443 |
444 |
- (:obj:`A2CAgent`): The agent with the best model.
445 |
446 |
>>> agent = A2CAgent(env_id='LunarLanderContinuous-v2')
447 |
>>> agent.train()
448 |
>>> agent =
449 |
450 |
.. note::
451 |
The best model is the model with the highest evaluation return. If this method is called, the current \
452 |
model will be replaced by the best model.
453 |
454 |
455 |
best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
456 |
# Load best model if it exists
457 |
if os.path.exists(best_model_file_path):
458 |
policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
459 |
460 |
return self
@@ -0,0 +1,459 @@
1 |
from typing import Optional, Union, List
2 |
from ditk import logging
3 |
from easydict import EasyDict
4 |
import os
5 |
import numpy as np
6 |
import torch
7 |
import treetensor.torch as ttorch
8 |
from ding.framework import task, OnlineRLContext
9 |
from ding.framework.middleware import CkptSaver, \
10 |
wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, data_pusher, \
11 |
OffPolicyLearner, final_ctx_saver, eps_greedy_handler, nstep_reward_enhancer
12 |
from ding.envs import BaseEnv
13 |
from ding.envs import setup_ding_env_manager
14 |
from ding.policy import C51Policy
15 |
from ding.utils import set_pkg_seed
16 |
from ding.utils import get_env_fps, render
17 |
from ding.config import save_config_py, compile_config
18 |
from ding.model import C51DQN
19 |
from ding.model import model_wrap
20 |
from import DequeBuffer
21 |
from ding.bonus.common import TrainingReturn, EvalReturn
22 |
from ding.config.example.C51 import supported_env_cfg
23 |
from ding.config.example.C51 import supported_env
24 |
25 |
26 |
class C51Agent:
27 |
28 |
29 |
Class of agent for training, evaluation and deployment of Reinforcement learning algorithm C51.
30 |
For more information about the system design of RL agent, please refer to \
31 |
32 |
33 |
``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
34 |
35 |
supported_env_list = list(supported_env_cfg.keys())
36 |
37 |
38 |
List of supported envs.
39 |
40 |
>>> from ding.bonus.c51 import C51Agent
41 |
>>> print(C51Agent.supported_env_list)
42 |
43 |
44 |
def __init__(
45 |
46 |
env_id: str = None,
47 |
env: BaseEnv = None,
48 |
seed: int = 0,
49 |
exp_name: str = None,
50 |
model: Optional[torch.nn.Module] = None,
51 |
cfg: Optional[Union[EasyDict, dict]] = None,
52 |
policy_state_dict: str = None,
53 |
) -> None:
54 |
55 |
56 |
Initialize agent for C51 algorithm.
57 |
58 |
- env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
59 |
If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
60 |
If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
61 |
``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
62 |
- env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
63 |
If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
64 |
``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
65 |
If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
66 |
- seed (:obj:`int`): The random seed, which is set before running the program. \
67 |
Default to 0.
68 |
- exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
69 |
log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
70 |
- model (:obj:`torch.nn.Module`): The model of C51 algorithm, which should be an instance of class \
71 |
:class:`ding.model.C51DQN`. \
72 |
If not specified, a default model will be generated according to the configuration.
73 |
- cfg (:obj:Union[EasyDict, dict]): The configuration of C51 algorithm, which is a dict. \
74 |
Default to None. If not specified, the default configuration will be used. \
75 |
The default configuration can be found in ``ding/config/example/C51/``.
76 |
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
77 |
If specified, the policy will be loaded from this file. Default to None.
78 |
79 |
.. note::
80 |
An RL Agent Instance can be initialized in two basic ways. \
81 |
For example, we have an environment with id ``LunarLander-v2`` registered in gym, \
82 |
and we want to train an agent with C51 algorithm with default configuration. \
83 |
Then we can initialize the agent in the following ways:
84 |
>>> agent = C51Agent(env_id='LunarLander-v2')
85 |
or, if we want can specify the env_id in the configuration:
86 |
>>> cfg = {'env': {'env_id': 'LunarLander-v2'}, 'policy': ...... }
87 |
>>> agent = C51Agent(cfg=cfg)
88 |
There are also other arguments to specify the agent when initializing.
89 |
For example, if we want to specify the environment instance:
90 |
>>> env = CustomizedEnv('LunarLander-v2')
91 |
>>> agent = C51Agent(cfg=cfg, env=env)
92 |
or, if we want to specify the model:
93 |
>>> model = C51DQN(**cfg.policy.model)
94 |
>>> agent = C51Agent(cfg=cfg, model=model)
95 |
or, if we want to reload the policy from a saved policy state dict:
96 |
>>> agent = C51Agent(cfg=cfg, policy_state_dict='LunarLander-v2.pth.tar')
97 |
Make sure that the configuration is consistent with the saved policy state dict.
98 |
99 |
100 |
assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
101 |
102 |
if cfg is not None and not isinstance(cfg, EasyDict):
103 |
cfg = EasyDict(cfg)
104 |
105 |
if env_id is not None:
106 |
assert env_id in C51Agent.supported_env_list, "Please use supported envs: {}".format(
107 |
108 |
109 |
if cfg is None:
110 |
cfg = supported_env_cfg[env_id]
111 |
112 |
assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
113 |
114 |
assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
115 |
assert cfg.env.env_id in C51Agent.supported_env_list, "Please use supported envs: {}".format(
116 |
117 |
118 |
default_policy_config = EasyDict({"policy": C51Policy.default_config()})
119 |
120 |
cfg = default_policy_config
121 |
122 |
if exp_name is not None:
123 |
cfg.exp_name = exp_name
124 |
self.cfg = compile_config(cfg, policy=C51Policy)
125 |
self.exp_name = self.cfg.exp_name
126 |
if env is None:
127 |
self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
128 |
129 |
assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
130 |
self.env = env
131 |
132 |
133 |
self.seed = seed
134 |
set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
135 |
if not os.path.exists(self.exp_name):
136 |
137 |
save_config_py(self.cfg, os.path.join(self.exp_name, ''))
138 |
if model is None:
139 |
model = C51DQN(**self.cfg.policy.model)
140 |
self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
141 |
self.policy = C51Policy(self.cfg.policy, model=model)
142 |
if policy_state_dict is not None:
143 |
144 |
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
145 |
146 |
def train(
147 |
148 |
step: int = int(1e7),
149 |
collector_env_num: int = None,
150 |
evaluator_env_num: int = None,
151 |
n_iter_save_ckpt: int = 1000,
152 |
context: Optional[str] = None,
153 |
debug: bool = False,
154 |
wandb_sweep: bool = False,
155 |
) -> TrainingReturn:
156 |
157 |
158 |
Train the agent with C51 algorithm for ``step`` iterations with ``collector_env_num`` collector \
159 |
environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
160 |
recorded and saved by wandb.
161 |
162 |
- step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
163 |
- collector_env_num (:obj:`int`): The collector environment number. Default to None. \
164 |
If not specified, it will be set according to the configuration.
165 |
- evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
166 |
If not specified, it will be set according to the configuration.
167 |
- n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
168 |
Default to 1000.
169 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
170 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
171 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
172 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
173 |
subprocess environment manager will be used.
174 |
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
175 |
which is a hyper-parameter optimization process for seeking the best configurations. \
176 |
Default to False. If True, the wandb sweep id will be used as the experiment name.
177 |
178 |
- (:obj:`TrainingReturn`): The training result, of which the attributions are:
179 |
- wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
180 |
181 |
182 |
if debug:
183 |
184 |
185 |
# define env and policy
186 |
collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
187 |
evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
188 |
collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
189 |
evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
190 |
191 |
with task.start(ctx=OnlineRLContext()):
192 |
193 |
194 |
195 |
196 |
197 |
render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
198 |
199 |
200 |
task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
if hasattr(self.cfg.policy, 'random_collect_size') else 0,
209 |
210 |
211 |
212 |
task.use(data_pusher(self.cfg, self.buffer_))
213 |
task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
return TrainingReturn(wandb_url=task.ctx.wandb_url)
228 |
229 |
def deploy(
230 |
231 |
enable_save_replay: bool = False,
232 |
concatenate_all_replay: bool = False,
233 |
replay_save_path: str = None,
234 |
seed: Optional[Union[int, List]] = None,
235 |
debug: bool = False
236 |
) -> EvalReturn:
237 |
238 |
239 |
Deploy the agent with C51 algorithm by interacting with the environment, during which the replay video \
240 |
can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
241 |
242 |
- enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
243 |
- concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
244 |
Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
245 |
If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
246 |
the replay video of each episode will be saved separately.
247 |
- replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
248 |
If not specified, the video will be saved in ``exp_name/videos``.
249 |
- seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
250 |
Default to None. If not specified, ``self.seed`` will be used. \
251 |
If ``seed`` is an integer, the agent will be deployed once. \
252 |
If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
253 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
254 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
255 |
subprocess environment manager will be used.
256 |
257 |
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
258 |
- eval_value (:obj:`np.float32`): The mean of evaluation return.
259 |
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
260 |
261 |
262 |
if debug:
263 |
264 |
# define env and policy
265 |
env = self.env.clone(caller='evaluator')
266 |
267 |
if seed is not None and isinstance(seed, int):
268 |
seeds = [seed]
269 |
elif seed is not None and isinstance(seed, list):
270 |
seeds = seed
271 |
272 |
seeds = [self.seed]
273 |
274 |
returns = []
275 |
images = []
276 |
if enable_save_replay:
277 |
replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
278 |
279 |
280 |
logging.warning('No video would be generated during the deploy.')
281 |
if concatenate_all_replay:
282 |
logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
283 |
concatenate_all_replay = False
284 |
285 |
def single_env_forward_wrapper(forward_fn, cuda=True):
286 |
287 |
forward_fn = model_wrap(forward_fn, wrapper_name='argmax_sample').forward
288 |
289 |
def _forward(obs):
290 |
# unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
291 |
obs = ttorch.as_tensor(obs).unsqueeze(0)
292 |
if cuda and torch.cuda.is_available():
293 |
obs = obs.cuda()
294 |
action = forward_fn(obs)["action"]
295 |
# squeeze means delete batch dim, i.e. (1, A) -> (A, )
296 |
action = action.squeeze(0).detach().cpu().numpy()
297 |
return action
298 |
299 |
return _forward
300 |
301 |
forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
302 |
303 |
# reset first to make sure the env is in the initial state
304 |
# env will be reset again in the main loop
305 |
306 |
307 |
for seed in seeds:
308 |
env.seed(seed, dynamic_seed=False)
309 |
return_ = 0.
310 |
step = 0
311 |
obs = env.reset()
312 |
images.append(render(env)[None]) if concatenate_all_replay else None
313 |
while True:
314 |
action = forward_fn(obs)
315 |
obs, rew, done, info = env.step(action)
316 |
images.append(render(env)[None]) if concatenate_all_replay else None
317 |
return_ += rew
318 |
step += 1
319 |
if done:
320 |
321 |
+'C51 deploy is finished, final episode return with {step} steps is: {return_}')
322 |
323 |
324 |
325 |
326 |
if concatenate_all_replay:
327 |
images = np.concatenate(images, axis=0)
328 |
import imageio
329 |
imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
330 |
331 |
return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
332 |
333 |
def collect_data(
334 |
335 |
env_num: int = 8,
336 |
save_data_path: Optional[str] = None,
337 |
n_sample: Optional[int] = None,
338 |
n_episode: Optional[int] = None,
339 |
context: Optional[str] = None,
340 |
debug: bool = False
341 |
) -> None:
342 |
343 |
344 |
Collect data with C51 algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
345 |
The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
346 |
347 |
348 |
- env_num (:obj:`int`): The number of collector environments. Default to 8.
349 |
- save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
350 |
If not specified, the data will be saved in ``exp_name/demo_data``.
351 |
- n_sample (:obj:`int`): The number of samples to collect. Default to None. \
352 |
If not specified, ``n_episode`` must be specified.
353 |
- n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
354 |
If not specified, ``n_sample`` must be specified.
355 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
356 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
357 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
358 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
359 |
subprocess environment manager will be used.
360 |
361 |
362 |
if debug:
363 |
364 |
if n_episode is not None:
365 |
raise NotImplementedError
366 |
# define env and policy
367 |
env_num = env_num if env_num else self.cfg.env.collector_env_num
368 |
env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
369 |
370 |
if save_data_path is None:
371 |
save_data_path = os.path.join(self.exp_name, 'demo_data')
372 |
373 |
# main execution task
374 |
with task.start(ctx=OnlineRLContext()):
375 |
376 |
377 |
self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
378 |
379 |
380 |
task.use(offline_data_saver(save_data_path, data_type='hdf5'))
381 |
382 |
383 |
f'C51 collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
384 |
385 |
386 |
def batch_evaluate(
387 |
388 |
env_num: int = 4,
389 |
n_evaluator_episode: int = 4,
390 |
context: Optional[str] = None,
391 |
debug: bool = False
392 |
) -> EvalReturn:
393 |
394 |
395 |
Evaluate the agent with C51 algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
396 |
environments. The evaluation result will be returned.
397 |
The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
398 |
multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
399 |
will only create one evaluator environment to evaluate the agent and save the replay video.
400 |
401 |
- env_num (:obj:`int`): The number of evaluator environments. Default to 4.
402 |
- n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
403 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
404 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
405 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
406 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
407 |
subprocess environment manager will be used.
408 |
409 |
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
410 |
- eval_value (:obj:`np.float32`): The mean of evaluation return.
411 |
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
412 |
413 |
414 |
if debug:
415 |
416 |
# define env and policy
417 |
env_num = env_num if env_num else self.cfg.env.evaluator_env_num
418 |
env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
419 |
420 |
# reset first to make sure the env is in the initial state
421 |
# env will be reset again in the main loop
422 |
423 |
424 |
425 |
evaluate_cfg = self.cfg
426 |
evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
427 |
428 |
# main execution task
429 |
with task.start(ctx=OnlineRLContext()):
430 |
task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
431 |
432 |
433 |
return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
434 |
435 |
436 |
def best(self) -> 'C51Agent':
437 |
438 |
439 |
Load the best model from the checkpoint directory, \
440 |
which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
441 |
The return value is the agent with the best model.
442 |
443 |
- (:obj:`C51Agent`): The agent with the best model.
444 |
445 |
>>> agent = C51Agent(env_id='LunarLander-v2')
446 |
>>> agent.train()
447 |
>>> agent =
448 |
449 |
.. note::
450 |
The best model is the model with the highest evaluation return. If this method is called, the current \
451 |
model will be replaced by the best model.
452 |
453 |
454 |
best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
455 |
# Load best model if it exists
456 |
if os.path.exists(best_model_file_path):
457 |
policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
458 |
459 |
return self
@@ -0,0 +1,22 @@
1 |
from dataclasses import dataclass
2 |
import numpy as np
3 |
4 |
5 |
6 |
class TrainingReturn:
7 |
8 |
9 |
wandb_url: The weight & biases (wandb) project url of the trainning experiment.
10 |
11 |
wandb_url: str
12 |
13 |
14 |
15 |
class EvalReturn:
16 |
17 |
18 |
eval_value: The mean of evaluation return.
19 |
eval_value_std: The standard deviation of evaluation return.
20 |
21 |
eval_value: np.float32
22 |
eval_value_std: np.float32
@@ -0,0 +1,326 @@
1 |
from easydict import EasyDict
2 |
import os
3 |
import gym
4 |
from ding.envs import BaseEnv, DingEnvWrapper
5 |
from ding.envs.env_wrappers import MaxAndSkipWrapper, WarpFrameWrapper, ScaledFloatFrameWrapper, FrameStackWrapper, \
6 |
EvalEpisodeReturnWrapper, TransposeWrapper, TimeLimitWrapper, FlatObsWrapper, GymToGymnasiumWrapper
7 |
from ding.policy import PPOFPolicy
8 |
9 |
10 |
def get_instance_config(env_id: str, algorithm: str) -> EasyDict:
11 |
if algorithm == 'PPOF':
12 |
cfg = PPOFPolicy.default_config()
13 |
if env_id == 'LunarLander-v2':
14 |
cfg.n_sample = 512
15 |
cfg.value_norm = 'popart'
16 |
cfg.entropy_weight = 1e-3
17 |
elif env_id == 'LunarLanderContinuous-v2':
18 |
cfg.action_space = 'continuous'
19 |
cfg.n_sample = 400
20 |
elif env_id == 'BipedalWalker-v3':
21 |
cfg.learning_rate = 1e-3
22 |
cfg.action_space = 'continuous'
23 |
cfg.n_sample = 1024
24 |
elif env_id == 'Pendulum-v1':
25 |
cfg.action_space = 'continuous'
26 |
cfg.n_sample = 400
27 |
elif env_id == 'acrobot':
28 |
cfg.learning_rate = 1e-4
29 |
cfg.n_sample = 400
30 |
elif env_id == 'rocket_landing':
31 |
cfg.n_sample = 2048
32 |
cfg.adv_norm = False
33 |
cfg.model = dict(
34 |
encoder_hidden_size_list=[64, 64, 128],
35 |
36 |
37 |
38 |
elif env_id == 'drone_fly':
39 |
cfg.action_space = 'continuous'
40 |
cfg.adv_norm = False
41 |
cfg.epoch_per_collect = 5
42 |
cfg.learning_rate = 5e-5
43 |
cfg.n_sample = 640
44 |
elif env_id == 'hybrid_moving':
45 |
cfg.action_space = 'hybrid'
46 |
cfg.n_sample = 3200
47 |
cfg.entropy_weight = 0.03
48 |
cfg.batch_size = 320
49 |
cfg.adv_norm = False
50 |
cfg.model = dict(
51 |
encoder_hidden_size_list=[256, 128, 64, 64],
52 |
53 |
54 |
55 |
56 |
elif env_id == 'evogym_carrier':
57 |
cfg.action_space = 'continuous'
58 |
cfg.n_sample = 2048
59 |
cfg.batch_size = 256
60 |
cfg.epoch_per_collect = 10
61 |
cfg.learning_rate = 3e-3
62 |
elif env_id == 'mario':
63 |
cfg.n_sample = 256
64 |
cfg.batch_size = 64
65 |
cfg.epoch_per_collect = 2
66 |
cfg.learning_rate = 1e-3
67 |
cfg.model = dict(
68 |
encoder_hidden_size_list=[64, 64, 128],
69 |
70 |
71 |
72 |
elif env_id == 'di_sheep':
73 |
cfg.n_sample = 3200
74 |
cfg.batch_size = 320
75 |
cfg.epoch_per_collect = 10
76 |
cfg.learning_rate = 3e-4
77 |
cfg.adv_norm = False
78 |
cfg.entropy_weight = 0.001
79 |
elif env_id == 'procgen_bigfish':
80 |
cfg.n_sample = 16384
81 |
cfg.batch_size = 16384
82 |
cfg.epoch_per_collect = 10
83 |
cfg.learning_rate = 5e-4
84 |
cfg.model = dict(
85 |
encoder_hidden_size_list=[64, 128, 256],
86 |
87 |
88 |
89 |
elif env_id in ['KangarooNoFrameskip-v4', 'BowlingNoFrameskip-v4']:
90 |
cfg.n_sample = 1024
91 |
cfg.batch_size = 128
92 |
cfg.epoch_per_collect = 10
93 |
cfg.learning_rate = 0.0001
94 |
cfg.model = dict(
95 |
encoder_hidden_size_list=[32, 64, 64, 128],
96 |
97 |
98 |
99 |
100 |
elif env_id == 'PongNoFrameskip-v4':
101 |
cfg.n_sample = 3200
102 |
cfg.batch_size = 320
103 |
cfg.epoch_per_collect = 10
104 |
cfg.learning_rate = 3e-4
105 |
cfg.model = dict(
106 |
encoder_hidden_size_list=[64, 64, 128],
107 |
108 |
109 |
110 |
elif env_id == 'SpaceInvadersNoFrameskip-v4':
111 |
cfg.n_sample = 320
112 |
cfg.batch_size = 320
113 |
cfg.epoch_per_collect = 1
114 |
cfg.learning_rate = 1e-3
115 |
cfg.entropy_weight = 0.01
116 |
cfg.lr_scheduler = (2000, 0.1)
117 |
cfg.model = dict(
118 |
encoder_hidden_size_list=[64, 64, 128],
119 |
120 |
121 |
122 |
elif env_id == 'QbertNoFrameskip-v4':
123 |
cfg.n_sample = 3200
124 |
cfg.batch_size = 320
125 |
cfg.epoch_per_collect = 10
126 |
cfg.learning_rate = 5e-4
127 |
cfg.lr_scheduler = (1000, 0.1)
128 |
cfg.model = dict(
129 |
encoder_hidden_size_list=[64, 64, 128],
130 |
131 |
132 |
133 |
elif env_id == 'minigrid_fourroom':
134 |
cfg.n_sample = 3200
135 |
cfg.batch_size = 320
136 |
cfg.learning_rate = 3e-4
137 |
cfg.epoch_per_collect = 10
138 |
cfg.entropy_weight = 0.001
139 |
elif env_id == 'metadrive':
140 |
cfg.learning_rate = 3e-4
141 |
cfg.action_space = 'continuous'
142 |
cfg.entropy_weight = 0.001
143 |
cfg.n_sample = 3000
144 |
cfg.epoch_per_collect = 10
145 |
cfg.learning_rate = 0.0001
146 |
cfg.model = dict(
147 |
encoder_hidden_size_list=[32, 64, 64, 128],
148 |
149 |
150 |
151 |
152 |
elif env_id == 'Hopper-v3':
153 |
cfg.action_space = "continuous"
154 |
cfg.n_sample = 3200
155 |
cfg.batch_size = 320
156 |
cfg.epoch_per_collect = 10
157 |
cfg.learning_rate = 3e-4
158 |
elif env_id == 'HalfCheetah-v3':
159 |
cfg.action_space = "continuous"
160 |
cfg.n_sample = 3200
161 |
cfg.batch_size = 320
162 |
cfg.epoch_per_collect = 10
163 |
cfg.learning_rate = 3e-4
164 |
elif env_id == 'Walker2d-v3':
165 |
cfg.action_space = "continuous"
166 |
cfg.n_sample = 3200
167 |
cfg.batch_size = 320
168 |
cfg.epoch_per_collect = 10
169 |
cfg.learning_rate = 3e-4
170 |
171 |
raise KeyError("not supported env type: {}".format(env_id))
172 |
173 |
raise KeyError("not supported algorithm type: {}".format(algorithm))
174 |
175 |
return cfg
176 |
177 |
178 |
def get_instance_env(env_id: str) -> BaseEnv:
179 |
if env_id == 'LunarLander-v2':
180 |
return DingEnvWrapper(gym.make('LunarLander-v2'))
181 |
elif env_id == 'LunarLanderContinuous-v2':
182 |
return DingEnvWrapper(gym.make('LunarLanderContinuous-v2', continuous=True))
183 |
elif env_id == 'BipedalWalker-v3':
184 |
return DingEnvWrapper(gym.make('BipedalWalker-v3'), cfg={'act_scale': True, 'rew_clip': True})
185 |
elif env_id == 'Pendulum-v1':
186 |
return DingEnvWrapper(gym.make('Pendulum-v1'), cfg={'act_scale': True})
187 |
elif env_id == 'acrobot':
188 |
return DingEnvWrapper(gym.make('Acrobot-v1'))
189 |
elif env_id == 'rocket_landing':
190 |
from dizoo.rocket.envs import RocketEnv
191 |
cfg = EasyDict({
192 |
'task': 'landing',
193 |
'max_steps': 800,
194 |
195 |
return RocketEnv(cfg)
196 |
elif env_id == 'drone_fly':
197 |
from dizoo.gym_pybullet_drones.envs import GymPybulletDronesEnv
198 |
cfg = EasyDict({
199 |
'env_id': 'flythrugate-aviary-v0',
200 |
'action_type': 'VEL',
201 |
202 |
return GymPybulletDronesEnv(cfg)
203 |
elif env_id == 'hybrid_moving':
204 |
import gym_hybrid
205 |
return DingEnvWrapper(gym.make('Moving-v0'))
206 |
elif env_id == 'evogym_carrier':
207 |
import evogym.envs
208 |
from evogym import sample_robot, WorldObject
209 |
path = os.path.join(os.path.dirname(__file__), '../../dizoo/evogym/envs/world_data/carry_bot.json')
210 |
robot_object = WorldObject.from_json(path)
211 |
body = robot_object.get_structure()
212 |
return DingEnvWrapper(
213 |
gym.make('Carrier-v0', body=body),
214 |
215 |
'env_wrapper': [
216 |
lambda env: TimeLimitWrapper(env, max_limit=300),
217 |
lambda env: EvalEpisodeReturnWrapper(env),
218 |
219 |
220 |
221 |
elif env_id == 'mario':
222 |
import gym_super_mario_bros
223 |
from nes_py.wrappers import JoypadSpace
224 |
return DingEnvWrapper(
225 |
JoypadSpace(gym_super_mario_bros.make("SuperMarioBros-1-1-v1"), [["right"], ["right", "A"]]),
226 |
227 |
'env_wrapper': [
228 |
lambda env: MaxAndSkipWrapper(env, skip=4),
229 |
lambda env: WarpFrameWrapper(env, size=84),
230 |
lambda env: ScaledFloatFrameWrapper(env),
231 |
lambda env: FrameStackWrapper(env, n_frames=4),
232 |
lambda env: TimeLimitWrapper(env, max_limit=200),
233 |
lambda env: EvalEpisodeReturnWrapper(env),
234 |
235 |
236 |
237 |
elif env_id == 'di_sheep':
238 |
from sheep_env import SheepEnv
239 |
return DingEnvWrapper(SheepEnv(level=9))
240 |
elif env_id == 'procgen_bigfish':
241 |
return DingEnvWrapper(
242 |
gym.make('procgen:procgen-bigfish-v0', start_level=0, num_levels=1),
243 |
244 |
'env_wrapper': [
245 |
lambda env: TransposeWrapper(env),
246 |
lambda env: ScaledFloatFrameWrapper(env),
247 |
lambda env: EvalEpisodeReturnWrapper(env),
248 |
249 |
250 |
251 |
252 |
elif env_id == 'Hopper-v3':
253 |
cfg = EasyDict(
254 |
255 |
256 |
257 |
258 |
259 |
return DingEnvWrapper(gym.make('Hopper-v3'), cfg=cfg)
260 |
elif env_id == 'HalfCheetah-v3':
261 |
cfg = EasyDict(
262 |
263 |
264 |
265 |
266 |
267 |
return DingEnvWrapper(gym.make('HalfCheetah-v3'), cfg=cfg)
268 |
elif env_id == 'Walker2d-v3':
269 |
cfg = EasyDict(
270 |
271 |
272 |
273 |
274 |
275 |
return DingEnvWrapper(gym.make('Walker2d-v3'), cfg=cfg)
276 |
277 |
elif env_id in [
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
cfg = EasyDict({
288 |
'env_id': env_id,
289 |
'env_wrapper': 'atari_default',
290 |
291 |
ding_env_atari = DingEnvWrapper(gym.make(env_id), cfg=cfg)
292 |
return ding_env_atari
293 |
elif env_id == 'minigrid_fourroom':
294 |
import gymnasium
295 |
return DingEnvWrapper(
296 |
297 |
298 |
'env_wrapper': [
299 |
lambda env: GymToGymnasiumWrapper(env),
300 |
lambda env: FlatObsWrapper(env),
301 |
lambda env: TimeLimitWrapper(env, max_limit=300),
302 |
lambda env: EvalEpisodeReturnWrapper(env),
303 |
304 |
305 |
306 |
elif env_id == 'metadrive':
307 |
from dizoo.metadrive.env.drive_env import MetaDrivePPOOriginEnv
308 |
from dizoo.metadrive.env.drive_wrapper import DriveEnvWrapper
309 |
cfg = dict(
310 |
311 |
312 |
313 |
314 |
315 |
316 |
cfg = EasyDict(cfg)
317 |
return DriveEnvWrapper(MetaDrivePPOOriginEnv(cfg))
318 |
319 |
raise KeyError("not supported env type: {}".format(env_id))
320 |
321 |
322 |
def get_hybrid_shape(action_space) -> EasyDict:
323 |
return EasyDict({
324 |
'action_type_shape': action_space[0].n,
325 |
'action_args_shape': action_space[1].shape,
326 |
@@ -0,0 +1,456 @@
1 |
from typing import Optional, Union, List
2 |
from ditk import logging
3 |
from easydict import EasyDict
4 |
import os
5 |
import numpy as np
6 |
import torch
7 |
import treetensor.torch as ttorch
8 |
from ding.framework import task, OnlineRLContext
9 |
from ding.framework.middleware import CkptSaver, \
10 |
wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, data_pusher, \
11 |
OffPolicyLearner, final_ctx_saver
12 |
from ding.envs import BaseEnv
13 |
from ding.envs import setup_ding_env_manager
14 |
from ding.policy import DDPGPolicy
15 |
from ding.utils import set_pkg_seed
16 |
from ding.utils import get_env_fps, render
17 |
from ding.config import save_config_py, compile_config
18 |
from ding.model import ContinuousQAC
19 |
from import DequeBuffer
20 |
from ding.bonus.common import TrainingReturn, EvalReturn
21 |
from ding.config.example.DDPG import supported_env_cfg
22 |
from ding.config.example.DDPG import supported_env
23 |
24 |
25 |
class DDPGAgent:
26 |
27 |
28 |
Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
29 |
Deep Deterministic Policy Gradient(DDPG).
30 |
For more information about the system design of RL agent, please refer to \
31 |
32 |
33 |
``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
34 |
35 |
supported_env_list = list(supported_env_cfg.keys())
36 |
37 |
38 |
List of supported envs.
39 |
40 |
>>> from ding.bonus.ddpg import DDPGAgent
41 |
>>> print(DDPGAgent.supported_env_list)
42 |
43 |
44 |
def __init__(
45 |
46 |
env_id: str = None,
47 |
env: BaseEnv = None,
48 |
seed: int = 0,
49 |
exp_name: str = None,
50 |
model: Optional[torch.nn.Module] = None,
51 |
cfg: Optional[Union[EasyDict, dict]] = None,
52 |
policy_state_dict: str = None,
53 |
) -> None:
54 |
55 |
56 |
Initialize agent for DDPG algorithm.
57 |
58 |
- env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
59 |
If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
60 |
If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
61 |
``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
62 |
- env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
63 |
If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
64 |
``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
65 |
If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
66 |
- seed (:obj:`int`): The random seed, which is set before running the program. \
67 |
Default to 0.
68 |
- exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
69 |
log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
70 |
- model (:obj:`torch.nn.Module`): The model of DDPG algorithm, which should be an instance of class \
71 |
:class:`ding.model.ContinuousQAC`. \
72 |
If not specified, a default model will be generated according to the configuration.
73 |
- cfg (:obj:Union[EasyDict, dict]): The configuration of DDPG algorithm, which is a dict. \
74 |
Default to None. If not specified, the default configuration will be used. \
75 |
The default configuration can be found in ``ding/config/example/DDPG/``.
76 |
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
77 |
If specified, the policy will be loaded from this file. Default to None.
78 |
79 |
.. note::
80 |
An RL Agent Instance can be initialized in two basic ways. \
81 |
For example, we have an environment with id ``LunarLanderContinuous-v2`` registered in gym, \
82 |
and we want to train an agent with DDPG algorithm with default configuration. \
83 |
Then we can initialize the agent in the following ways:
84 |
>>> agent = DDPGAgent(env_id='LunarLanderContinuous-v2')
85 |
or, if we want can specify the env_id in the configuration:
86 |
>>> cfg = {'env': {'env_id': 'LunarLanderContinuous-v2'}, 'policy': ...... }
87 |
>>> agent = DDPGAgent(cfg=cfg)
88 |
There are also other arguments to specify the agent when initializing.
89 |
For example, if we want to specify the environment instance:
90 |
>>> env = CustomizedEnv('LunarLanderContinuous-v2')
91 |
>>> agent = DDPGAgent(cfg=cfg, env=env)
92 |
or, if we want to specify the model:
93 |
>>> model = ContinuousQAC(**cfg.policy.model)
94 |
>>> agent = DDPGAgent(cfg=cfg, model=model)
95 |
or, if we want to reload the policy from a saved policy state dict:
96 |
>>> agent = DDPGAgent(cfg=cfg, policy_state_dict='LunarLanderContinuous-v2.pth.tar')
97 |
Make sure that the configuration is consistent with the saved policy state dict.
98 |
99 |
100 |
assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
101 |
102 |
if cfg is not None and not isinstance(cfg, EasyDict):
103 |
cfg = EasyDict(cfg)
104 |
105 |
if env_id is not None:
106 |
assert env_id in DDPGAgent.supported_env_list, "Please use supported envs: {}".format(
107 |
108 |
109 |
if cfg is None:
110 |
cfg = supported_env_cfg[env_id]
111 |
112 |
assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
113 |
114 |
assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
115 |
assert cfg.env.env_id in DDPGAgent.supported_env_list, "Please use supported envs: {}".format(
116 |
117 |
118 |
default_policy_config = EasyDict({"policy": DDPGPolicy.default_config()})
119 |
120 |
cfg = default_policy_config
121 |
122 |
if exp_name is not None:
123 |
cfg.exp_name = exp_name
124 |
self.cfg = compile_config(cfg, policy=DDPGPolicy)
125 |
self.exp_name = self.cfg.exp_name
126 |
if env is None:
127 |
self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
128 |
129 |
assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
130 |
self.env = env
131 |
132 |
133 |
self.seed = seed
134 |
set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
135 |
if not os.path.exists(self.exp_name):
136 |
137 |
save_config_py(self.cfg, os.path.join(self.exp_name, ''))
138 |
if model is None:
139 |
model = ContinuousQAC(**self.cfg.policy.model)
140 |
self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
141 |
self.policy = DDPGPolicy(self.cfg.policy, model=model)
142 |
if policy_state_dict is not None:
143 |
144 |
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
145 |
146 |
def train(
147 |
148 |
step: int = int(1e7),
149 |
collector_env_num: int = None,
150 |
evaluator_env_num: int = None,
151 |
n_iter_log_show: int = 500,
152 |
n_iter_save_ckpt: int = 1000,
153 |
context: Optional[str] = None,
154 |
debug: bool = False,
155 |
wandb_sweep: bool = False,
156 |
) -> TrainingReturn:
157 |
158 |
159 |
Train the agent with DDPG algorithm for ``step`` iterations with ``collector_env_num`` collector \
160 |
environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
161 |
recorded and saved by wandb.
162 |
163 |
- step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
164 |
- collector_env_num (:obj:`int`): The collector environment number. Default to None. \
165 |
If not specified, it will be set according to the configuration.
166 |
- evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
167 |
If not specified, it will be set according to the configuration.
168 |
- n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
169 |
Default to 1000.
170 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
171 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
172 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
173 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
174 |
subprocess environment manager will be used.
175 |
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
176 |
which is a hyper-parameter optimization process for seeking the best configurations. \
177 |
Default to False. If True, the wandb sweep id will be used as the experiment name.
178 |
179 |
- (:obj:`TrainingReturn`): The training result, of which the attributions are:
180 |
- wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
181 |
182 |
183 |
if debug:
184 |
185 |
186 |
# define env and policy
187 |
collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
188 |
evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
189 |
collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
190 |
evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
191 |
192 |
with task.start(ctx=OnlineRLContext()):
193 |
194 |
195 |
196 |
197 |
198 |
render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
199 |
200 |
201 |
task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
202 |
203 |
204 |
205 |
206 |
207 |
208 |
if hasattr(self.cfg.policy, 'random_collect_size') else 0,
209 |
210 |
211 |
task.use(data_pusher(self.cfg, self.buffer_))
212 |
task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
return TrainingReturn(wandb_url=task.ctx.wandb_url)
227 |
228 |
def deploy(
229 |
230 |
enable_save_replay: bool = False,
231 |
concatenate_all_replay: bool = False,
232 |
replay_save_path: str = None,
233 |
seed: Optional[Union[int, List]] = None,
234 |
debug: bool = False
235 |
) -> EvalReturn:
236 |
237 |
238 |
Deploy the agent with DDPG algorithm by interacting with the environment, during which the replay video \
239 |
can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
240 |
241 |
- enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
242 |
- concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
243 |
Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
244 |
If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
245 |
the replay video of each episode will be saved separately.
246 |
- replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
247 |
If not specified, the video will be saved in ``exp_name/videos``.
248 |
- seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
249 |
Default to None. If not specified, ``self.seed`` will be used. \
250 |
If ``seed`` is an integer, the agent will be deployed once. \
251 |
If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
252 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
253 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
254 |
subprocess environment manager will be used.
255 |
256 |
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
257 |
- eval_value (:obj:`np.float32`): The mean of evaluation return.
258 |
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
259 |
260 |
261 |
if debug:
262 |
263 |
# define env and policy
264 |
env = self.env.clone(caller='evaluator')
265 |
266 |
if seed is not None and isinstance(seed, int):
267 |
seeds = [seed]
268 |
elif seed is not None and isinstance(seed, list):
269 |
seeds = seed
270 |
271 |
seeds = [self.seed]
272 |
273 |
returns = []
274 |
images = []
275 |
if enable_save_replay:
276 |
replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
277 |
278 |
279 |
logging.warning('No video would be generated during the deploy.')
280 |
if concatenate_all_replay:
281 |
logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
282 |
concatenate_all_replay = False
283 |
284 |
def single_env_forward_wrapper(forward_fn, cuda=True):
285 |
286 |
def _forward(obs):
287 |
# unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
288 |
obs = ttorch.as_tensor(obs).unsqueeze(0)
289 |
if cuda and torch.cuda.is_available():
290 |
obs = obs.cuda()
291 |
action = forward_fn(obs, mode='compute_actor')["action"]
292 |
# squeeze means delete batch dim, i.e. (1, A) -> (A, )
293 |
action = action.squeeze(0).detach().cpu().numpy()
294 |
return action
295 |
296 |
return _forward
297 |
298 |
forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
299 |
300 |
# reset first to make sure the env is in the initial state
301 |
# env will be reset again in the main loop
302 |
303 |
304 |
for seed in seeds:
305 |
env.seed(seed, dynamic_seed=False)
306 |
return_ = 0.
307 |
step = 0
308 |
obs = env.reset()
309 |
images.append(render(env)[None]) if concatenate_all_replay else None
310 |
while True:
311 |
action = forward_fn(obs)
312 |
obs, rew, done, info = env.step(action)
313 |
images.append(render(env)[None]) if concatenate_all_replay else None
314 |
return_ += rew
315 |
step += 1
316 |
if done:
317 |
318 |
+'DDPG deploy is finished, final episode return with {step} steps is: {return_}')
319 |
320 |
321 |
322 |
323 |
if concatenate_all_replay:
324 |
images = np.concatenate(images, axis=0)
325 |
import imageio
326 |
imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
327 |
328 |
return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
329 |
330 |
def collect_data(
331 |
332 |
env_num: int = 8,
333 |
save_data_path: Optional[str] = None,
334 |
n_sample: Optional[int] = None,
335 |
n_episode: Optional[int] = None,
336 |
context: Optional[str] = None,
337 |
debug: bool = False
338 |
) -> None:
339 |
340 |
341 |
Collect data with DDPG algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
342 |
The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
343 |
344 |
345 |
- env_num (:obj:`int`): The number of collector environments. Default to 8.
346 |
- save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
347 |
If not specified, the data will be saved in ``exp_name/demo_data``.
348 |
- n_sample (:obj:`int`): The number of samples to collect. Default to None. \
349 |
If not specified, ``n_episode`` must be specified.
350 |
- n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
351 |
If not specified, ``n_sample`` must be specified.
352 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
353 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
354 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
355 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
356 |
subprocess environment manager will be used.
357 |
358 |
359 |
if debug:
360 |
361 |
if n_episode is not None:
362 |
raise NotImplementedError
363 |
# define env and policy
364 |
env_num = env_num if env_num else self.cfg.env.collector_env_num
365 |
env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
366 |
367 |
if save_data_path is None:
368 |
save_data_path = os.path.join(self.exp_name, 'demo_data')
369 |
370 |
# main execution task
371 |
with task.start(ctx=OnlineRLContext()):
372 |
373 |
374 |
self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
375 |
376 |
377 |
task.use(offline_data_saver(save_data_path, data_type='hdf5'))
378 |
379 |
380 |
f'DDPG collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
381 |
382 |
383 |
def batch_evaluate(
384 |
385 |
env_num: int = 4,
386 |
n_evaluator_episode: int = 4,
387 |
context: Optional[str] = None,
388 |
debug: bool = False
389 |
) -> EvalReturn:
390 |
391 |
392 |
Evaluate the agent with DDPG algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
393 |
environments. The evaluation result will be returned.
394 |
The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
395 |
multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
396 |
will only create one evaluator environment to evaluate the agent and save the replay video.
397 |
398 |
- env_num (:obj:`int`): The number of evaluator environments. Default to 4.
399 |
- n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
400 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
401 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
402 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
403 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
404 |
subprocess environment manager will be used.
405 |
406 |
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
407 |
- eval_value (:obj:`np.float32`): The mean of evaluation return.
408 |
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
409 |
410 |
411 |
if debug:
412 |
413 |
# define env and policy
414 |
env_num = env_num if env_num else self.cfg.env.evaluator_env_num
415 |
env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
416 |
417 |
# reset first to make sure the env is in the initial state
418 |
# env will be reset again in the main loop
419 |
420 |
421 |
422 |
evaluate_cfg = self.cfg
423 |
evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
424 |
425 |
# main execution task
426 |
with task.start(ctx=OnlineRLContext()):
427 |
task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
428 |
429 |
430 |
return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
431 |
432 |
433 |
def best(self) -> 'DDPGAgent':
434 |
435 |
436 |
Load the best model from the checkpoint directory, \
437 |
which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
438 |
The return value is the agent with the best model.
439 |
440 |
- (:obj:`DDPGAgent`): The agent with the best model.
441 |
442 |
>>> agent = DDPGAgent(env_id='LunarLanderContinuous-v2')
443 |
>>> agent.train()
444 |
>>> agent =
445 |
446 |
.. note::
447 |
The best model is the model with the highest evaluation return. If this method is called, the current \
448 |
model will be replaced by the best model.
449 |
450 |
451 |
best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
452 |
# Load best model if it exists
453 |
if os.path.exists(best_model_file_path):
454 |
policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
455 |
456 |
return self
@@ -0,0 +1,460 @@
1 |
from typing import Optional, Union, List
2 |
from ditk import logging
3 |
from easydict import EasyDict
4 |
import os
5 |
import numpy as np
6 |
import torch
7 |
import treetensor.torch as ttorch
8 |
from ding.framework import task, OnlineRLContext
9 |
from ding.framework.middleware import CkptSaver, \
10 |
wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, data_pusher, \
11 |
OffPolicyLearner, final_ctx_saver, nstep_reward_enhancer, eps_greedy_handler
12 |
from ding.envs import BaseEnv
13 |
from ding.envs import setup_ding_env_manager
14 |
from ding.policy import DQNPolicy
15 |
from ding.utils import set_pkg_seed
16 |
from ding.utils import get_env_fps, render
17 |
from ding.config import save_config_py, compile_config
18 |
from ding.model import DQN
19 |
from ding.model import model_wrap
20 |
from import DequeBuffer
21 |
from ding.bonus.common import TrainingReturn, EvalReturn
22 |
from ding.config.example.DQN import supported_env_cfg
23 |
from ding.config.example.DQN import supported_env
24 |
25 |
26 |
class DQNAgent:
27 |
28 |
29 |
Class of agent for training, evaluation and deployment of Reinforcement learning algorithm Deep Q-Learning(DQN).
30 |
For more information about the system design of RL agent, please refer to \
31 |
32 |
33 |
``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
34 |
35 |
supported_env_list = list(supported_env_cfg.keys())
36 |
37 |
38 |
List of supported envs.
39 |
40 |
>>> from ding.bonus.dqn import DQNAgent
41 |
>>> print(DQNAgent.supported_env_list)
42 |
43 |
44 |
def __init__(
45 |
46 |
env_id: str = None,
47 |
env: BaseEnv = None,
48 |
seed: int = 0,
49 |
exp_name: str = None,
50 |
model: Optional[torch.nn.Module] = None,
51 |
cfg: Optional[Union[EasyDict, dict]] = None,
52 |
policy_state_dict: str = None,
53 |
) -> None:
54 |
55 |
56 |
Initialize agent for DQN algorithm.
57 |
58 |
- env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
59 |
If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
60 |
If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
61 |
``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
62 |
- env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
63 |
If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
64 |
``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
65 |
If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
66 |
- seed (:obj:`int`): The random seed, which is set before running the program. \
67 |
Default to 0.
68 |
- exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
69 |
log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
70 |
- model (:obj:`torch.nn.Module`): The model of DQN algorithm, which should be an instance of class \
71 |
:class:`ding.model.DQN`. \
72 |
If not specified, a default model will be generated according to the configuration.
73 |
- cfg (:obj:Union[EasyDict, dict]): The configuration of DQN algorithm, which is a dict. \
74 |
Default to None. If not specified, the default configuration will be used. \
75 |
The default configuration can be found in ``ding/config/example/DQN/``.
76 |
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
77 |
If specified, the policy will be loaded from this file. Default to None.
78 |
79 |
.. note::
80 |
An RL Agent Instance can be initialized in two basic ways. \
81 |
For example, we have an environment with id ``LunarLander-v2`` registered in gym, \
82 |
and we want to train an agent with DQN algorithm with default configuration. \
83 |
Then we can initialize the agent in the following ways:
84 |
>>> agent = DQNAgent(env_id='LunarLander-v2')
85 |
or, if we want can specify the env_id in the configuration:
86 |
>>> cfg = {'env': {'env_id': 'LunarLander-v2'}, 'policy': ...... }
87 |
>>> agent = DQNAgent(cfg=cfg)
88 |
There are also other arguments to specify the agent when initializing.
89 |
For example, if we want to specify the environment instance:
90 |
>>> env = CustomizedEnv('LunarLander-v2')
91 |
>>> agent = DQNAgent(cfg=cfg, env=env)
92 |
or, if we want to specify the model:
93 |
>>> model = DQN(**cfg.policy.model)
94 |
>>> agent = DQNAgent(cfg=cfg, model=model)
95 |
or, if we want to reload the policy from a saved policy state dict:
96 |
>>> agent = DQNAgent(cfg=cfg, policy_state_dict='LunarLander-v2.pth.tar')
97 |
Make sure that the configuration is consistent with the saved policy state dict.
98 |
99 |
100 |
assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
101 |
102 |
if cfg is not None and not isinstance(cfg, EasyDict):
103 |
cfg = EasyDict(cfg)
104 |
105 |
if env_id is not None:
106 |
assert env_id in DQNAgent.supported_env_list, "Please use supported envs: {}".format(
107 |
108 |
109 |
if cfg is None:
110 |
cfg = supported_env_cfg[env_id]
111 |
112 |
assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
113 |
114 |
assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
115 |
assert cfg.env.env_id in DQNAgent.supported_env_list, "Please use supported envs: {}".format(
116 |
117 |
118 |
default_policy_config = EasyDict({"policy": DQNPolicy.default_config()})
119 |
120 |
cfg = default_policy_config
121 |
122 |
if exp_name is not None:
123 |
cfg.exp_name = exp_name
124 |
self.cfg = compile_config(cfg, policy=DQNPolicy)
125 |
self.exp_name = self.cfg.exp_name
126 |
if env is None:
127 |
self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
128 |
129 |
assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
130 |
self.env = env
131 |
132 |
133 |
self.seed = seed
134 |
set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
135 |
if not os.path.exists(self.exp_name):
136 |
137 |
save_config_py(self.cfg, os.path.join(self.exp_name, ''))
138 |
if model is None:
139 |
model = DQN(**self.cfg.policy.model)
140 |
self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
141 |
self.policy = DQNPolicy(self.cfg.policy, model=model)
142 |
if policy_state_dict is not None:
143 |
144 |
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
145 |
146 |
def train(
147 |
148 |
step: int = int(1e7),
149 |
collector_env_num: int = None,
150 |
evaluator_env_num: int = None,
151 |
n_iter_save_ckpt: int = 1000,
152 |
context: Optional[str] = None,
153 |
debug: bool = False,
154 |
wandb_sweep: bool = False,
155 |
) -> TrainingReturn:
156 |
157 |
158 |
Train the agent with DQN algorithm for ``step`` iterations with ``collector_env_num`` collector \
159 |
environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
160 |
recorded and saved by wandb.
161 |
162 |
- step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
163 |
- collector_env_num (:obj:`int`): The collector environment number. Default to None. \
164 |
If not specified, it will be set according to the configuration.
165 |
- evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
166 |
If not specified, it will be set according to the configuration.
167 |
- n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
168 |
Default to 1000.
169 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
170 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
171 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
172 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
173 |
subprocess environment manager will be used.
174 |
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
175 |
which is a hyper-parameter optimization process for seeking the best configurations. \
176 |
Default to False. If True, the wandb sweep id will be used as the experiment name.
177 |
178 |
- (:obj:`TrainingReturn`): The training result, of which the attributions are:
179 |
- wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
180 |
181 |
182 |
if debug:
183 |
184 |
185 |
# define env and policy
186 |
collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
187 |
evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
188 |
collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
189 |
evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
190 |
191 |
with task.start(ctx=OnlineRLContext()):
192 |
193 |
194 |
195 |
196 |
197 |
render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
198 |
199 |
200 |
task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
if hasattr(self.cfg.policy, 'random_collect_size') else 0,
209 |
210 |
211 |
if "nstep" in self.cfg.policy and self.cfg.policy.nstep > 1:
212 |
213 |
task.use(data_pusher(self.cfg, self.buffer_))
214 |
task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
return TrainingReturn(wandb_url=task.ctx.wandb_url)
229 |
230 |
def deploy(
231 |
232 |
enable_save_replay: bool = False,
233 |
concatenate_all_replay: bool = False,
234 |
replay_save_path: str = None,
235 |
seed: Optional[Union[int, List]] = None,
236 |
debug: bool = False
237 |
) -> EvalReturn:
238 |
239 |
240 |
Deploy the agent with DQN algorithm by interacting with the environment, during which the replay video \
241 |
can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
242 |
243 |
- enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
244 |
- concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
245 |
Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
246 |
If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
247 |
the replay video of each episode will be saved separately.
248 |
- replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
249 |
If not specified, the video will be saved in ``exp_name/videos``.
250 |
- seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
251 |
Default to None. If not specified, ``self.seed`` will be used. \
252 |
If ``seed`` is an integer, the agent will be deployed once. \
253 |
If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
254 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
255 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
256 |
subprocess environment manager will be used.
257 |
258 |
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
259 |
- eval_value (:obj:`np.float32`): The mean of evaluation return.
260 |
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
261 |
262 |
263 |
if debug:
264 |
265 |
# define env and policy
266 |
env = self.env.clone(caller='evaluator')
267 |
268 |
if seed is not None and isinstance(seed, int):
269 |
seeds = [seed]
270 |
elif seed is not None and isinstance(seed, list):
271 |
seeds = seed
272 |
273 |
seeds = [self.seed]
274 |
275 |
returns = []
276 |
images = []
277 |
if enable_save_replay:
278 |
replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
279 |
280 |
281 |
logging.warning('No video would be generated during the deploy.')
282 |
if concatenate_all_replay:
283 |
logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
284 |
concatenate_all_replay = False
285 |
286 |
def single_env_forward_wrapper(forward_fn, cuda=True):
287 |
288 |
forward_fn = model_wrap(forward_fn, wrapper_name='argmax_sample').forward
289 |
290 |
def _forward(obs):
291 |
# unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
292 |
obs = ttorch.as_tensor(obs).unsqueeze(0)
293 |
if cuda and torch.cuda.is_available():
294 |
obs = obs.cuda()
295 |
action = forward_fn(obs)["action"]
296 |
# squeeze means delete batch dim, i.e. (1, A) -> (A, )
297 |
action = action.squeeze(0).detach().cpu().numpy()
298 |
return action
299 |
300 |
return _forward
301 |
302 |
forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
303 |
304 |
# reset first to make sure the env is in the initial state
305 |
# env will be reset again in the main loop
306 |
307 |
308 |
for seed in seeds:
309 |
env.seed(seed, dynamic_seed=False)
310 |
return_ = 0.
311 |
step = 0
312 |
obs = env.reset()
313 |
images.append(render(env)[None]) if concatenate_all_replay else None
314 |
while True:
315 |
action = forward_fn(obs)
316 |
obs, rew, done, info = env.step(action)
317 |
images.append(render(env)[None]) if concatenate_all_replay else None
318 |
return_ += rew
319 |
step += 1
320 |
if done:
321 |
322 |
+'DQN deploy is finished, final episode return with {step} steps is: {return_}')
323 |
324 |
325 |
326 |
327 |
if concatenate_all_replay:
328 |
images = np.concatenate(images, axis=0)
329 |
import imageio
330 |
imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
331 |
332 |
return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
333 |
334 |
def collect_data(
335 |
336 |
env_num: int = 8,
337 |
save_data_path: Optional[str] = None,
338 |
n_sample: Optional[int] = None,
339 |
n_episode: Optional[int] = None,
340 |
context: Optional[str] = None,
341 |
debug: bool = False
342 |
) -> None:
343 |
344 |
345 |
Collect data with DQN algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
346 |
The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
347 |
348 |
349 |
- env_num (:obj:`int`): The number of collector environments. Default to 8.
350 |
- save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
351 |
If not specified, the data will be saved in ``exp_name/demo_data``.
352 |
- n_sample (:obj:`int`): The number of samples to collect. Default to None. \
353 |
If not specified, ``n_episode`` must be specified.
354 |
- n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
355 |
If not specified, ``n_sample`` must be specified.
356 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
357 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
358 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
359 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
360 |
subprocess environment manager will be used.
361 |
362 |
363 |
if debug:
364 |
365 |
if n_episode is not None:
366 |
raise NotImplementedError
367 |
# define env and policy
368 |
env_num = env_num if env_num else self.cfg.env.collector_env_num
369 |
env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
370 |
371 |
if save_data_path is None:
372 |
save_data_path = os.path.join(self.exp_name, 'demo_data')
373 |
374 |
# main execution task
375 |
with task.start(ctx=OnlineRLContext()):
376 |
377 |
378 |
self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
379 |
380 |
381 |
task.use(offline_data_saver(save_data_path, data_type='hdf5'))
382 |
383 |
384 |
f'DQN collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
385 |
386 |
387 |
def batch_evaluate(
388 |
389 |
env_num: int = 4,
390 |
n_evaluator_episode: int = 4,
391 |
context: Optional[str] = None,
392 |
debug: bool = False
393 |
) -> EvalReturn:
394 |
395 |
396 |
Evaluate the agent with DQN algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
397 |
environments. The evaluation result will be returned.
398 |
The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
399 |
multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
400 |
will only create one evaluator environment to evaluate the agent and save the replay video.
401 |
402 |
- env_num (:obj:`int`): The number of evaluator environments. Default to 4.
403 |
- n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
404 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
405 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
406 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
407 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
408 |
subprocess environment manager will be used.
409 |
410 |
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
411 |
- eval_value (:obj:`np.float32`): The mean of evaluation return.
412 |
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
413 |
414 |
415 |
if debug:
416 |
417 |
# define env and policy
418 |
env_num = env_num if env_num else self.cfg.env.evaluator_env_num
419 |
env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
420 |
421 |
# reset first to make sure the env is in the initial state
422 |
# env will be reset again in the main loop
423 |
424 |
425 |
426 |
evaluate_cfg = self.cfg
427 |
evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
428 |
429 |
# main execution task
430 |
with task.start(ctx=OnlineRLContext()):
431 |
task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
432 |
433 |
434 |
return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
435 |
436 |
437 |
def best(self) -> 'DQNAgent':
438 |
439 |
440 |
Load the best model from the checkpoint directory, \
441 |
which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
442 |
The return value is the agent with the best model.
443 |
444 |
- (:obj:`DQNAgent`): The agent with the best model.
445 |
446 |
>>> agent = DQNAgent(env_id='LunarLander-v2')
447 |
>>> agent.train()
448 |
>>> agent =
449 |
450 |
.. note::
451 |
The best model is the model with the highest evaluation return. If this method is called, the current \
452 |
model will be replaced by the best model.
453 |
454 |
455 |
best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
456 |
# Load best model if it exists
457 |
if os.path.exists(best_model_file_path):
458 |
policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
459 |
460 |
return self
@@ -0,0 +1,245 @@
1 |
from typing import Union, Optional
2 |
from easydict import EasyDict
3 |
import torch
4 |
import torch.nn as nn
5 |
import treetensor.torch as ttorch
6 |
from copy import deepcopy
7 |
from ding.utils import SequenceType, squeeze
8 |
from ding.model.common import ReparameterizationHead, RegressionHead, MultiHead, \
9 |
FCEncoder, ConvEncoder, IMPALAConvEncoder, PopArtVHead
10 |
from ding.torch_utils import MLP, fc_block
11 |
12 |
13 |
class DiscretePolicyHead(nn.Module):
14 |
15 |
def __init__(
16 |
17 |
hidden_size: int,
18 |
output_size: int,
19 |
layer_num: int = 1,
20 |
activation: Optional[nn.Module] = nn.ReLU(),
21 |
norm_type: Optional[str] = None,
22 |
) -> None:
23 |
super(DiscretePolicyHead, self).__init__()
24 |
self.main = nn.Sequential(
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
), fc_block(hidden_size, output_size)
34 |
35 |
36 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
37 |
return self.main(x)
38 |
39 |
40 |
class PPOFModel(nn.Module):
41 |
mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']
42 |
43 |
def __init__(
44 |
45 |
obs_shape: Union[int, SequenceType],
46 |
action_shape: Union[int, SequenceType, EasyDict],
47 |
action_space: str = 'discrete',
48 |
share_encoder: bool = True,
49 |
encoder_hidden_size_list: SequenceType = [128, 128, 64],
50 |
actor_head_hidden_size: int = 64,
51 |
actor_head_layer_num: int = 1,
52 |
critic_head_hidden_size: int = 64,
53 |
critic_head_layer_num: int = 1,
54 |
activation: Optional[nn.Module] = nn.ReLU(),
55 |
norm_type: Optional[str] = None,
56 |
sigma_type: Optional[str] = 'independent',
57 |
fixed_sigma_value: Optional[int] = 0.3,
58 |
bound_type: Optional[str] = None,
59 |
encoder: Optional[torch.nn.Module] = None,
60 |
61 |
) -> None:
62 |
super(PPOFModel, self).__init__()
63 |
obs_shape = squeeze(obs_shape)
64 |
action_shape = squeeze(action_shape)
65 |
self.obs_shape, self.action_shape = obs_shape, action_shape
66 |
self.share_encoder = share_encoder
67 |
68 |
# Encoder Type
69 |
def new_encoder(outsize):
70 |
if isinstance(obs_shape, int) or len(obs_shape) == 1:
71 |
return FCEncoder(
72 |
73 |
74 |
75 |
76 |
77 |
elif len(obs_shape) == 3:
78 |
return ConvEncoder(
79 |
80 |
81 |
82 |
83 |
84 |
85 |
raise RuntimeError(
86 |
"not support obs_shape for pre-defined encoder: {}, please customize your own encoder".
87 |
88 |
89 |
90 |
if self.share_encoder:
91 |
assert actor_head_hidden_size == critic_head_hidden_size, \
92 |
"actor and critic network head should have same size."
93 |
if encoder:
94 |
if isinstance(encoder, torch.nn.Module):
95 |
self.encoder = encoder
96 |
97 |
raise ValueError("illegal encoder instance.")
98 |
99 |
self.encoder = new_encoder(actor_head_hidden_size)
100 |
101 |
if encoder:
102 |
if isinstance(encoder, torch.nn.Module):
103 |
self.actor_encoder = encoder
104 |
self.critic_encoder = deepcopy(encoder)
105 |
106 |
raise ValueError("illegal encoder instance.")
107 |
108 |
self.actor_encoder = new_encoder(actor_head_hidden_size)
109 |
self.critic_encoder = new_encoder(critic_head_hidden_size)
110 |
111 |
# Head Type
112 |
if not popart_head:
113 |
self.critic_head = RegressionHead(
114 |
critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type
115 |
116 |
117 |
self.critic_head = PopArtVHead(
118 |
critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type
119 |
120 |
121 |
self.action_space = action_space
122 |
assert self.action_space in ['discrete', 'continuous', 'hybrid'], self.action_space
123 |
if self.action_space == 'continuous':
124 |
self.multi_head = False
125 |
self.actor_head = ReparameterizationHead(
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
elif self.action_space == 'discrete':
135 |
actor_head_cls = DiscretePolicyHead
136 |
multi_head = not isinstance(action_shape, int)
137 |
self.multi_head = multi_head
138 |
if multi_head:
139 |
self.actor_head = MultiHead(
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
self.actor_head = actor_head_cls(
149 |
150 |
151 |
152 |
153 |
154 |
155 |
elif self.action_space == 'hybrid': # HPPO
156 |
# hybrid action space: action_type(discrete) + action_args(continuous),
157 |
# such as {'action_type_shape': torch.LongTensor([0]), 'action_args_shape': torch.FloatTensor([0.1, -0.27])}
158 |
action_shape.action_args_shape = squeeze(action_shape.action_args_shape)
159 |
action_shape.action_type_shape = squeeze(action_shape.action_type_shape)
160 |
actor_action_args = ReparameterizationHead(
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
actor_action_type = DiscretePolicyHead(
171 |
172 |
173 |
174 |
175 |
176 |
177 |
self.actor_head = nn.ModuleList([actor_action_type, actor_action_args])
178 |
179 |
# must use list, not nn.ModuleList
180 |
if self.share_encoder:
181 |
+ = [self.encoder, self.actor_head]
182 |
self.critic = [self.encoder, self.critic_head]
183 |
184 |
+ = [self.actor_encoder, self.actor_head]
185 |
self.critic = [self.critic_encoder, self.critic_head]
186 |
# Convenient for calling some apis (e.g. self.critic.parameters()),
187 |
# but may cause misunderstanding when `print(self)`
188 |
+ = nn.ModuleList(
189 |
self.critic = nn.ModuleList(self.critic)
190 |
191 |
def forward(self, inputs: ttorch.Tensor, mode: str) -> ttorch.Tensor:
192 |
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
193 |
return getattr(self, mode)(inputs)
194 |
195 |
def compute_actor(self, x: ttorch.Tensor) -> ttorch.Tensor:
196 |
if self.share_encoder:
197 |
x = self.encoder(x)
198 |
199 |
x = self.actor_encoder(x)
200 |
201 |
if self.action_space == 'discrete':
202 |
return self.actor_head(x)
203 |
elif self.action_space == 'continuous':
204 |
x = self.actor_head(x) # mu, sigma
205 |
return ttorch.as_tensor(x)
206 |
elif self.action_space == 'hybrid':
207 |
action_type = self.actor_head[0](x)
208 |
action_args = self.actor_head[1](x)
209 |
return ttorch.as_tensor({'action_type': action_type, 'action_args': action_args})
210 |
211 |
def compute_critic(self, x: ttorch.Tensor) -> ttorch.Tensor:
212 |
if self.share_encoder:
213 |
x = self.encoder(x)
214 |
215 |
x = self.critic_encoder(x)
216 |
x = self.critic_head(x)
217 |
return x
218 |
219 |
def compute_actor_critic(self, x: ttorch.Tensor) -> ttorch.Tensor:
220 |
if self.share_encoder:
221 |
actor_embedding = critic_embedding = self.encoder(x)
222 |
223 |
actor_embedding = self.actor_encoder(x)
224 |
critic_embedding = self.critic_encoder(x)
225 |
226 |
value = self.critic_head(critic_embedding)
227 |
228 |
if self.action_space == 'discrete':
229 |
logit = self.actor_head(actor_embedding)
230 |
return ttorch.as_tensor({'logit': logit, 'value': value['pred']})
231 |
elif self.action_space == 'continuous':
232 |
x = self.actor_head(actor_embedding)
233 |
return ttorch.as_tensor({'logit': x, 'value': value['pred']})
234 |
elif self.action_space == 'hybrid':
235 |
action_type = self.actor_head[0](actor_embedding)
236 |
action_args = self.actor_head[1](actor_embedding)
237 |
return ttorch.as_tensor(
238 |
239 |
'logit': {
240 |
'action_type': action_type,
241 |
'action_args': action_args
242 |
243 |
'value': value['pred']
244 |
245 |
@@ -0,0 +1,453 @@
1 |
from typing import Optional, Union, List
2 |
from ditk import logging
3 |
from easydict import EasyDict
4 |
import os
5 |
import numpy as np
6 |
import torch
7 |
import treetensor.torch as ttorch
8 |
from ding.framework import task, OnlineRLContext
9 |
from ding.framework.middleware import CkptSaver, trainer, \
10 |
wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, \
11 |
montecarlo_return_estimator, final_ctx_saver, EpisodeCollector
12 |
from ding.envs import BaseEnv
13 |
from ding.envs import setup_ding_env_manager
14 |
from ding.policy import PGPolicy
15 |
from ding.utils import set_pkg_seed
16 |
from ding.utils import get_env_fps, render
17 |
from ding.config import save_config_py, compile_config
18 |
from ding.model import PG
19 |
from ding.bonus.common import TrainingReturn, EvalReturn
20 |
from ding.config.example.PG import supported_env_cfg
21 |
from ding.config.example.PG import supported_env
22 |
23 |
24 |
class PGAgent:
25 |
26 |
27 |
Class of agent for training, evaluation and deployment of Reinforcement learning algorithm Policy Gradient(PG).
28 |
For more information about the system design of RL agent, please refer to \
29 |
30 |
31 |
``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
32 |
33 |
supported_env_list = list(supported_env_cfg.keys())
34 |
35 |
36 |
List of supported envs.
37 |
38 |
>>> from import PGAgent
39 |
>>> print(PGAgent.supported_env_list)
40 |
41 |
42 |
def __init__(
43 |
44 |
env_id: str = None,
45 |
env: BaseEnv = None,
46 |
seed: int = 0,
47 |
exp_name: str = None,
48 |
model: Optional[torch.nn.Module] = None,
49 |
cfg: Optional[Union[EasyDict, dict]] = None,
50 |
policy_state_dict: str = None,
51 |
) -> None:
52 |
53 |
54 |
Initialize agent for PG algorithm.
55 |
56 |
- env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
57 |
If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
58 |
If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
59 |
``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
60 |
- env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
61 |
If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
62 |
``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
63 |
If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
64 |
- seed (:obj:`int`): The random seed, which is set before running the program. \
65 |
Default to 0.
66 |
- exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
67 |
log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
68 |
- model (:obj:`torch.nn.Module`): The model of PG algorithm, which should be an instance of class \
69 |
:class:`ding.model.PG`. \
70 |
If not specified, a default model will be generated according to the configuration.
71 |
- cfg (:obj:Union[EasyDict, dict]): The configuration of PG algorithm, which is a dict. \
72 |
Default to None. If not specified, the default configuration will be used. \
73 |
The default configuration can be found in ``ding/config/example/PG/``.
74 |
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
75 |
If specified, the policy will be loaded from this file. Default to None.
76 |
77 |
.. note::
78 |
An RL Agent Instance can be initialized in two basic ways. \
79 |
For example, we have an environment with id ``LunarLanderContinuous-v2`` registered in gym, \
80 |
and we want to train an agent with PG algorithm with default configuration. \
81 |
Then we can initialize the agent in the following ways:
82 |
>>> agent = PGAgent(env_id='LunarLanderContinuous-v2')
83 |
or, if we want can specify the env_id in the configuration:
84 |
>>> cfg = {'env': {'env_id': 'LunarLanderContinuous-v2'}, 'policy': ...... }
85 |
>>> agent = PGAgent(cfg=cfg)
86 |
There are also other arguments to specify the agent when initializing.
87 |
For example, if we want to specify the environment instance:
88 |
>>> env = CustomizedEnv('LunarLanderContinuous-v2')
89 |
>>> agent = PGAgent(cfg=cfg, env=env)
90 |
or, if we want to specify the model:
91 |
>>> model = PG(**cfg.policy.model)
92 |
>>> agent = PGAgent(cfg=cfg, model=model)
93 |
or, if we want to reload the policy from a saved policy state dict:
94 |
>>> agent = PGAgent(cfg=cfg, policy_state_dict='LunarLanderContinuous-v2.pth.tar')
95 |
Make sure that the configuration is consistent with the saved policy state dict.
96 |
97 |
98 |
assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
99 |
100 |
if cfg is not None and not isinstance(cfg, EasyDict):
101 |
cfg = EasyDict(cfg)
102 |
103 |
if env_id is not None:
104 |
assert env_id in PGAgent.supported_env_list, "Please use supported envs: {}".format(
105 |
106 |
107 |
if cfg is None:
108 |
cfg = supported_env_cfg[env_id]
109 |
110 |
assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
111 |
112 |
assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
113 |
assert cfg.env.env_id in PGAgent.supported_env_list, "Please use supported envs: {}".format(
114 |
115 |
116 |
default_policy_config = EasyDict({"policy": PGPolicy.default_config()})
117 |
118 |
cfg = default_policy_config
119 |
120 |
if exp_name is not None:
121 |
cfg.exp_name = exp_name
122 |
self.cfg = compile_config(cfg, policy=PGPolicy)
123 |
self.exp_name = self.cfg.exp_name
124 |
if env is None:
125 |
self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
126 |
127 |
assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
128 |
self.env = env
129 |
130 |
131 |
self.seed = seed
132 |
set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
133 |
if not os.path.exists(self.exp_name):
134 |
135 |
save_config_py(self.cfg, os.path.join(self.exp_name, ''))
136 |
if model is None:
137 |
model = PG(**self.cfg.policy.model)
138 |
self.policy = PGPolicy(self.cfg.policy, model=model)
139 |
if policy_state_dict is not None:
140 |
141 |
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
142 |
143 |
def train(
144 |
145 |
step: int = int(1e7),
146 |
collector_env_num: int = None,
147 |
evaluator_env_num: int = None,
148 |
n_iter_save_ckpt: int = 1000,
149 |
context: Optional[str] = None,
150 |
debug: bool = False,
151 |
wandb_sweep: bool = False,
152 |
) -> TrainingReturn:
153 |
154 |
155 |
Train the agent with PG algorithm for ``step`` iterations with ``collector_env_num`` collector \
156 |
environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
157 |
recorded and saved by wandb.
158 |
159 |
- step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
160 |
- collector_env_num (:obj:`int`): The collector environment number. Default to None. \
161 |
If not specified, it will be set according to the configuration.
162 |
- evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
163 |
If not specified, it will be set according to the configuration.
164 |
- n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
165 |
Default to 1000.
166 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
167 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
168 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
169 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
170 |
subprocess environment manager will be used.
171 |
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
172 |
which is a hyper-parameter optimization process for seeking the best configurations. \
173 |
Default to False. If True, the wandb sweep id will be used as the experiment name.
174 |
175 |
- (:obj:`TrainingReturn`): The training result, of which the attributions are:
176 |
- wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
177 |
178 |
179 |
if debug:
180 |
181 |
182 |
# define env and policy
183 |
collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
184 |
evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
185 |
collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
186 |
evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
187 |
188 |
with task.start(ctx=OnlineRLContext()):
189 |
190 |
191 |
192 |
193 |
194 |
render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
195 |
196 |
197 |
task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
198 |
task.use(EpisodeCollector(self.cfg, self.policy.collect_mode, collector_env))
199 |
200 |
task.use(trainer(self.cfg, self.policy.learn_mode))
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
return TrainingReturn(wandb_url=task.ctx.wandb_url)
215 |
216 |
def deploy(
217 |
218 |
enable_save_replay: bool = False,
219 |
concatenate_all_replay: bool = False,
220 |
replay_save_path: str = None,
221 |
seed: Optional[Union[int, List]] = None,
222 |
debug: bool = False
223 |
) -> EvalReturn:
224 |
225 |
226 |
Deploy the agent with PG algorithm by interacting with the environment, during which the replay video \
227 |
can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
228 |
229 |
- enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
230 |
- concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
231 |
Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
232 |
If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
233 |
the replay video of each episode will be saved separately.
234 |
- replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
235 |
If not specified, the video will be saved in ``exp_name/videos``.
236 |
- seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
237 |
Default to None. If not specified, ``self.seed`` will be used. \
238 |
If ``seed`` is an integer, the agent will be deployed once. \
239 |
If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
240 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
241 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
242 |
subprocess environment manager will be used.
243 |
244 |
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
245 |
- eval_value (:obj:`np.float32`): The mean of evaluation return.
246 |
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
247 |
248 |
249 |
if debug:
250 |
251 |
# define env and policy
252 |
env = self.env.clone(caller='evaluator')
253 |
254 |
if seed is not None and isinstance(seed, int):
255 |
seeds = [seed]
256 |
elif seed is not None and isinstance(seed, list):
257 |
seeds = seed
258 |
259 |
seeds = [self.seed]
260 |
261 |
returns = []
262 |
images = []
263 |
if enable_save_replay:
264 |
replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
265 |
266 |
267 |
logging.warning('No video would be generated during the deploy.')
268 |
if concatenate_all_replay:
269 |
logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
270 |
concatenate_all_replay = False
271 |
272 |
def single_env_forward_wrapper(forward_fn, cuda=True):
273 |
274 |
def _forward(obs):
275 |
# unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
276 |
obs = ttorch.as_tensor(obs).unsqueeze(0)
277 |
if cuda and torch.cuda.is_available():
278 |
obs = obs.cuda()
279 |
output = forward_fn(obs)
280 |
if self.policy._cfg.deterministic_eval:
281 |
if self.policy._cfg.action_space == 'discrete':
282 |
output['action'] = output['logit'].argmax(dim=-1)
283 |
elif self.policy._cfg.action_space == 'continuous':
284 |
output['action'] = output['logit']['mu']
285 |
286 |
raise KeyError("invalid action_space: {}".format(self.policy._cfg.action_space))
287 |
288 |
output['action'] = output['dist'].sample()
289 |
# squeeze means delete batch dim, i.e. (1, A) -> (A, )
290 |
action = output['action'].squeeze(0).detach().cpu().numpy()
291 |
return action
292 |
293 |
return _forward
294 |
295 |
forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
296 |
297 |
# reset first to make sure the env is in the initial state
298 |
# env will be reset again in the main loop
299 |
300 |
301 |
for seed in seeds:
302 |
env.seed(seed, dynamic_seed=False)
303 |
return_ = 0.
304 |
step = 0
305 |
obs = env.reset()
306 |
images.append(render(env)[None]) if concatenate_all_replay else None
307 |
while True:
308 |
action = forward_fn(obs)
309 |
obs, rew, done, info = env.step(action)
310 |
images.append(render(env)[None]) if concatenate_all_replay else None
311 |
return_ += rew
312 |
step += 1
313 |
if done:
314 |
315 |
+'DQN deploy is finished, final episode return with {step} steps is: {return_}')
316 |
317 |
318 |
319 |
320 |
if concatenate_all_replay:
321 |
images = np.concatenate(images, axis=0)
322 |
import imageio
323 |
imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
324 |
325 |
return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
326 |
327 |
def collect_data(
328 |
329 |
env_num: int = 8,
330 |
save_data_path: Optional[str] = None,
331 |
n_sample: Optional[int] = None,
332 |
n_episode: Optional[int] = None,
333 |
context: Optional[str] = None,
334 |
debug: bool = False
335 |
) -> None:
336 |
337 |
338 |
Collect data with PG algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
339 |
The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
340 |
341 |
342 |
- env_num (:obj:`int`): The number of collector environments. Default to 8.
343 |
- save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
344 |
If not specified, the data will be saved in ``exp_name/demo_data``.
345 |
- n_sample (:obj:`int`): The number of samples to collect. Default to None. \
346 |
If not specified, ``n_episode`` must be specified.
347 |
- n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
348 |
If not specified, ``n_sample`` must be specified.
349 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
350 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
351 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
352 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
353 |
subprocess environment manager will be used.
354 |
355 |
356 |
if debug:
357 |
358 |
if n_episode is not None:
359 |
raise NotImplementedError
360 |
# define env and policy
361 |
env_num = env_num if env_num else self.cfg.env.collector_env_num
362 |
env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
363 |
364 |
if save_data_path is None:
365 |
save_data_path = os.path.join(self.exp_name, 'demo_data')
366 |
367 |
# main execution task
368 |
with task.start(ctx=OnlineRLContext()):
369 |
370 |
371 |
self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
372 |
373 |
374 |
task.use(offline_data_saver(save_data_path, data_type='hdf5'))
375 |
376 |
377 |
f'PG collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
378 |
379 |
380 |
def batch_evaluate(
381 |
382 |
env_num: int = 4,
383 |
n_evaluator_episode: int = 4,
384 |
context: Optional[str] = None,
385 |
debug: bool = False
386 |
) -> EvalReturn:
387 |
388 |
389 |
Evaluate the agent with PG algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
390 |
environments. The evaluation result will be returned.
391 |
The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
392 |
multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
393 |
will only create one evaluator environment to evaluate the agent and save the replay video.
394 |
395 |
- env_num (:obj:`int`): The number of evaluator environments. Default to 4.
396 |
- n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
397 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
398 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
399 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
400 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
401 |
subprocess environment manager will be used.
402 |
403 |
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
404 |
- eval_value (:obj:`np.float32`): The mean of evaluation return.
405 |
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
406 |
407 |
408 |
if debug:
409 |
410 |
# define env and policy
411 |
env_num = env_num if env_num else self.cfg.env.evaluator_env_num
412 |
env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
413 |
414 |
# reset first to make sure the env is in the initial state
415 |
# env will be reset again in the main loop
416 |
417 |
418 |
419 |
evaluate_cfg = self.cfg
420 |
evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
421 |
422 |
# main execution task
423 |
with task.start(ctx=OnlineRLContext()):
424 |
task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
425 |
426 |
427 |
return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
428 |
429 |
430 |
def best(self) -> 'PGAgent':
431 |
432 |
433 |
Load the best model from the checkpoint directory, \
434 |
which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
435 |
The return value is the agent with the best model.
436 |
437 |
- (:obj:`PGAgent`): The agent with the best model.
438 |
439 |
>>> agent = PGAgent(env_id='LunarLanderContinuous-v2')
440 |
>>> agent.train()
441 |
>>> agent =
442 |
443 |
.. note::
444 |
The best model is the model with the highest evaluation return. If this method is called, the current \
445 |
model will be replaced by the best model.
446 |
447 |
448 |
best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
449 |
# Load best model if it exists
450 |
if os.path.exists(best_model_file_path):
451 |
policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
452 |
453 |
return self
@@ -0,0 +1,471 @@
1 |
from typing import Optional, Union, List
2 |
from ditk import logging
3 |
from easydict import EasyDict
4 |
import os
5 |
import numpy as np
6 |
import torch
7 |
import treetensor.torch as ttorch
8 |
from ding.framework import task, OnlineRLContext
9 |
from ding.framework.middleware import CkptSaver, final_ctx_saver, OffPolicyLearner, StepCollector, \
10 |
wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, gae_estimator
11 |
from ding.envs import BaseEnv
12 |
from ding.envs import setup_ding_env_manager
13 |
from ding.policy import PPOOffPolicy
14 |
from ding.utils import set_pkg_seed
15 |
from ding.utils import get_env_fps, render
16 |
from ding.config import save_config_py, compile_config
17 |
from ding.model import VAC
18 |
from ding.model import model_wrap
19 |
from import DequeBuffer
20 |
from ding.bonus.common import TrainingReturn, EvalReturn
21 |
from ding.config.example.PPOOffPolicy import supported_env_cfg
22 |
from ding.config.example.PPOOffPolicy import supported_env
23 |
24 |
25 |
class PPOOffPolicyAgent:
26 |
27 |
28 |
Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
29 |
Proximal Policy Optimization(PPO) in an off-policy style.
30 |
For more information about the system design of RL agent, please refer to \
31 |
32 |
33 |
``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
34 |
35 |
supported_env_list = list(supported_env_cfg.keys())
36 |
37 |
38 |
List of supported envs.
39 |
40 |
>>> from ding.bonus.ppo_offpolicy import PPOOffPolicyAgent
41 |
>>> print(PPOOffPolicyAgent.supported_env_list)
42 |
43 |
44 |
def __init__(
45 |
46 |
env_id: str = None,
47 |
env: BaseEnv = None,
48 |
seed: int = 0,
49 |
exp_name: str = None,
50 |
model: Optional[torch.nn.Module] = None,
51 |
cfg: Optional[Union[EasyDict, dict]] = None,
52 |
policy_state_dict: str = None,
53 |
) -> None:
54 |
55 |
56 |
Initialize agent for PPO (offpolicy) algorithm.
57 |
58 |
- env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
59 |
If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
60 |
If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
61 |
``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
62 |
- env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
63 |
If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
64 |
``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
65 |
If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
66 |
- seed (:obj:`int`): The random seed, which is set before running the program. \
67 |
Default to 0.
68 |
- exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
69 |
log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
70 |
- model (:obj:`torch.nn.Module`): The model of PPO (offpolicy) algorithm, \
71 |
which should be an instance of class :class:`ding.model.VAC`. \
72 |
If not specified, a default model will be generated according to the configuration.
73 |
- cfg (:obj:Union[EasyDict, dict]): The configuration of PPO (offpolicy) algorithm, which is a dict. \
74 |
Default to None. If not specified, the default configuration will be used. \
75 |
The default configuration can be found in ``ding/config/example/PPO (offpolicy)/``.
76 |
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
77 |
If specified, the policy will be loaded from this file. Default to None.
78 |
79 |
.. note::
80 |
An RL Agent Instance can be initialized in two basic ways. \
81 |
For example, we have an environment with id ``LunarLander-v2`` registered in gym, \
82 |
and we want to train an agent with PPO (offpolicy) algorithm with default configuration. \
83 |
Then we can initialize the agent in the following ways:
84 |
>>> agent = PPOOffPolicyAgent(env_id='LunarLander-v2')
85 |
or, if we want can specify the env_id in the configuration:
86 |
>>> cfg = {'env': {'env_id': 'LunarLander-v2'}, 'policy': ...... }
87 |
>>> agent = PPOOffPolicyAgent(cfg=cfg)
88 |
There are also other arguments to specify the agent when initializing.
89 |
For example, if we want to specify the environment instance:
90 |
>>> env = CustomizedEnv('LunarLander-v2')
91 |
>>> agent = PPOOffPolicyAgent(cfg=cfg, env=env)
92 |
or, if we want to specify the model:
93 |
>>> model = VAC(**cfg.policy.model)
94 |
>>> agent = PPOOffPolicyAgent(cfg=cfg, model=model)
95 |
or, if we want to reload the policy from a saved policy state dict:
96 |
>>> agent = PPOOffPolicyAgent(cfg=cfg, policy_state_dict='LunarLander-v2.pth.tar')
97 |
Make sure that the configuration is consistent with the saved policy state dict.
98 |
99 |
100 |
assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
101 |
102 |
if cfg is not None and not isinstance(cfg, EasyDict):
103 |
cfg = EasyDict(cfg)
104 |
105 |
if env_id is not None:
106 |
assert env_id in PPOOffPolicyAgent.supported_env_list, "Please use supported envs: {}".format(
107 |
108 |
109 |
if cfg is None:
110 |
cfg = supported_env_cfg[env_id]
111 |
112 |
assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
113 |
114 |
assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
115 |
assert cfg.env.env_id in PPOOffPolicyAgent.supported_env_list, "Please use supported envs: {}".format(
116 |
117 |
118 |
default_policy_config = EasyDict({"policy": PPOOffPolicy.default_config()})
119 |
120 |
cfg = default_policy_config
121 |
122 |
if exp_name is not None:
123 |
cfg.exp_name = exp_name
124 |
self.cfg = compile_config(cfg, policy=PPOOffPolicy)
125 |
self.exp_name = self.cfg.exp_name
126 |
if env is None:
127 |
self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
128 |
129 |
assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
130 |
self.env = env
131 |
132 |
133 |
self.seed = seed
134 |
set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
135 |
if not os.path.exists(self.exp_name):
136 |
137 |
save_config_py(self.cfg, os.path.join(self.exp_name, ''))
138 |
if model is None:
139 |
model = VAC(**self.cfg.policy.model)
140 |
self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
141 |
self.policy = PPOOffPolicy(self.cfg.policy, model=model)
142 |
if policy_state_dict is not None:
143 |
144 |
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
145 |
146 |
def train(
147 |
148 |
step: int = int(1e7),
149 |
collector_env_num: int = None,
150 |
evaluator_env_num: int = None,
151 |
n_iter_save_ckpt: int = 1000,
152 |
context: Optional[str] = None,
153 |
debug: bool = False,
154 |
wandb_sweep: bool = False,
155 |
) -> TrainingReturn:
156 |
157 |
158 |
Train the agent with PPO (offpolicy) algorithm for ``step`` iterations with ``collector_env_num`` \
159 |
collector environments and ``evaluator_env_num`` evaluator environments. \
160 |
Information during training will be recorded and saved by wandb.
161 |
162 |
- step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
163 |
- collector_env_num (:obj:`int`): The collector environment number. Default to None. \
164 |
If not specified, it will be set according to the configuration.
165 |
- evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
166 |
If not specified, it will be set according to the configuration.
167 |
- n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
168 |
Default to 1000.
169 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
170 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
171 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
172 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
173 |
subprocess environment manager will be used.
174 |
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
175 |
which is a hyper-parameter optimization process for seeking the best configurations. \
176 |
Default to False. If True, the wandb sweep id will be used as the experiment name.
177 |
178 |
- (:obj:`TrainingReturn`): The training result, of which the attributions are:
179 |
- wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
180 |
181 |
182 |
if debug:
183 |
184 |
185 |
# define env and policy
186 |
collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
187 |
evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
188 |
collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
189 |
evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
190 |
191 |
with task.start(ctx=OnlineRLContext()):
192 |
193 |
194 |
195 |
196 |
197 |
render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
198 |
199 |
200 |
task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
201 |
202 |
203 |
204 |
205 |
206 |
207 |
if hasattr(self.cfg.policy, 'random_collect_size') else 0,
208 |
209 |
210 |
task.use(gae_estimator(self.cfg, self.policy.collect_mode, self.buffer_))
211 |
task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
return TrainingReturn(wandb_url=task.ctx.wandb_url)
228 |
229 |
def deploy(
230 |
231 |
enable_save_replay: bool = False,
232 |
concatenate_all_replay: bool = False,
233 |
replay_save_path: str = None,
234 |
seed: Optional[Union[int, List]] = None,
235 |
debug: bool = False
236 |
) -> EvalReturn:
237 |
238 |
239 |
Deploy the agent with PPO (offpolicy) algorithm by interacting with the environment, \
240 |
during which the replay video can be saved if ``enable_save_replay`` is True. \
241 |
The evaluation result will be returned.
242 |
243 |
- enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
244 |
- concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
245 |
Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
246 |
If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
247 |
the replay video of each episode will be saved separately.
248 |
- replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
249 |
If not specified, the video will be saved in ``exp_name/videos``.
250 |
- seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
251 |
Default to None. If not specified, ``self.seed`` will be used. \
252 |
If ``seed`` is an integer, the agent will be deployed once. \
253 |
If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
254 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
255 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
256 |
subprocess environment manager will be used.
257 |
258 |
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
259 |
- eval_value (:obj:`np.float32`): The mean of evaluation return.
260 |
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
261 |
262 |
263 |
if debug:
264 |
265 |
# define env and policy
266 |
env = self.env.clone(caller='evaluator')
267 |
268 |
if seed is not None and isinstance(seed, int):
269 |
seeds = [seed]
270 |
elif seed is not None and isinstance(seed, list):
271 |
seeds = seed
272 |
273 |
seeds = [self.seed]
274 |
275 |
returns = []
276 |
images = []
277 |
if enable_save_replay:
278 |
replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
279 |
280 |
281 |
logging.warning('No video would be generated during the deploy.')
282 |
if concatenate_all_replay:
283 |
logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
284 |
concatenate_all_replay = False
285 |
286 |
def single_env_forward_wrapper(forward_fn, cuda=True):
287 |
288 |
if self.cfg.policy.action_space == 'discrete':
289 |
forward_fn = model_wrap(forward_fn, wrapper_name='argmax_sample').forward
290 |
elif self.cfg.policy.action_space == 'continuous':
291 |
forward_fn = model_wrap(forward_fn, wrapper_name='deterministic_sample').forward
292 |
elif self.cfg.policy.action_space == 'hybrid':
293 |
forward_fn = model_wrap(forward_fn, wrapper_name='hybrid_deterministic_argmax_sample').forward
294 |
elif self.cfg.policy.action_space == 'general':
295 |
forward_fn = model_wrap(forward_fn, wrapper_name='base').forward
296 |
297 |
raise NotImplementedError
298 |
299 |
def _forward(obs):
300 |
# unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
301 |
obs = ttorch.as_tensor(obs).unsqueeze(0)
302 |
if cuda and torch.cuda.is_available():
303 |
obs = obs.cuda()
304 |
action = forward_fn(obs, mode='compute_actor')["action"]
305 |
# squeeze means delete batch dim, i.e. (1, A) -> (A, )
306 |
action = action.squeeze(0).detach().cpu().numpy()
307 |
return action
308 |
309 |
return _forward
310 |
311 |
forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
312 |
313 |
# reset first to make sure the env is in the initial state
314 |
# env will be reset again in the main loop
315 |
316 |
317 |
for seed in seeds:
318 |
env.seed(seed, dynamic_seed=False)
319 |
return_ = 0.
320 |
step = 0
321 |
obs = env.reset()
322 |
images.append(render(env)[None]) if concatenate_all_replay else None
323 |
while True:
324 |
action = forward_fn(obs)
325 |
obs, rew, done, info = env.step(action)
326 |
images.append(render(env)[None]) if concatenate_all_replay else None
327 |
return_ += rew
328 |
step += 1
329 |
if done:
330 |
331 |
+'PPO (offpolicy) deploy is finished, final episode return with {step} steps is: {return_}')
332 |
333 |
334 |
335 |
336 |
if concatenate_all_replay:
337 |
images = np.concatenate(images, axis=0)
338 |
import imageio
339 |
imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
340 |
341 |
return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
342 |
343 |
def collect_data(
344 |
345 |
env_num: int = 8,
346 |
save_data_path: Optional[str] = None,
347 |
n_sample: Optional[int] = None,
348 |
n_episode: Optional[int] = None,
349 |
context: Optional[str] = None,
350 |
debug: bool = False
351 |
) -> None:
352 |
353 |
354 |
Collect data with PPO (offpolicy) algorithm for ``n_episode`` episodes \
355 |
with ``env_num`` collector environments. \
356 |
The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
357 |
358 |
359 |
- env_num (:obj:`int`): The number of collector environments. Default to 8.
360 |
- save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
361 |
If not specified, the data will be saved in ``exp_name/demo_data``.
362 |
- n_sample (:obj:`int`): The number of samples to collect. Default to None. \
363 |
If not specified, ``n_episode`` must be specified.
364 |
- n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
365 |
If not specified, ``n_sample`` must be specified.
366 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
367 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
368 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
369 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
370 |
subprocess environment manager will be used.
371 |
372 |
373 |
if debug:
374 |
375 |
if n_episode is not None:
376 |
raise NotImplementedError
377 |
# define env and policy
378 |
env_num = env_num if env_num else self.cfg.env.collector_env_num
379 |
env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
380 |
381 |
if save_data_path is None:
382 |
save_data_path = os.path.join(self.exp_name, 'demo_data')
383 |
384 |
# main execution task
385 |
with task.start(ctx=OnlineRLContext()):
386 |
387 |
388 |
self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
389 |
390 |
391 |
task.use(offline_data_saver(save_data_path, data_type='hdf5'))
392 |
393 |
394 |
f'PPOOffPolicy collecting is finished, more than {n_sample} \
395 |
samples are collected and saved in `{save_data_path}`'
396 |
397 |
398 |
def batch_evaluate(
399 |
400 |
env_num: int = 4,
401 |
n_evaluator_episode: int = 4,
402 |
context: Optional[str] = None,
403 |
debug: bool = False
404 |
) -> EvalReturn:
405 |
406 |
407 |
Evaluate the agent with PPO (offpolicy) algorithm for ``n_evaluator_episode`` episodes \
408 |
with ``env_num`` evaluator environments. The evaluation result will be returned.
409 |
The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
410 |
multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
411 |
will only create one evaluator environment to evaluate the agent and save the replay video.
412 |
413 |
- env_num (:obj:`int`): The number of evaluator environments. Default to 4.
414 |
- n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
415 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
416 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
417 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
418 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
419 |
subprocess environment manager will be used.
420 |
421 |
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
422 |
- eval_value (:obj:`np.float32`): The mean of evaluation return.
423 |
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
424 |
425 |
426 |
if debug:
427 |
428 |
# define env and policy
429 |
env_num = env_num if env_num else self.cfg.env.evaluator_env_num
430 |
env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
431 |
432 |
# reset first to make sure the env is in the initial state
433 |
# env will be reset again in the main loop
434 |
435 |
436 |
437 |
evaluate_cfg = self.cfg
438 |
evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
439 |
440 |
# main execution task
441 |
with task.start(ctx=OnlineRLContext()):
442 |
task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
443 |
444 |
445 |
return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
446 |
447 |
448 |
def best(self) -> 'PPOOffPolicyAgent':
449 |
450 |
451 |
Load the best model from the checkpoint directory, \
452 |
which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
453 |
The return value is the agent with the best model.
454 |
455 |
- (:obj:`PPOOffPolicyAgent`): The agent with the best model.
456 |
457 |
>>> agent = PPOOffPolicyAgent(env_id='LunarLander-v2')
458 |
>>> agent.train()
459 |
460 |
461 |
.. note::
462 |
The best model is the model with the highest evaluation return. If this method is called, the current \
463 |
model will be replaced by the best model.
464 |
465 |
466 |
best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
467 |
# Load best model if it exists
468 |
if os.path.exists(best_model_file_path):
469 |
policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
470 |
471 |
return self
@@ -0,0 +1,509 @@
1 |
from typing import Optional, Union, List
2 |
from ditk import logging
3 |
from easydict import EasyDict
4 |
from functools import partial
5 |
import os
6 |
import gym
7 |
import gymnasium
8 |
import numpy as np
9 |
import torch
10 |
from ding.framework import task, OnlineRLContext
11 |
from ding.framework.middleware import interaction_evaluator_ttorch, PPOFStepCollector, multistep_trainer, CkptSaver, \
12 |
wandb_online_logger, offline_data_saver, termination_checker, ppof_adv_estimator
13 |
from ding.envs import BaseEnv, BaseEnvManagerV2, SubprocessEnvManagerV2
14 |
from ding.policy import PPOFPolicy, single_env_forward_wrapper_ttorch
15 |
from ding.utils import set_pkg_seed
16 |
from ding.utils import get_env_fps, render
17 |
from ding.config import save_config_py
18 |
from .model import PPOFModel
19 |
from .config import get_instance_config, get_instance_env, get_hybrid_shape
20 |
from ding.bonus.common import TrainingReturn, EvalReturn
21 |
22 |
23 |
class PPOF:
24 |
25 |
26 |
Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
27 |
Proximal Policy Optimization(PPO).
28 |
For more information about the system design of RL agent, please refer to \
29 |
30 |
31 |
``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
32 |
33 |
34 |
supported_env_list = [
35 |
# common
36 |
37 |
38 |
39 |
40 |
41 |
# ch2: action
42 |
43 |
44 |
45 |
# ch3: obs
46 |
47 |
48 |
49 |
50 |
# ch4: reward
51 |
52 |
53 |
# atari
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
# mujoco
62 |
63 |
64 |
65 |
66 |
67 |
68 |
List of supported envs.
69 |
70 |
>>> from ding.bonus.ppof import PPOF
71 |
>>> print(PPOF.supported_env_list)
72 |
73 |
74 |
def __init__(
75 |
76 |
env_id: str = None,
77 |
env: BaseEnv = None,
78 |
seed: int = 0,
79 |
exp_name: str = None,
80 |
model: Optional[torch.nn.Module] = None,
81 |
cfg: Optional[Union[EasyDict, dict]] = None,
82 |
policy_state_dict: str = None
83 |
) -> None:
84 |
85 |
86 |
Initialize agent for PPO algorithm.
87 |
88 |
- env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
89 |
If ``env_id`` is not specified, ``env_id`` in ``cfg`` must be specified. \
90 |
If ``env_id`` is specified, ``env_id`` in ``cfg`` will be ignored. \
91 |
``env_id`` should be one of the supported envs, which can be found in ``PPOF.supported_env_list``.
92 |
- env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
93 |
If ``env`` is not specified, ``env_id`` or ``cfg.env_id`` must be specified. \
94 |
``env_id`` or ``cfg.env_id`` will be used to create environment instance. \
95 |
If ``env`` is specified, ``env_id`` and ``cfg.env_id`` will be ignored.
96 |
- seed (:obj:`int`): The random seed, which is set before running the program. \
97 |
Default to 0.
98 |
- exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
99 |
log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
100 |
- model (:obj:`torch.nn.Module`): The model of PPO algorithm, which should be an instance of class \
101 |
``ding.model.PPOFModel``. \
102 |
If not specified, a default model will be generated according to the configuration.
103 |
- cfg (:obj:`Union[EasyDict, dict]`): The configuration of PPO algorithm, which is a dict. \
104 |
Default to None. If not specified, the default configuration will be used.
105 |
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
106 |
If specified, the policy will be loaded from this file. Default to None.
107 |
108 |
.. note::
109 |
An RL Agent Instance can be initialized in two basic ways. \
110 |
For example, we have an environment with id ``LunarLander-v2`` registered in gym, \
111 |
and we want to train an agent with PPO algorithm with default configuration. \
112 |
Then we can initialize the agent in the following ways:
113 |
>>> agent = PPOF(env_id='LunarLander-v2')
114 |
or, if we want can specify the env_id in the configuration:
115 |
>>> cfg = {'env': {'env_id': 'LunarLander-v2'}, 'policy': ...... }
116 |
>>> agent = PPOF(cfg=cfg)
117 |
There are also other arguments to specify the agent when initializing.
118 |
For example, if we want to specify the environment instance:
119 |
>>> env = CustomizedEnv('LunarLander-v2')
120 |
>>> agent = PPOF(cfg=cfg, env=env)
121 |
or, if we want to specify the model:
122 |
>>> model = VAC(**cfg.policy.model)
123 |
>>> agent = PPOF(cfg=cfg, model=model)
124 |
or, if we want to reload the policy from a saved policy state dict:
125 |
>>> agent = PPOF(cfg=cfg, policy_state_dict='LunarLander-v2.pth.tar')
126 |
Make sure that the configuration is consistent with the saved policy state dict.
127 |
128 |
129 |
assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
130 |
131 |
if cfg is not None and not isinstance(cfg, EasyDict):
132 |
cfg = EasyDict(cfg)
133 |
134 |
if env_id is not None:
135 |
assert env_id in PPOF.supported_env_list, "Please use supported envs: {}".format(PPOF.supported_env_list)
136 |
if cfg is None:
137 |
cfg = get_instance_config(env_id, algorithm="PPOF")
138 |
139 |
if not hasattr(cfg, "env_id"):
140 |
cfg.env_id = env_id
141 |
assert cfg.env_id == env_id, "env_id in cfg should be the same as env_id in args."
142 |
143 |
assert hasattr(cfg, "env_id"), "Please specify env_id in cfg."
144 |
assert cfg.env_id in PPOF.supported_env_list, "Please use supported envs: {}".format(
145 |
146 |
147 |
148 |
if exp_name is not None:
149 |
cfg.exp_name = exp_name
150 |
elif not hasattr(cfg, "exp_name"):
151 |
cfg.exp_name = "{}-{}".format(cfg.env_id, "PPO")
152 |
self.cfg = cfg
153 |
self.exp_name = self.cfg.exp_name
154 |
155 |
if env is None:
156 |
self.env = get_instance_env(self.cfg.env_id)
157 |
158 |
self.env = env
159 |
160 |
161 |
self.seed = seed
162 |
set_pkg_seed(self.seed, use_cuda=self.cfg.cuda)
163 |
164 |
if not os.path.exists(self.exp_name):
165 |
166 |
save_config_py(self.cfg, os.path.join(self.exp_name, ''))
167 |
168 |
action_space = self.env.action_space
169 |
if isinstance(action_space, (gym.spaces.Discrete, gymnasium.spaces.Discrete)):
170 |
action_shape = int(action_space.n)
171 |
elif isinstance(action_space, (gym.spaces.Tuple, gymnasium.spaces.Tuple)):
172 |
action_shape = get_hybrid_shape(action_space)
173 |
174 |
action_shape = action_space.shape
175 |
176 |
# Three types of value normalization is supported currently
177 |
assert self.cfg.value_norm in ['popart', 'value_rescale', 'symlog', 'baseline']
178 |
if model is None:
179 |
if self.cfg.value_norm != 'popart':
180 |
model = PPOFModel(
181 |
182 |
183 |
184 |
185 |
186 |
187 |
model = PPOFModel(
188 |
189 |
190 |
191 |
192 |
193 |
194 |
self.policy = PPOFPolicy(self.cfg, model=model)
195 |
if policy_state_dict is not None:
196 |
197 |
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
198 |
199 |
def train(
200 |
201 |
step: int = int(1e7),
202 |
collector_env_num: int = 4,
203 |
evaluator_env_num: int = 4,
204 |
n_iter_log_show: int = 500,
205 |
n_iter_save_ckpt: int = 1000,
206 |
context: Optional[str] = None,
207 |
reward_model: Optional[str] = None,
208 |
debug: bool = False,
209 |
wandb_sweep: bool = False,
210 |
) -> TrainingReturn:
211 |
212 |
213 |
Train the agent with PPO algorithm for ``step`` iterations with ``collector_env_num`` collector \
214 |
environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
215 |
recorded and saved by wandb.
216 |
217 |
- step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
218 |
- collector_env_num (:obj:`int`): The number of collector environments. Default to 4.
219 |
- evaluator_env_num (:obj:`int`): The number of evaluator environments. Default to 4.
220 |
- n_iter_log_show (:obj:`int`): The frequency of logging every training iteration. Default to 500.
221 |
- n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
222 |
Default to 1000.
223 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
224 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
225 |
- reward_model (:obj:`str`): The reward model name. Default to None. This argument is not supported yet.
226 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
227 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
228 |
subprocess environment manager will be used.
229 |
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
230 |
which is a hyper-parameter optimization process for seeking the best configurations. \
231 |
Default to False. If True, the wandb sweep id will be used as the experiment name.
232 |
233 |
- (:obj:`TrainingReturn`): The training result, of which the attributions are:
234 |
- wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
235 |
236 |
237 |
if debug:
238 |
239 |
240 |
# define env and policy
241 |
collector_env = self._setup_env_manager(collector_env_num, context, debug, 'collector')
242 |
evaluator_env = self._setup_env_manager(evaluator_env_num, context, debug, 'evaluator')
243 |
244 |
if reward_model is not None:
245 |
# self.reward_model = create_reward_model(reward_model, self.cfg.reward_model)
246 |
247 |
248 |
with task.start(ctx=OnlineRLContext()):
249 |
task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env))
250 |
task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
251 |
task.use(PPOFStepCollector(self.seed, self.policy, collector_env, self.cfg.n_sample))
252 |
253 |
task.use(multistep_trainer(self.policy, log_freq=n_iter_log_show))
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
return TrainingReturn(wandb_url=task.ctx.wandb_url)
267 |
268 |
def deploy(
269 |
270 |
enable_save_replay: bool = False,
271 |
concatenate_all_replay: bool = False,
272 |
replay_save_path: str = None,
273 |
seed: Optional[Union[int, List]] = None,
274 |
debug: bool = False
275 |
) -> EvalReturn:
276 |
277 |
278 |
Deploy the agent with PPO algorithm by interacting with the environment, during which the replay video \
279 |
can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
280 |
281 |
- enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
282 |
- concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
283 |
Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
284 |
If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
285 |
the replay video of each episode will be saved separately.
286 |
- replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
287 |
If not specified, the video will be saved in ``exp_name/videos``.
288 |
- seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
289 |
Default to None. If not specified, ``self.seed`` will be used. \
290 |
If ``seed`` is an integer, the agent will be deployed once. \
291 |
If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
292 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
293 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
294 |
subprocess environment manager will be used.
295 |
296 |
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
297 |
- eval_value (:obj:`np.float32`): The mean of evaluation return.
298 |
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
299 |
300 |
301 |
if debug:
302 |
303 |
# define env and policy
304 |
env = self.env.clone(caller='evaluator')
305 |
306 |
if seed is not None and isinstance(seed, int):
307 |
seeds = [seed]
308 |
elif seed is not None and isinstance(seed, list):
309 |
seeds = seed
310 |
311 |
seeds = [self.seed]
312 |
313 |
returns = []
314 |
images = []
315 |
if enable_save_replay:
316 |
replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
317 |
318 |
319 |
logging.warning('No video would be generated during the deploy.')
320 |
if concatenate_all_replay:
321 |
logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
322 |
concatenate_all_replay = False
323 |
324 |
forward_fn = single_env_forward_wrapper_ttorch(self.policy.eval, self.cfg.cuda)
325 |
326 |
# reset first to make sure the env is in the initial state
327 |
# env will be reset again in the main loop
328 |
329 |
330 |
for seed in seeds:
331 |
env.seed(seed, dynamic_seed=False)
332 |
return_ = 0.
333 |
step = 0
334 |
obs = env.reset()
335 |
images.append(render(env)[None]) if concatenate_all_replay else None
336 |
while True:
337 |
action = forward_fn(obs)
338 |
obs, rew, done, info = env.step(action)
339 |
images.append(render(env)[None]) if concatenate_all_replay else None
340 |
return_ += rew
341 |
step += 1
342 |
if done:
343 |
344 |
+'DQN deploy is finished, final episode return with {step} steps is: {return_}')
345 |
346 |
347 |
348 |
349 |
if concatenate_all_replay:
350 |
images = np.concatenate(images, axis=0)
351 |
import imageio
352 |
imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
353 |
354 |
return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
355 |
356 |
def collect_data(
357 |
358 |
env_num: int = 8,
359 |
save_data_path: Optional[str] = None,
360 |
n_sample: Optional[int] = None,
361 |
n_episode: Optional[int] = None,
362 |
context: Optional[str] = None,
363 |
debug: bool = False
364 |
) -> None:
365 |
366 |
367 |
Collect data with PPO algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
368 |
The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
369 |
370 |
371 |
- env_num (:obj:`int`): The number of collector environments. Default to 8.
372 |
- save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
373 |
If not specified, the data will be saved in ``exp_name/demo_data``.
374 |
- n_sample (:obj:`int`): The number of samples to collect. Default to None. \
375 |
If not specified, ``n_episode`` must be specified.
376 |
- n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
377 |
If not specified, ``n_sample`` must be specified.
378 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
379 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
380 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
381 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
382 |
subprocess environment manager will be used.
383 |
384 |
385 |
if debug:
386 |
387 |
if n_episode is not None:
388 |
raise NotImplementedError
389 |
# define env and policy
390 |
env = self._setup_env_manager(env_num, context, debug, 'collector')
391 |
if save_data_path is None:
392 |
save_data_path = os.path.join(self.exp_name, 'demo_data')
393 |
394 |
# main execution task
395 |
with task.start(ctx=OnlineRLContext()):
396 |
task.use(PPOFStepCollector(self.seed, self.policy, env, n_sample))
397 |
task.use(offline_data_saver(save_data_path, data_type='hdf5'))
398 |
399 |
400 |
f'PPOF collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
401 |
402 |
403 |
def batch_evaluate(
404 |
405 |
env_num: int = 4,
406 |
n_evaluator_episode: int = 4,
407 |
context: Optional[str] = None,
408 |
debug: bool = False,
409 |
) -> EvalReturn:
410 |
411 |
412 |
Evaluate the agent with PPO algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
413 |
environments. The evaluation result will be returned.
414 |
The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
415 |
multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
416 |
will only create one evaluator environment to evaluate the agent and save the replay video.
417 |
418 |
- env_num (:obj:`int`): The number of evaluator environments. Default to 4.
419 |
- n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
420 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
421 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
422 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
423 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
424 |
subprocess environment manager will be used.
425 |
426 |
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
427 |
- eval_value (:obj:`np.float32`): The mean of evaluation return.
428 |
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
429 |
430 |
431 |
if debug:
432 |
433 |
# define env and policy
434 |
env = self._setup_env_manager(env_num, context, debug, 'evaluator')
435 |
436 |
# reset first to make sure the env is in the initial state
437 |
# env will be reset again in the main loop
438 |
439 |
440 |
441 |
# main execution task
442 |
with task.start(ctx=OnlineRLContext()):
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
452 |
453 |
def _setup_env_manager(
454 |
455 |
env_num: int,
456 |
context: Optional[str] = None,
457 |
debug: bool = False,
458 |
caller: str = 'collector'
459 |
) -> BaseEnvManagerV2:
460 |
461 |
462 |
Setup the environment manager. The environment manager is used to manage multiple environments.
463 |
464 |
- env_num (:obj:`int`): The number of environments.
465 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
466 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
467 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
468 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
469 |
subprocess environment manager will be used.
470 |
- caller (:obj:`str`): The caller of the environment manager. Default to 'collector'.
471 |
472 |
- (:obj:`BaseEnvManagerV2`): The environment manager.
473 |
474 |
assert caller in ['evaluator', 'collector']
475 |
if debug:
476 |
env_cls = BaseEnvManagerV2
477 |
manager_cfg = env_cls.default_config()
478 |
479 |
env_cls = SubprocessEnvManagerV2
480 |
manager_cfg = env_cls.default_config()
481 |
if context is not None:
482 |
manager_cfg.context = context
483 |
return env_cls([partial(self.env.clone, caller) for _ in range(env_num)], manager_cfg)
484 |
485 |
486 |
def best(self) -> 'PPOF':
487 |
488 |
489 |
Load the best model from the checkpoint directory, \
490 |
which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
491 |
The return value is the agent with the best model.
492 |
493 |
- (:obj:`PPOF`): The agent with the best model.
494 |
495 |
>>> agent = PPOF(env_id='LunarLander-v2')
496 |
>>> agent.train()
497 |
>>> agent =
498 |
499 |
.. note::
500 |
The best model is the model with the highest evaluation return. If this method is called, the current \
501 |
model will be replaced by the best model.
502 |
503 |
504 |
best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
505 |
# Load best model if it exists
506 |
if os.path.exists(best_model_file_path):
507 |
policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
508 |
509 |
return self
@@ -0,0 +1,457 @@
1 |
from typing import Optional, Union, List
2 |
from ditk import logging
3 |
from easydict import EasyDict
4 |
import os
5 |
import numpy as np
6 |
import torch
7 |
import treetensor.torch as ttorch
8 |
from ding.framework import task, OnlineRLContext
9 |
from ding.framework.middleware import CkptSaver, \
10 |
wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, data_pusher, \
11 |
OffPolicyLearner, final_ctx_saver
12 |
from ding.envs import BaseEnv
13 |
from ding.envs import setup_ding_env_manager
14 |
from ding.policy import SACPolicy
15 |
from ding.utils import set_pkg_seed
16 |
from ding.utils import get_env_fps, render
17 |
from ding.config import save_config_py, compile_config
18 |
from ding.model import ContinuousQAC
19 |
from ding.model import model_wrap
20 |
from import DequeBuffer
21 |
from ding.bonus.common import TrainingReturn, EvalReturn
22 |
from ding.config.example.SAC import supported_env_cfg
23 |
from ding.config.example.SAC import supported_env
24 |
25 |
26 |
class SACAgent:
27 |
28 |
29 |
Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
30 |
Soft Actor-Critic(SAC).
31 |
For more information about the system design of RL agent, please refer to \
32 |
33 |
34 |
``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
35 |
36 |
supported_env_list = list(supported_env_cfg.keys())
37 |
38 |
39 |
List of supported envs.
40 |
41 |
>>> from ding.bonus.sac import SACAgent
42 |
>>> print(SACAgent.supported_env_list)
43 |
44 |
45 |
def __init__(
46 |
47 |
env_id: str = None,
48 |
env: BaseEnv = None,
49 |
seed: int = 0,
50 |
exp_name: str = None,
51 |
model: Optional[torch.nn.Module] = None,
52 |
cfg: Optional[Union[EasyDict, dict]] = None,
53 |
policy_state_dict: str = None,
54 |
) -> None:
55 |
56 |
57 |
Initialize agent for SAC algorithm.
58 |
59 |
- env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
60 |
If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
61 |
If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
62 |
``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
63 |
- env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
64 |
If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
65 |
``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
66 |
If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
67 |
- seed (:obj:`int`): The random seed, which is set before running the program. \
68 |
Default to 0.
69 |
- exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
70 |
log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
71 |
- model (:obj:`torch.nn.Module`): The model of SAC algorithm, which should be an instance of class \
72 |
:class:`ding.model.ContinuousQAC`. \
73 |
If not specified, a default model will be generated according to the configuration.
74 |
- cfg (:obj:Union[EasyDict, dict]): The configuration of SAC algorithm, which is a dict. \
75 |
Default to None. If not specified, the default configuration will be used. \
76 |
The default configuration can be found in ``ding/config/example/SAC/``.
77 |
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
78 |
If specified, the policy will be loaded from this file. Default to None.
79 |
80 |
.. note::
81 |
An RL Agent Instance can be initialized in two basic ways. \
82 |
For example, we have an environment with id ``LunarLanderContinuous-v2`` registered in gym, \
83 |
and we want to train an agent with SAC algorithm with default configuration. \
84 |
Then we can initialize the agent in the following ways:
85 |
>>> agent = SACAgent(env_id='LunarLanderContinuous-v2')
86 |
or, if we want can specify the env_id in the configuration:
87 |
>>> cfg = {'env': {'env_id': 'LunarLanderContinuous-v2'}, 'policy': ...... }
88 |
>>> agent = SACAgent(cfg=cfg)
89 |
There are also other arguments to specify the agent when initializing.
90 |
For example, if we want to specify the environment instance:
91 |
>>> env = CustomizedEnv('LunarLanderContinuous-v2')
92 |
>>> agent = SACAgent(cfg=cfg, env=env)
93 |
or, if we want to specify the model:
94 |
>>> model = ContinuousQAC(**cfg.policy.model)
95 |
>>> agent = SACAgent(cfg=cfg, model=model)
96 |
or, if we want to reload the policy from a saved policy state dict:
97 |
>>> agent = SACAgent(cfg=cfg, policy_state_dict='LunarLanderContinuous-v2.pth.tar')
98 |
Make sure that the configuration is consistent with the saved policy state dict.
99 |
100 |
101 |
assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
102 |
103 |
if cfg is not None and not isinstance(cfg, EasyDict):
104 |
cfg = EasyDict(cfg)
105 |
106 |
if env_id is not None:
107 |
assert env_id in SACAgent.supported_env_list, "Please use supported envs: {}".format(
108 |
109 |
110 |
if cfg is None:
111 |
cfg = supported_env_cfg[env_id]
112 |
113 |
assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
114 |
115 |
assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
116 |
assert cfg.env.env_id in SACAgent.supported_env_list, "Please use supported envs: {}".format(
117 |
118 |
119 |
default_policy_config = EasyDict({"policy": SACPolicy.default_config()})
120 |
121 |
cfg = default_policy_config
122 |
123 |
if exp_name is not None:
124 |
cfg.exp_name = exp_name
125 |
self.cfg = compile_config(cfg, policy=SACPolicy)
126 |
self.exp_name = self.cfg.exp_name
127 |
if env is None:
128 |
self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
129 |
130 |
assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
131 |
self.env = env
132 |
133 |
134 |
self.seed = seed
135 |
set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
136 |
if not os.path.exists(self.exp_name):
137 |
138 |
save_config_py(self.cfg, os.path.join(self.exp_name, ''))
139 |
if model is None:
140 |
model = ContinuousQAC(**self.cfg.policy.model)
141 |
self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
142 |
self.policy = SACPolicy(self.cfg.policy, model=model)
143 |
if policy_state_dict is not None:
144 |
145 |
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
146 |
147 |
def train(
148 |
149 |
step: int = int(1e7),
150 |
collector_env_num: int = None,
151 |
evaluator_env_num: int = None,
152 |
n_iter_save_ckpt: int = 1000,
153 |
context: Optional[str] = None,
154 |
debug: bool = False,
155 |
wandb_sweep: bool = False,
156 |
) -> TrainingReturn:
157 |
158 |
159 |
Train the agent with SAC algorithm for ``step`` iterations with ``collector_env_num`` collector \
160 |
environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
161 |
recorded and saved by wandb.
162 |
163 |
- step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
164 |
- collector_env_num (:obj:`int`): The collector environment number. Default to None. \
165 |
If not specified, it will be set according to the configuration.
166 |
- evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
167 |
If not specified, it will be set according to the configuration.
168 |
- n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
169 |
Default to 1000.
170 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
171 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
172 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
173 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
174 |
subprocess environment manager will be used.
175 |
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
176 |
which is a hyper-parameter optimization process for seeking the best configurations. \
177 |
Default to False. If True, the wandb sweep id will be used as the experiment name.
178 |
179 |
- (:obj:`TrainingReturn`): The training result, of which the attributions are:
180 |
- wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
181 |
182 |
183 |
if debug:
184 |
185 |
186 |
# define env and policy
187 |
collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
188 |
evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
189 |
collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
190 |
evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
191 |
192 |
with task.start(ctx=OnlineRLContext()):
193 |
194 |
195 |
196 |
197 |
198 |
render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
199 |
200 |
201 |
task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
202 |
203 |
204 |
205 |
206 |
207 |
208 |
if hasattr(self.cfg.policy, 'random_collect_size') else 0,
209 |
210 |
211 |
task.use(data_pusher(self.cfg, self.buffer_))
212 |
task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
return TrainingReturn(wandb_url=task.ctx.wandb_url)
227 |
228 |
def deploy(
229 |
230 |
enable_save_replay: bool = False,
231 |
concatenate_all_replay: bool = False,
232 |
replay_save_path: str = None,
233 |
seed: Optional[Union[int, List]] = None,
234 |
debug: bool = False
235 |
) -> EvalReturn:
236 |
237 |
238 |
Deploy the agent with SAC algorithm by interacting with the environment, during which the replay video \
239 |
can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
240 |
241 |
- enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
242 |
- concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
243 |
Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
244 |
If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
245 |
the replay video of each episode will be saved separately.
246 |
- replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
247 |
If not specified, the video will be saved in ``exp_name/videos``.
248 |
- seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
249 |
Default to None. If not specified, ``self.seed`` will be used. \
250 |
If ``seed`` is an integer, the agent will be deployed once. \
251 |
If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
252 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
253 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
254 |
subprocess environment manager will be used.
255 |
256 |
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
257 |
- eval_value (:obj:`np.float32`): The mean of evaluation return.
258 |
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
259 |
260 |
261 |
if debug:
262 |
263 |
# define env and policy
264 |
env = self.env.clone(caller='evaluator')
265 |
266 |
if seed is not None and isinstance(seed, int):
267 |
seeds = [seed]
268 |
elif seed is not None and isinstance(seed, list):
269 |
seeds = seed
270 |
271 |
seeds = [self.seed]
272 |
273 |
returns = []
274 |
images = []
275 |
if enable_save_replay:
276 |
replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
277 |
278 |
279 |
logging.warning('No video would be generated during the deploy.')
280 |
if concatenate_all_replay:
281 |
logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
282 |
concatenate_all_replay = False
283 |
284 |
def single_env_forward_wrapper(forward_fn, cuda=True):
285 |
286 |
forward_fn = model_wrap(forward_fn, wrapper_name='base').forward
287 |
288 |
def _forward(obs):
289 |
# unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
290 |
obs = ttorch.as_tensor(obs).unsqueeze(0)
291 |
if cuda and torch.cuda.is_available():
292 |
obs = obs.cuda()
293 |
(mu, sigma) = forward_fn(obs, mode='compute_actor')['logit']
294 |
action = torch.tanh(mu).detach().cpu().numpy()[0] # deterministic_eval
295 |
return action
296 |
297 |
return _forward
298 |
299 |
forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
300 |
301 |
# reset first to make sure the env is in the initial state
302 |
# env will be reset again in the main loop
303 |
304 |
305 |
for seed in seeds:
306 |
env.seed(seed, dynamic_seed=False)
307 |
return_ = 0.
308 |
step = 0
309 |
obs = env.reset()
310 |
images.append(render(env)[None]) if concatenate_all_replay else None
311 |
while True:
312 |
action = forward_fn(obs)
313 |
obs, rew, done, info = env.step(action)
314 |
images.append(render(env)[None]) if concatenate_all_replay else None
315 |
return_ += rew
316 |
step += 1
317 |
if done:
318 |
319 |
+'DQN deploy is finished, final episode return with {step} steps is: {return_}')
320 |
321 |
322 |
323 |
324 |
if concatenate_all_replay:
325 |
images = np.concatenate(images, axis=0)
326 |
import imageio
327 |
imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
328 |
329 |
return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
330 |
331 |
def collect_data(
332 |
333 |
env_num: int = 8,
334 |
save_data_path: Optional[str] = None,
335 |
n_sample: Optional[int] = None,
336 |
n_episode: Optional[int] = None,
337 |
context: Optional[str] = None,
338 |
debug: bool = False
339 |
) -> None:
340 |
341 |
342 |
Collect data with SAC algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
343 |
The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
344 |
345 |
346 |
- env_num (:obj:`int`): The number of collector environments. Default to 8.
347 |
- save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
348 |
If not specified, the data will be saved in ``exp_name/demo_data``.
349 |
- n_sample (:obj:`int`): The number of samples to collect. Default to None. \
350 |
If not specified, ``n_episode`` must be specified.
351 |
- n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
352 |
If not specified, ``n_sample`` must be specified.
353 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
354 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
355 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
356 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
357 |
subprocess environment manager will be used.
358 |
359 |
360 |
if debug:
361 |
362 |
if n_episode is not None:
363 |
raise NotImplementedError
364 |
# define env and policy
365 |
env_num = env_num if env_num else self.cfg.env.collector_env_num
366 |
env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
367 |
368 |
if save_data_path is None:
369 |
save_data_path = os.path.join(self.exp_name, 'demo_data')
370 |
371 |
# main execution task
372 |
with task.start(ctx=OnlineRLContext()):
373 |
374 |
375 |
self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
376 |
377 |
378 |
task.use(offline_data_saver(save_data_path, data_type='hdf5'))
379 |
380 |
381 |
f'SAC collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
382 |
383 |
384 |
def batch_evaluate(
385 |
386 |
env_num: int = 4,
387 |
n_evaluator_episode: int = 4,
388 |
context: Optional[str] = None,
389 |
debug: bool = False
390 |
) -> EvalReturn:
391 |
392 |
393 |
Evaluate the agent with SAC algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
394 |
environments. The evaluation result will be returned.
395 |
The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
396 |
multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
397 |
will only create one evaluator environment to evaluate the agent and save the replay video.
398 |
399 |
- env_num (:obj:`int`): The number of evaluator environments. Default to 4.
400 |
- n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
401 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
402 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
403 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
404 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
405 |
subprocess environment manager will be used.
406 |
407 |
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
408 |
- eval_value (:obj:`np.float32`): The mean of evaluation return.
409 |
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
410 |
411 |
412 |
if debug:
413 |
414 |
# define env and policy
415 |
env_num = env_num if env_num else self.cfg.env.evaluator_env_num
416 |
env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
417 |
418 |
# reset first to make sure the env is in the initial state
419 |
# env will be reset again in the main loop
420 |
421 |
422 |
423 |
evaluate_cfg = self.cfg
424 |
evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
425 |
426 |
# main execution task
427 |
with task.start(ctx=OnlineRLContext()):
428 |
task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
429 |
430 |
431 |
return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
432 |
433 |
434 |
def best(self) -> 'SACAgent':
435 |
436 |
437 |
Load the best model from the checkpoint directory, \
438 |
which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
439 |
The return value is the agent with the best model.
440 |
441 |
- (:obj:`SACAgent`): The agent with the best model.
442 |
443 |
>>> agent = SACAgent(env_id='LunarLanderContinuous-v2')
444 |
>>> agent.train()
445 |
>>> agent =
446 |
447 |
.. note::
448 |
The best model is the model with the highest evaluation return. If this method is called, the current \
449 |
model will be replaced by the best model.
450 |
451 |
452 |
best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
453 |
# Load best model if it exists
454 |
if os.path.exists(best_model_file_path):
455 |
policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
456 |
457 |
return self
@@ -0,0 +1,461 @@
1 |
from typing import Optional, Union, List
2 |
from ditk import logging
3 |
from easydict import EasyDict
4 |
import os
5 |
import numpy as np
6 |
import torch
7 |
import treetensor.torch as ttorch
8 |
from ding.framework import task, OnlineRLContext
9 |
from ding.framework.middleware import CkptSaver, \
10 |
wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, data_pusher, \
11 |
OffPolicyLearner, final_ctx_saver, nstep_reward_enhancer, eps_greedy_handler
12 |
from ding.envs import BaseEnv
13 |
from ding.envs import setup_ding_env_manager
14 |
from ding.policy import SQLPolicy
15 |
from ding.utils import set_pkg_seed
16 |
from ding.utils import get_env_fps, render
17 |
from ding.config import save_config_py, compile_config
18 |
from ding.model import DQN
19 |
from ding.model import model_wrap
20 |
from import DequeBuffer
21 |
from ding.bonus.common import TrainingReturn, EvalReturn
22 |
from ding.config.example.SQL import supported_env_cfg
23 |
from ding.config.example.SQL import supported_env
24 |
25 |
26 |
class SQLAgent:
27 |
28 |
29 |
Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
30 |
Soft Q-Learning(SQL).
31 |
For more information about the system design of RL agent, please refer to \
32 |
33 |
34 |
``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
35 |
36 |
supported_env_list = list(supported_env_cfg.keys())
37 |
38 |
39 |
List of supported envs.
40 |
41 |
>>> from ding.bonus.sql import SQLAgent
42 |
>>> print(SQLAgent.supported_env_list)
43 |
44 |
45 |
def __init__(
46 |
47 |
env_id: str = None,
48 |
env: BaseEnv = None,
49 |
seed: int = 0,
50 |
exp_name: str = None,
51 |
model: Optional[torch.nn.Module] = None,
52 |
cfg: Optional[Union[EasyDict, dict]] = None,
53 |
policy_state_dict: str = None,
54 |
) -> None:
55 |
56 |
57 |
Initialize agent for SQL algorithm.
58 |
59 |
- env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
60 |
If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
61 |
If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
62 |
``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
63 |
- env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
64 |
If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
65 |
``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
66 |
If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
67 |
- seed (:obj:`int`): The random seed, which is set before running the program. \
68 |
Default to 0.
69 |
- exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
70 |
log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
71 |
- model (:obj:`torch.nn.Module`): The model of SQL algorithm, which should be an instance of class \
72 |
:class:`ding.model.DQN`. \
73 |
If not specified, a default model will be generated according to the configuration.
74 |
- cfg (:obj:Union[EasyDict, dict]): The configuration of SQL algorithm, which is a dict. \
75 |
Default to None. If not specified, the default configuration will be used. \
76 |
The default configuration can be found in ``ding/config/example/SQL/``.
77 |
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
78 |
If specified, the policy will be loaded from this file. Default to None.
79 |
80 |
.. note::
81 |
An RL Agent Instance can be initialized in two basic ways. \
82 |
For example, we have an environment with id ``LunarLander-v2`` registered in gym, \
83 |
and we want to train an agent with SQL algorithm with default configuration. \
84 |
Then we can initialize the agent in the following ways:
85 |
>>> agent = SQLAgent(env_id='LunarLander-v2')
86 |
or, if we want can specify the env_id in the configuration:
87 |
>>> cfg = {'env': {'env_id': 'LunarLander-v2'}, 'policy': ...... }
88 |
>>> agent = SQLAgent(cfg=cfg)
89 |
There are also other arguments to specify the agent when initializing.
90 |
For example, if we want to specify the environment instance:
91 |
>>> env = CustomizedEnv('LunarLander-v2')
92 |
>>> agent = SQLAgent(cfg=cfg, env=env)
93 |
or, if we want to specify the model:
94 |
>>> model = DQN(**cfg.policy.model)
95 |
>>> agent = SQLAgent(cfg=cfg, model=model)
96 |
or, if we want to reload the policy from a saved policy state dict:
97 |
>>> agent = SQLAgent(cfg=cfg, policy_state_dict='LunarLander-v2.pth.tar')
98 |
Make sure that the configuration is consistent with the saved policy state dict.
99 |
100 |
101 |
assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
102 |
103 |
if cfg is not None and not isinstance(cfg, EasyDict):
104 |
cfg = EasyDict(cfg)
105 |
106 |
if env_id is not None:
107 |
assert env_id in SQLAgent.supported_env_list, "Please use supported envs: {}".format(
108 |
109 |
110 |
if cfg is None:
111 |
cfg = supported_env_cfg[env_id]
112 |
113 |
assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
114 |
115 |
assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
116 |
assert cfg.env.env_id in SQLAgent.supported_env_list, "Please use supported envs: {}".format(
117 |
118 |
119 |
default_policy_config = EasyDict({"policy": SQLPolicy.default_config()})
120 |
121 |
cfg = default_policy_config
122 |
123 |
if exp_name is not None:
124 |
cfg.exp_name = exp_name
125 |
self.cfg = compile_config(cfg, policy=SQLPolicy)
126 |
self.exp_name = self.cfg.exp_name
127 |
if env is None:
128 |
self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
129 |
130 |
assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
131 |
self.env = env
132 |
133 |
134 |
self.seed = seed
135 |
set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
136 |
if not os.path.exists(self.exp_name):
137 |
138 |
save_config_py(self.cfg, os.path.join(self.exp_name, ''))
139 |
if model is None:
140 |
model = DQN(**self.cfg.policy.model)
141 |
self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
142 |
self.policy = SQLPolicy(self.cfg.policy, model=model)
143 |
if policy_state_dict is not None:
144 |
145 |
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
146 |
147 |
def train(
148 |
149 |
step: int = int(1e7),
150 |
collector_env_num: int = None,
151 |
evaluator_env_num: int = None,
152 |
n_iter_save_ckpt: int = 1000,
153 |
context: Optional[str] = None,
154 |
debug: bool = False,
155 |
wandb_sweep: bool = False,
156 |
) -> TrainingReturn:
157 |
158 |
159 |
Train the agent with SQL algorithm for ``step`` iterations with ``collector_env_num`` collector \
160 |
environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
161 |
recorded and saved by wandb.
162 |
163 |
- step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
164 |
- collector_env_num (:obj:`int`): The collector environment number. Default to None. \
165 |
If not specified, it will be set according to the configuration.
166 |
- evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
167 |
If not specified, it will be set according to the configuration.
168 |
- n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
169 |
Default to 1000.
170 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
171 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
172 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
173 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
174 |
subprocess environment manager will be used.
175 |
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
176 |
which is a hyper-parameter optimization process for seeking the best configurations. \
177 |
Default to False. If True, the wandb sweep id will be used as the experiment name.
178 |
179 |
- (:obj:`TrainingReturn`): The training result, of which the attributions are:
180 |
- wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
181 |
182 |
183 |
if debug:
184 |
185 |
186 |
# define env and policy
187 |
collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
188 |
evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
189 |
collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
190 |
evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
191 |
192 |
with task.start(ctx=OnlineRLContext()):
193 |
194 |
195 |
196 |
197 |
198 |
render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
199 |
200 |
201 |
task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
if hasattr(self.cfg.policy, 'random_collect_size') else 0,
210 |
211 |
212 |
if "nstep" in self.cfg.policy and self.cfg.policy.nstep > 1:
213 |
214 |
task.use(data_pusher(self.cfg, self.buffer_))
215 |
task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
return TrainingReturn(wandb_url=task.ctx.wandb_url)
230 |
231 |
def deploy(
232 |
233 |
enable_save_replay: bool = False,
234 |
concatenate_all_replay: bool = False,
235 |
replay_save_path: str = None,
236 |
seed: Optional[Union[int, List]] = None,
237 |
debug: bool = False
238 |
) -> EvalReturn:
239 |
240 |
241 |
Deploy the agent with SQL algorithm by interacting with the environment, during which the replay video \
242 |
can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
243 |
244 |
- enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
245 |
- concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
246 |
Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
247 |
If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
248 |
the replay video of each episode will be saved separately.
249 |
- replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
250 |
If not specified, the video will be saved in ``exp_name/videos``.
251 |
- seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
252 |
Default to None. If not specified, ``self.seed`` will be used. \
253 |
If ``seed`` is an integer, the agent will be deployed once. \
254 |
If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
255 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
256 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
257 |
subprocess environment manager will be used.
258 |
259 |
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
260 |
- eval_value (:obj:`np.float32`): The mean of evaluation return.
261 |
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
262 |
263 |
264 |
if debug:
265 |
266 |
# define env and policy
267 |
env = self.env.clone(caller='evaluator')
268 |
269 |
if seed is not None and isinstance(seed, int):
270 |
seeds = [seed]
271 |
elif seed is not None and isinstance(seed, list):
272 |
seeds = seed
273 |
274 |
seeds = [self.seed]
275 |
276 |
returns = []
277 |
images = []
278 |
if enable_save_replay:
279 |
replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
280 |
281 |
282 |
logging.warning('No video would be generated during the deploy.')
283 |
if concatenate_all_replay:
284 |
logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
285 |
concatenate_all_replay = False
286 |
287 |
def single_env_forward_wrapper(forward_fn, cuda=True):
288 |
289 |
forward_fn = model_wrap(forward_fn, wrapper_name='argmax_sample').forward
290 |
291 |
def _forward(obs):
292 |
# unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
293 |
obs = ttorch.as_tensor(obs).unsqueeze(0)
294 |
if cuda and torch.cuda.is_available():
295 |
obs = obs.cuda()
296 |
action = forward_fn(obs)["action"]
297 |
# squeeze means delete batch dim, i.e. (1, A) -> (A, )
298 |
action = action.squeeze(0).detach().cpu().numpy()
299 |
return action
300 |
301 |
return _forward
302 |
303 |
forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
304 |
305 |
# reset first to make sure the env is in the initial state
306 |
# env will be reset again in the main loop
307 |
308 |
309 |
for seed in seeds:
310 |
env.seed(seed, dynamic_seed=False)
311 |
return_ = 0.
312 |
step = 0
313 |
obs = env.reset()
314 |
images.append(render(env)[None]) if concatenate_all_replay else None
315 |
while True:
316 |
action = forward_fn(obs)
317 |
obs, rew, done, info = env.step(action)
318 |
images.append(render(env)[None]) if concatenate_all_replay else None
319 |
return_ += rew
320 |
step += 1
321 |
if done:
322 |
323 |
+'SQL deploy is finished, final episode return with {step} steps is: {return_}')
324 |
325 |
326 |
327 |
328 |
if concatenate_all_replay:
329 |
images = np.concatenate(images, axis=0)
330 |
import imageio
331 |
imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
332 |
333 |
return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
334 |
335 |
def collect_data(
336 |
337 |
env_num: int = 8,
338 |
save_data_path: Optional[str] = None,
339 |
n_sample: Optional[int] = None,
340 |
n_episode: Optional[int] = None,
341 |
context: Optional[str] = None,
342 |
debug: bool = False
343 |
) -> None:
344 |
345 |
346 |
Collect data with SQL algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
347 |
The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
348 |
349 |
350 |
- env_num (:obj:`int`): The number of collector environments. Default to 8.
351 |
- save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
352 |
If not specified, the data will be saved in ``exp_name/demo_data``.
353 |
- n_sample (:obj:`int`): The number of samples to collect. Default to None. \
354 |
If not specified, ``n_episode`` must be specified.
355 |
- n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
356 |
If not specified, ``n_sample`` must be specified.
357 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
358 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
359 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
360 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
361 |
subprocess environment manager will be used.
362 |
363 |
364 |
if debug:
365 |
366 |
if n_episode is not None:
367 |
raise NotImplementedError
368 |
# define env and policy
369 |
env_num = env_num if env_num else self.cfg.env.collector_env_num
370 |
env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
371 |
372 |
if save_data_path is None:
373 |
save_data_path = os.path.join(self.exp_name, 'demo_data')
374 |
375 |
# main execution task
376 |
with task.start(ctx=OnlineRLContext()):
377 |
378 |
379 |
self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
380 |
381 |
382 |
task.use(offline_data_saver(save_data_path, data_type='hdf5'))
383 |
384 |
385 |
f'SQL collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
386 |
387 |
388 |
def batch_evaluate(
389 |
390 |
env_num: int = 4,
391 |
n_evaluator_episode: int = 4,
392 |
context: Optional[str] = None,
393 |
debug: bool = False
394 |
) -> EvalReturn:
395 |
396 |
397 |
Evaluate the agent with SQL algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
398 |
environments. The evaluation result will be returned.
399 |
The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
400 |
multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
401 |
will only create one evaluator environment to evaluate the agent and save the replay video.
402 |
403 |
- env_num (:obj:`int`): The number of evaluator environments. Default to 4.
404 |
- n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
405 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
406 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
407 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
408 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
409 |
subprocess environment manager will be used.
410 |
411 |
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
412 |
- eval_value (:obj:`np.float32`): The mean of evaluation return.
413 |
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
414 |
415 |
416 |
if debug:
417 |
418 |
# define env and policy
419 |
env_num = env_num if env_num else self.cfg.env.evaluator_env_num
420 |
env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
421 |
422 |
# reset first to make sure the env is in the initial state
423 |
# env will be reset again in the main loop
424 |
425 |
426 |
427 |
evaluate_cfg = self.cfg
428 |
evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
429 |
430 |
# main execution task
431 |
with task.start(ctx=OnlineRLContext()):
432 |
task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
433 |
434 |
435 |
return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
436 |
437 |
438 |
def best(self) -> 'SQLAgent':
439 |
440 |
441 |
Load the best model from the checkpoint directory, \
442 |
which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
443 |
The return value is the agent with the best model.
444 |
445 |
- (:obj:`SQLAgent`): The agent with the best model.
446 |
447 |
>>> agent = SQLAgent(env_id='LunarLander-v2')
448 |
>>> agent.train()
449 |
>>> agent =
450 |
451 |
.. note::
452 |
The best model is the model with the highest evaluation return. If this method is called, the current \
453 |
model will be replaced by the best model.
454 |
455 |
456 |
best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
457 |
# Load best model if it exists
458 |
if os.path.exists(best_model_file_path):
459 |
policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
460 |
461 |
return self
@@ -0,0 +1,455 @@
1 |
from typing import Optional, Union, List
2 |
from ditk import logging
3 |
from easydict import EasyDict
4 |
import os
5 |
import numpy as np
6 |
import torch
7 |
import treetensor.torch as ttorch
8 |
from ding.framework import task, OnlineRLContext
9 |
from ding.framework.middleware import CkptSaver, \
10 |
wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, data_pusher, \
11 |
OffPolicyLearner, final_ctx_saver
12 |
from ding.envs import BaseEnv
13 |
from ding.envs import setup_ding_env_manager
14 |
from ding.policy import TD3Policy
15 |
from ding.utils import set_pkg_seed
16 |
from ding.utils import get_env_fps, render
17 |
from ding.config import save_config_py, compile_config
18 |
from ding.model import ContinuousQAC
19 |
from import DequeBuffer
20 |
from ding.bonus.common import TrainingReturn, EvalReturn
21 |
from ding.config.example.TD3 import supported_env_cfg
22 |
from ding.config.example.TD3 import supported_env
23 |
24 |
25 |
class TD3Agent:
26 |
27 |
28 |
Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
29 |
Twin Delayed Deep Deterministic Policy Gradient(TD3).
30 |
For more information about the system design of RL agent, please refer to \
31 |
32 |
33 |
``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
34 |
35 |
supported_env_list = list(supported_env_cfg.keys())
36 |
37 |
38 |
List of supported envs.
39 |
40 |
>>> from ding.bonus.td3 import TD3Agent
41 |
>>> print(TD3Agent.supported_env_list)
42 |
43 |
44 |
def __init__(
45 |
46 |
env_id: str = None,
47 |
env: BaseEnv = None,
48 |
seed: int = 0,
49 |
exp_name: str = None,
50 |
model: Optional[torch.nn.Module] = None,
51 |
cfg: Optional[Union[EasyDict, dict]] = None,
52 |
policy_state_dict: str = None,
53 |
) -> None:
54 |
55 |
56 |
Initialize agent for TD3 algorithm.
57 |
58 |
- env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
59 |
If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
60 |
If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
61 |
``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
62 |
- env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
63 |
If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
64 |
``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
65 |
If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
66 |
- seed (:obj:`int`): The random seed, which is set before running the program. \
67 |
Default to 0.
68 |
- exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
69 |
log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
70 |
- model (:obj:`torch.nn.Module`): The model of TD3 algorithm, which should be an instance of class \
71 |
:class:`ding.model.ContinuousQAC`. \
72 |
If not specified, a default model will be generated according to the configuration.
73 |
- cfg (:obj:Union[EasyDict, dict]): The configuration of TD3 algorithm, which is a dict. \
74 |
Default to None. If not specified, the default configuration will be used. \
75 |
The default configuration can be found in ``ding/config/example/TD3/``.
76 |
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
77 |
If specified, the policy will be loaded from this file. Default to None.
78 |
79 |
.. note::
80 |
An RL Agent Instance can be initialized in two basic ways. \
81 |
For example, we have an environment with id ``LunarLanderContinuous-v2`` registered in gym, \
82 |
and we want to train an agent with TD3 algorithm with default configuration. \
83 |
Then we can initialize the agent in the following ways:
84 |
>>> agent = TD3Agent(env_id='LunarLanderContinuous-v2')
85 |
or, if we want can specify the env_id in the configuration:
86 |
>>> cfg = {'env': {'env_id': 'LunarLanderContinuous-v2'}, 'policy': ...... }
87 |
>>> agent = TD3Agent(cfg=cfg)
88 |
There are also other arguments to specify the agent when initializing.
89 |
For example, if we want to specify the environment instance:
90 |
>>> env = CustomizedEnv('LunarLanderContinuous-v2')
91 |
>>> agent = TD3Agent(cfg=cfg, env=env)
92 |
or, if we want to specify the model:
93 |
>>> model = ContinuousQAC(**cfg.policy.model)
94 |
>>> agent = TD3Agent(cfg=cfg, model=model)
95 |
or, if we want to reload the policy from a saved policy state dict:
96 |
>>> agent = TD3Agent(cfg=cfg, policy_state_dict='LunarLanderContinuous-v2.pth.tar')
97 |
Make sure that the configuration is consistent with the saved policy state dict.
98 |
99 |
100 |
assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
101 |
102 |
if cfg is not None and not isinstance(cfg, EasyDict):
103 |
cfg = EasyDict(cfg)
104 |
105 |
if env_id is not None:
106 |
assert env_id in TD3Agent.supported_env_list, "Please use supported envs: {}".format(
107 |
108 |
109 |
if cfg is None:
110 |
cfg = supported_env_cfg[env_id]
111 |
112 |
assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
113 |
114 |
assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
115 |
assert cfg.env.env_id in TD3Agent.supported_env_list, "Please use supported envs: {}".format(
116 |
117 |
118 |
default_policy_config = EasyDict({"policy": TD3Policy.default_config()})
119 |
120 |
cfg = default_policy_config
121 |
122 |
if exp_name is not None:
123 |
cfg.exp_name = exp_name
124 |
self.cfg = compile_config(cfg, policy=TD3Policy)
125 |
self.exp_name = self.cfg.exp_name
126 |
if env is None:
127 |
self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
128 |
129 |
assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
130 |
self.env = env
131 |
132 |
133 |
self.seed = seed
134 |
set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
135 |
if not os.path.exists(self.exp_name):
136 |
137 |
save_config_py(self.cfg, os.path.join(self.exp_name, ''))
138 |
if model is None:
139 |
model = ContinuousQAC(**self.cfg.policy.model)
140 |
self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
141 |
self.policy = TD3Policy(self.cfg.policy, model=model)
142 |
if policy_state_dict is not None:
143 |
144 |
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
145 |
146 |
def train(
147 |
148 |
step: int = int(1e7),
149 |
collector_env_num: int = None,
150 |
evaluator_env_num: int = None,
151 |
n_iter_save_ckpt: int = 1000,
152 |
context: Optional[str] = None,
153 |
debug: bool = False,
154 |
wandb_sweep: bool = False,
155 |
) -> TrainingReturn:
156 |
157 |
158 |
Train the agent with TD3 algorithm for ``step`` iterations with ``collector_env_num`` collector \
159 |
environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
160 |
recorded and saved by wandb.
161 |
162 |
- step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
163 |
- collector_env_num (:obj:`int`): The collector environment number. Default to None. \
164 |
If not specified, it will be set according to the configuration.
165 |
- evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
166 |
If not specified, it will be set according to the configuration.
167 |
- n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
168 |
Default to 1000.
169 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
170 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
171 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
172 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
173 |
subprocess environment manager will be used.
174 |
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
175 |
which is a hyper-parameter optimization process for seeking the best configurations. \
176 |
Default to False. If True, the wandb sweep id will be used as the experiment name.
177 |
178 |
- (:obj:`TrainingReturn`): The training result, of which the attributions are:
179 |
- wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
180 |
181 |
182 |
if debug:
183 |
184 |
185 |
# define env and policy
186 |
collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
187 |
evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
188 |
collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
189 |
evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
190 |
191 |
with task.start(ctx=OnlineRLContext()):
192 |
193 |
194 |
195 |
196 |
197 |
render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
198 |
199 |
200 |
task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
201 |
202 |
203 |
204 |
205 |
206 |
207 |
if hasattr(self.cfg.policy, 'random_collect_size') else 0,
208 |
209 |
210 |
task.use(data_pusher(self.cfg, self.buffer_))
211 |
task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
return TrainingReturn(wandb_url=task.ctx.wandb_url)
226 |
227 |
def deploy(
228 |
229 |
enable_save_replay: bool = False,
230 |
concatenate_all_replay: bool = False,
231 |
replay_save_path: str = None,
232 |
seed: Optional[Union[int, List]] = None,
233 |
debug: bool = False
234 |
) -> EvalReturn:
235 |
236 |
237 |
Deploy the agent with TD3 algorithm by interacting with the environment, during which the replay video \
238 |
can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
239 |
240 |
- enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
241 |
- concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
242 |
Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
243 |
If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
244 |
the replay video of each episode will be saved separately.
245 |
- replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
246 |
If not specified, the video will be saved in ``exp_name/videos``.
247 |
- seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
248 |
Default to None. If not specified, ``self.seed`` will be used. \
249 |
If ``seed`` is an integer, the agent will be deployed once. \
250 |
If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
251 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
252 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
253 |
subprocess environment manager will be used.
254 |
255 |
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
256 |
- eval_value (:obj:`np.float32`): The mean of evaluation return.
257 |
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
258 |
259 |
260 |
if debug:
261 |
262 |
# define env and policy
263 |
env = self.env.clone(caller='evaluator')
264 |
265 |
if seed is not None and isinstance(seed, int):
266 |
seeds = [seed]
267 |
elif seed is not None and isinstance(seed, list):
268 |
seeds = seed
269 |
270 |
seeds = [self.seed]
271 |
272 |
returns = []
273 |
images = []
274 |
if enable_save_replay:
275 |
replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
276 |
277 |
278 |
logging.warning('No video would be generated during the deploy.')
279 |
if concatenate_all_replay:
280 |
logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
281 |
concatenate_all_replay = False
282 |
283 |
def single_env_forward_wrapper(forward_fn, cuda=True):
284 |
285 |
def _forward(obs):
286 |
# unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
287 |
obs = ttorch.as_tensor(obs).unsqueeze(0)
288 |
if cuda and torch.cuda.is_available():
289 |
obs = obs.cuda()
290 |
action = forward_fn(obs, mode='compute_actor')["action"]
291 |
# squeeze means delete batch dim, i.e. (1, A) -> (A, )
292 |
action = action.squeeze(0).detach().cpu().numpy()
293 |
return action
294 |
295 |
return _forward
296 |
297 |
forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
298 |
299 |
# reset first to make sure the env is in the initial state
300 |
# env will be reset again in the main loop
301 |
302 |
303 |
for seed in seeds:
304 |
env.seed(seed, dynamic_seed=False)
305 |
return_ = 0.
306 |
step = 0
307 |
obs = env.reset()
308 |
images.append(render(env)[None]) if concatenate_all_replay else None
309 |
while True:
310 |
action = forward_fn(obs)
311 |
obs, rew, done, info = env.step(action)
312 |
images.append(render(env)[None]) if concatenate_all_replay else None
313 |
return_ += rew
314 |
step += 1
315 |
if done:
316 |
317 |
+'DQN deploy is finished, final episode return with {step} steps is: {return_}')
318 |
319 |
320 |
321 |
322 |
if concatenate_all_replay:
323 |
images = np.concatenate(images, axis=0)
324 |
import imageio
325 |
imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
326 |
327 |
return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
328 |
329 |
def collect_data(
330 |
331 |
env_num: int = 8,
332 |
save_data_path: Optional[str] = None,
333 |
n_sample: Optional[int] = None,
334 |
n_episode: Optional[int] = None,
335 |
context: Optional[str] = None,
336 |
debug: bool = False
337 |
) -> None:
338 |
339 |
340 |
Collect data with TD3 algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
341 |
The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
342 |
343 |
344 |
- env_num (:obj:`int`): The number of collector environments. Default to 8.
345 |
- save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
346 |
If not specified, the data will be saved in ``exp_name/demo_data``.
347 |
- n_sample (:obj:`int`): The number of samples to collect. Default to None. \
348 |
If not specified, ``n_episode`` must be specified.
349 |
- n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
350 |
If not specified, ``n_sample`` must be specified.
351 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
352 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
353 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
354 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
355 |
subprocess environment manager will be used.
356 |
357 |
358 |
if debug:
359 |
360 |
if n_episode is not None:
361 |
raise NotImplementedError
362 |
# define env and policy
363 |
env_num = env_num if env_num else self.cfg.env.collector_env_num
364 |
env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
365 |
366 |
if save_data_path is None:
367 |
save_data_path = os.path.join(self.exp_name, 'demo_data')
368 |
369 |
# main execution task
370 |
with task.start(ctx=OnlineRLContext()):
371 |
372 |
373 |
self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
374 |
375 |
376 |
task.use(offline_data_saver(save_data_path, data_type='hdf5'))
377 |
378 |
379 |
f'TD3 collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
380 |
381 |
382 |
def batch_evaluate(
383 |
384 |
env_num: int = 4,
385 |
n_evaluator_episode: int = 4,
386 |
context: Optional[str] = None,
387 |
debug: bool = False
388 |
) -> EvalReturn:
389 |
390 |
391 |
Evaluate the agent with TD3 algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
392 |
environments. The evaluation result will be returned.
393 |
The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
394 |
multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
395 |
will only create one evaluator environment to evaluate the agent and save the replay video.
396 |
397 |
- env_num (:obj:`int`): The number of evaluator environments. Default to 4.
398 |
- n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
399 |
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
400 |
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
401 |
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
402 |
If set True, base environment manager will be used for easy debugging. Otherwise, \
403 |
subprocess environment manager will be used.
404 |
405 |
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
406 |
- eval_value (:obj:`np.float32`): The mean of evaluation return.
407 |
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
408 |
409 |
410 |
if debug:
411 |
412 |
# define env and policy
413 |
env_num = env_num if env_num else self.cfg.env.evaluator_env_num
414 |
env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
415 |
416 |
# reset first to make sure the env is in the initial state
417 |
# env will be reset again in the main loop
418 |
419 |
420 |
421 |
evaluate_cfg = self.cfg
422 |
evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
423 |
424 |
# main execution task
425 |
with task.start(ctx=OnlineRLContext()):
426 |
task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
427 |
428 |
429 |
return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
430 |
431 |
432 |
def best(self) -> 'TD3Agent':
433 |
434 |
435 |
Load the best model from the checkpoint directory, \
436 |
which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
437 |
The return value is the agent with the best model.
438 |
439 |
- (:obj:`TD3Agent`): The agent with the best model.
440 |
441 |
>>> agent = TD3Agent(env_id='LunarLanderContinuous-v2')
442 |
>>> agent.train()
443 |
444 |
445 |
.. note::
446 |
The best model is the model with the highest evaluation return. If this method is called, the current \
447 |
model will be replaced by the best model.
448 |
449 |
450 |
best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
451 |
# Load best model if it exists
452 |
if os.path.exists(best_model_file_path):
453 |
policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
454 |
455 |
return self
@@ -0,0 +1,9 @@
1 |
import torch
2 |
3 |
4 |
def torch_ge_131():
5 |
return int("".join(list(filter(str.isdigit, torch.__version__)))) >= 131
6 |
7 |
8 |
def torch_ge_180():
9 |
return int("".join(list(filter(str.isdigit, torch.__version__)))) >= 180
@@ -0,0 +1,4 @@
1 |
from .config import Config, read_config, save_config, compile_config, compile_config_parallel, read_config_directly, \
2 |
read_config_with_system, save_config_py
3 |
from .utils import parallel_transform, parallel_transform_slurm
4 |
from .example import A2C, C51, DDPG, DQN, PG, PPOF, PPOOffPolicy, SAC, SQL, TD3
@@ -0,0 +1,579 @@
1 |
import os
2 |
import os.path as osp
3 |
import yaml
4 |
import json
5 |
import shutil
6 |
import sys
7 |
import time
8 |
import tempfile
9 |
import subprocess
10 |
import datetime
11 |
from importlib import import_module
12 |
from typing import Optional, Tuple
13 |
from easydict import EasyDict
14 |
from copy import deepcopy
15 |
16 |
from ding.utils import deep_merge_dicts, get_rank
17 |
from ding.envs import get_env_cls, get_env_manager_cls, BaseEnvManager
18 |
from ding.policy import get_policy_cls
19 |
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, Coordinator, \
20 |
AdvancedReplayBuffer, get_parallel_commander_cls, get_parallel_collector_cls, get_buffer_cls, \
21 |
get_serial_collector_cls, MetricSerialEvaluator, BattleInteractionSerialEvaluator
22 |
from ding.reward_model import get_reward_model_cls
23 |
from ding.world_model import get_world_model_cls
24 |
from .utils import parallel_transform, parallel_transform_slurm, parallel_transform_k8s, save_config_formatted
25 |
26 |
27 |
class Config(object):
28 |
29 |
30 |
Base class for config.
31 |
32 |
__init__, file_to_dict
33 |
34 |
35 |
36 |
37 |
def __init__(
38 |
39 |
cfg_dict: Optional[dict] = None,
40 |
cfg_text: Optional[str] = None,
41 |
filename: Optional[str] = None
42 |
) -> None:
43 |
44 |
45 |
Init method. Create config including dict type config and text type config.
46 |
47 |
- cfg_dict (:obj:`Optional[dict]`): dict type config
48 |
- cfg_text (:obj:`Optional[str]`): text type config
49 |
- filename (:obj:`Optional[str]`): config file name
50 |
51 |
if cfg_dict is None:
52 |
cfg_dict = {}
53 |
if not isinstance(cfg_dict, dict):
54 |
raise TypeError("invalid type for cfg_dict: {}".format(type(cfg_dict)))
55 |
self._cfg_dict = cfg_dict
56 |
if cfg_text:
57 |
text = cfg_text
58 |
elif filename:
59 |
with open(filename, 'r') as f:
60 |
text =
61 |
62 |
text = '.'
63 |
self._text = text
64 |
self._filename = filename
65 |
66 |
67 |
def file_to_dict(filename: str) -> 'Config': # noqa
68 |
69 |
70 |
Read config file and create config.
71 |
72 |
- filename (:obj:`Optional[str]`): config file name.
73 |
74 |
- cfg_dict (:obj:`Config`): config class
75 |
76 |
cfg_dict, cfg_text = Config._file_to_dict(filename)
77 |
return Config(cfg_dict, cfg_text, filename=filename)
78 |
79 |
80 |
def _file_to_dict(filename: str) -> Tuple[dict, str]:
81 |
82 |
83 |
Read config file and convert the config file to dict type config and text type config.
84 |
85 |
- filename (:obj:`Optional[str]`): config file name.
86 |
87 |
- cfg_dict (:obj:`Optional[dict]`): dict type config
88 |
- cfg_text (:obj:`Optional[str]`): text type config
89 |
90 |
filename = osp.abspath(osp.expanduser(filename))
91 |
# TODO check exist
92 |
# TODO check suffix
93 |
ext_name = osp.splitext(filename)[-1]
94 |
with tempfile.TemporaryDirectory() as temp_config_dir:
95 |
temp_config_file = tempfile.NamedTemporaryFile(dir=temp_config_dir, suffix=ext_name)
96 |
temp_config_name = osp.basename(
97 |
98 |
99 |
100 |
temp_module_name = osp.splitext(temp_config_name)[0]
101 |
sys.path.insert(0, temp_config_dir)
102 |
# TODO validate py syntax
103 |
module = import_module(temp_module_name)
104 |
cfg_dict = {k: v for k, v in module.__dict__.items() if not k.startswith('_')}
105 |
del sys.modules[temp_module_name]
106 |
107 |
108 |
cfg_text = filename + '\n'
109 |
with open(filename, 'r') as f:
110 |
cfg_text +=
111 |
112 |
return cfg_dict, cfg_text
113 |
114 |
115 |
def cfg_dict(self) -> dict:
116 |
return self._cfg_dict
117 |
118 |
119 |
def read_config_yaml(path: str) -> EasyDict:
120 |
121 |
122 |
read configuration from path
123 |
124 |
- path (:obj:`str`): Path of source yaml
125 |
126 |
- (:obj:`EasyDict`): Config data from this file with dict type
127 |
128 |
with open(path, "r") as f:
129 |
config_ = yaml.safe_load(f)
130 |
131 |
return EasyDict(config_)
132 |
133 |
134 |
def save_config_yaml(config_: dict, path: str) -> None:
135 |
136 |
137 |
save configuration to path
138 |
139 |
- config (:obj:`dict`): Config dict
140 |
- path (:obj:`str`): Path of target yaml
141 |
142 |
config_string = json.dumps(config_)
143 |
with open(path, "w") as f:
144 |
yaml.safe_dump(json.loads(config_string), f)
145 |
146 |
147 |
def save_config_py(config_: dict, path: str) -> None:
148 |
149 |
150 |
save configuration to python file
151 |
152 |
- config (:obj:`dict`): Config dict
153 |
- path (:obj:`str`): Path of target yaml
154 |
155 |
# config_string = json.dumps(config_, indent=4)
156 |
config_string = str(config_)
157 |
from yapf.yapflib.yapf_api import FormatCode
158 |
config_string, _ = FormatCode(config_string)
159 |
config_string = config_string.replace('inf,', 'float("inf"),')
160 |
with open(path, "w") as f:
161 |
f.write('exp_config = ' + config_string)
162 |
163 |
164 |
def read_config_directly(path: str) -> dict:
165 |
166 |
167 |
Read configuration from a file path(now only support python file) and directly return results.
168 |
169 |
- path (:obj:`str`): Path of configuration file
170 |
171 |
- cfg (:obj:`Tuple[dict, dict]`): Configuration dict.
172 |
173 |
suffix = path.split('.')[-1]
174 |
if suffix == 'py':
175 |
return Config.file_to_dict(path).cfg_dict
176 |
177 |
raise KeyError("invalid config file suffix: {}".format(suffix))
178 |
179 |
180 |
def read_config(path: str) -> Tuple[dict, dict]:
181 |
182 |
183 |
Read configuration from a file path(now only suport python file). And select some proper parts.
184 |
185 |
- path (:obj:`str`): Path of configuration file
186 |
187 |
- cfg (:obj:`Tuple[dict, dict]`): A collection(tuple) of configuration dict, divided into `main_config` and \
188 |
`create_cfg` two parts.
189 |
190 |
suffix = path.split('.')[-1]
191 |
if suffix == 'py':
192 |
cfg = Config.file_to_dict(path).cfg_dict
193 |
assert "main_config" in cfg, "Please make sure a 'main_config' variable is declared in config python file!"
194 |
assert "create_config" in cfg, "Please make sure a 'create_config' variable is declared in config python file!"
195 |
return cfg['main_config'], cfg['create_config']
196 |
197 |
raise KeyError("invalid config file suffix: {}".format(suffix))
198 |
199 |
200 |
def read_config_with_system(path: str) -> Tuple[dict, dict, dict]:
201 |
202 |
203 |
Read configuration from a file path(now only suport python file). And select some proper parts
204 |
205 |
- path (:obj:`str`): Path of configuration file
206 |
207 |
- cfg (:obj:`Tuple[dict, dict]`): A collection(tuple) of configuration dict, divided into `main_config`, \
208 |
`create_cfg` and `system_config` three parts.
209 |
210 |
suffix = path.split('.')[-1]
211 |
if suffix == 'py':
212 |
cfg = Config.file_to_dict(path).cfg_dict
213 |
assert "main_config" in cfg, "Please make sure a 'main_config' variable is declared in config python file!"
214 |
assert "create_config" in cfg, "Please make sure a 'create_config' variable is declared in config python file!"
215 |
assert "system_config" in cfg, "Please make sure a 'system_config' variable is declared in config python file!"
216 |
return cfg['main_config'], cfg['create_config'], cfg['system_config']
217 |
218 |
raise KeyError("invalid config file suffix: {}".format(suffix))
219 |
220 |
221 |
def save_config(config_: dict, path: str, type_: str = 'py', save_formatted: bool = False) -> None:
222 |
223 |
224 |
save configuration to python file or yaml file
225 |
226 |
- config (:obj:`dict`): Config dict
227 |
- path (:obj:`str`): Path of target yaml or target python file
228 |
- type (:obj:`str`): If type is ``yaml`` , save configuration to yaml file. If type is ``py`` , save\
229 |
configuration to python file.
230 |
- save_formatted (:obj:`bool`): If save_formatted is true, save formatted config to path.\
231 |
Formatted config can be read by serial_pipeline directly.
232 |
233 |
assert type_ in ['yaml', 'py'], type_
234 |
if type_ == 'yaml':
235 |
save_config_yaml(config_, path)
236 |
elif type_ == 'py':
237 |
save_config_py(config_, path)
238 |
if save_formatted:
239 |
formated_path = osp.join(osp.dirname(path), 'formatted_' + osp.basename(path))
240 |
save_config_formatted(config_, formated_path)
241 |
242 |
243 |
def compile_buffer_config(policy_cfg: EasyDict, user_cfg: EasyDict, buffer_cls: 'IBuffer') -> EasyDict: # noqa
244 |
245 |
def _compile_buffer_config(policy_buffer_cfg, user_buffer_cfg, buffer_cls):
246 |
247 |
if buffer_cls is None:
248 |
assert 'type' in policy_buffer_cfg, "please indicate buffer type in create_cfg"
249 |
buffer_cls = get_buffer_cls(policy_buffer_cfg)
250 |
buffer_cfg = deep_merge_dicts(buffer_cls.default_config(), policy_buffer_cfg)
251 |
buffer_cfg = deep_merge_dicts(buffer_cfg, user_buffer_cfg)
252 |
return buffer_cfg
253 |
254 |
policy_multi_buffer = policy_cfg.other.replay_buffer.get('multi_buffer', False)
255 |
user_multi_buffer = user_cfg.policy.get('other', {}).get('replay_buffer', {}).get('multi_buffer', False)
256 |
assert not user_multi_buffer or user_multi_buffer == policy_multi_buffer, "For multi_buffer, \
257 |
user_cfg({}) and policy_cfg({}) must be in accordance".format(user_multi_buffer, policy_multi_buffer)
258 |
multi_buffer = policy_multi_buffer
259 |
if not multi_buffer:
260 |
policy_buffer_cfg = policy_cfg.other.replay_buffer
261 |
user_buffer_cfg = user_cfg.policy.get('other', {}).get('replay_buffer', {})
262 |
return _compile_buffer_config(policy_buffer_cfg, user_buffer_cfg, buffer_cls)
263 |
264 |
return_cfg = EasyDict()
265 |
for buffer_name in policy_cfg.other.replay_buffer: # Only traverse keys in policy_cfg
266 |
if buffer_name == 'multi_buffer':
267 |
268 |
policy_buffer_cfg = policy_cfg.other.replay_buffer[buffer_name]
269 |
user_buffer_cfg = user_cfg.policy.get('other', {}).get('replay_buffer', {}).get('buffer_name', {})
270 |
if buffer_cls is None:
271 |
return_cfg[buffer_name] = _compile_buffer_config(policy_buffer_cfg, user_buffer_cfg, None)
272 |
273 |
return_cfg[buffer_name] = _compile_buffer_config(
274 |
policy_buffer_cfg, user_buffer_cfg, buffer_cls[buffer_name]
275 |
276 |
return_cfg[buffer_name].name = buffer_name
277 |
return return_cfg
278 |
279 |
280 |
def compile_collector_config(
281 |
policy_cfg: EasyDict,
282 |
user_cfg: EasyDict,
283 |
collector_cls: 'ISerialCollector' # noqa
284 |
) -> EasyDict:
285 |
policy_collector_cfg = policy_cfg.collect.collector
286 |
user_collector_cfg = user_cfg.policy.get('collect', {}).get('collector', {})
287 |
# step1: get collector class
288 |
# two cases: create cfg merged in policy_cfg, collector class, and class has higher priority
289 |
if collector_cls is None:
290 |
assert 'type' in policy_collector_cfg, "please indicate collector type in create_cfg"
291 |
# use type to get collector_cls
292 |
collector_cls = get_serial_collector_cls(policy_collector_cfg)
293 |
# step2: policy collector cfg merge to collector cfg
294 |
collector_cfg = deep_merge_dicts(collector_cls.default_config(), policy_collector_cfg)
295 |
# step3: user collector cfg merge to the step2 config
296 |
collector_cfg = deep_merge_dicts(collector_cfg, user_collector_cfg)
297 |
298 |
return collector_cfg
299 |
300 |
301 |
policy_config_template = dict(
302 |
303 |
304 |
305 |
306 |
307 |
308 |
policy_config_template = EasyDict(policy_config_template)
309 |
env_config_template = dict(manager=dict(), stop_value=int(1e10), n_evaluator_episode=4)
310 |
env_config_template = EasyDict(env_config_template)
311 |
312 |
313 |
def save_project_state(exp_name: str) -> None:
314 |
315 |
def _fn(cmd: str):
316 |
return, shell=True, stdout=subprocess.PIPE).stdout.strip().decode("utf-8")
317 |
318 |
if"git status", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE).returncode == 0:
319 |
short_sha = _fn("git describe --always")
320 |
log = _fn("git log --stat -n 5")
321 |
diff = _fn("git diff")
322 |
with open(os.path.join(exp_name, "git_log.txt"), "w", encoding='utf-8') as f:
323 |
f.write(short_sha + '\n\n' + log)
324 |
with open(os.path.join(exp_name, "git_diff.txt"), "w", encoding='utf-8') as f:
325 |
326 |
327 |
328 |
def compile_config(
329 |
cfg: EasyDict,
330 |
env_manager: type = None,
331 |
policy: type = None,
332 |
learner: type = BaseLearner,
333 |
collector: type = None,
334 |
evaluator: type = InteractionSerialEvaluator,
335 |
buffer: type = None,
336 |
env: type = None,
337 |
reward_model: type = None,
338 |
world_model: type = None,
339 |
seed: int = 0,
340 |
auto: bool = False,
341 |
create_cfg: dict = None,
342 |
save_cfg: bool = True,
343 |
save_path: str = '',
344 |
renew_dir: bool = True,
345 |
) -> EasyDict:
346 |
347 |
348 |
Combine the input config information with other input information.
349 |
Compile config to make it easy to be called by other programs
350 |
351 |
- cfg (:obj:`EasyDict`): Input config dict which is to be used in the following pipeline
352 |
- env_manager (:obj:`type`): Env_manager class which is to be used in the following pipeline
353 |
- policy (:obj:`type`): Policy class which is to be used in the following pipeline
354 |
- learner (:obj:`type`): Input learner class, defaults to BaseLearner
355 |
- collector (:obj:`type`): Input collector class, defaults to BaseSerialCollector
356 |
- evaluator (:obj:`type`): Input evaluator class, defaults to InteractionSerialEvaluator
357 |
- buffer (:obj:`type`): Input buffer class, defaults to IBuffer
358 |
- env (:obj:`type`): Environment class which is to be used in the following pipeline
359 |
- reward_model (:obj:`type`): Reward model class which aims to offer various and valuable reward
360 |
- seed (:obj:`int`): Random number seed
361 |
- auto (:obj:`bool`): Compile create_config dict or not
362 |
- create_cfg (:obj:`dict`): Input create config dict
363 |
- save_cfg (:obj:`bool`): Save config or not
364 |
- save_path (:obj:`str`): Path of saving file
365 |
- renew_dir (:obj:`bool`): Whether to new a directory for saving config.
366 |
367 |
- cfg (:obj:`EasyDict`): Config after compiling
368 |
369 |
cfg, create_cfg = deepcopy(cfg), deepcopy(create_cfg)
370 |
if auto:
371 |
assert create_cfg is not None
372 |
# for compatibility
373 |
if 'collector' not in create_cfg:
374 |
create_cfg.collector = EasyDict(dict(type='sample'))
375 |
if 'replay_buffer' not in create_cfg:
376 |
create_cfg.replay_buffer = EasyDict(dict(type='advanced'))
377 |
buffer = AdvancedReplayBuffer
378 |
if env is None:
379 |
if 'env' in create_cfg:
380 |
env = get_env_cls(create_cfg.env)
381 |
382 |
env = None
383 |
create_cfg.env = {'type': 'ding_env_wrapper_generated'}
384 |
if env_manager is None:
385 |
env_manager = get_env_manager_cls(create_cfg.env_manager)
386 |
if policy is None:
387 |
policy = get_policy_cls(create_cfg.policy)
388 |
if 'default_config' in dir(env):
389 |
env_config = env.default_config()
390 |
391 |
env_config = EasyDict() # env does not have default_config
392 |
env_config = deep_merge_dicts(env_config_template, env_config)
393 |
394 |
env_config.manager = deep_merge_dicts(env_manager.default_config(), env_config.manager)
395 |
396 |
policy_config = policy.default_config()
397 |
policy_config = deep_merge_dicts(policy_config_template, policy_config)
398 |
399 |
400 |
if 'evaluator' in create_cfg:
401 |
402 |
403 |
404 |
policy_config.other.commander = BaseSerialCommander.default_config()
405 |
if 'reward_model' in create_cfg:
406 |
reward_model = get_reward_model_cls(create_cfg.reward_model)
407 |
reward_model_config = reward_model.default_config()
408 |
409 |
reward_model_config = EasyDict()
410 |
if 'world_model' in create_cfg:
411 |
world_model = get_world_model_cls(create_cfg.world_model)
412 |
world_model_config = world_model.default_config()
413 |
414 |
415 |
world_model_config = EasyDict()
416 |
417 |
if 'default_config' in dir(env):
418 |
env_config = env.default_config()
419 |
420 |
env_config = EasyDict() # env does not have default_config
421 |
env_config = deep_merge_dicts(env_config_template, env_config)
422 |
if env_manager is None:
423 |
env_manager = BaseEnvManager # for compatibility
424 |
env_config.manager = deep_merge_dicts(env_manager.default_config(), env_config.manager)
425 |
policy_config = policy.default_config()
426 |
policy_config = deep_merge_dicts(policy_config_template, policy_config)
427 |
if reward_model is None:
428 |
reward_model_config = EasyDict()
429 |
430 |
reward_model_config = reward_model.default_config()
431 |
if world_model is None:
432 |
world_model_config = EasyDict()
433 |
434 |
world_model_config = world_model.default_config()
435 |
436 |
policy_config.learn.learner = deep_merge_dicts(
437 |
438 |
439 |
440 |
if create_cfg is not None or collector is not None:
441 |
policy_config.collect.collector = compile_collector_config(policy_config, cfg, collector)
442 |
if evaluator:
443 |
policy_config.eval.evaluator = deep_merge_dicts(
444 |
445 |
446 |
447 |
if create_cfg is not None or buffer is not None:
448 |
policy_config.other.replay_buffer = compile_buffer_config(policy_config, cfg, buffer)
449 |
default_config = EasyDict({'env': env_config, 'policy': policy_config})
450 |
if len(reward_model_config) > 0:
451 |
default_config['reward_model'] = reward_model_config
452 |
if len(world_model_config) > 0:
453 |
default_config['world_model'] = world_model_config
454 |
cfg = deep_merge_dicts(default_config, cfg)
455 |
if 'unroll_len' in cfg.policy:
456 |
cfg.policy.collect.unroll_len = cfg.policy.unroll_len
457 |
cfg.seed = seed
458 |
# check important key in config
459 |
if evaluator in [InteractionSerialEvaluator, BattleInteractionSerialEvaluator]: # env interaction evaluation
460 |
cfg.policy.eval.evaluator.stop_value = cfg.env.stop_value
461 |
cfg.policy.eval.evaluator.n_episode = cfg.env.n_evaluator_episode
462 |
if 'exp_name' not in cfg:
463 |
cfg.exp_name = 'default_experiment'
464 |
if save_cfg and get_rank() == 0:
465 |
if os.path.exists(cfg.exp_name) and renew_dir:
466 |
cfg.exp_name +="_%y%m%d_%H%M%S")
467 |
468 |
469 |
except FileExistsError:
470 |
471 |
472 |
save_path = os.path.join(cfg.exp_name, save_path)
473 |
save_config(cfg, save_path, save_formatted=True)
474 |
return cfg
475 |
476 |
477 |
def compile_config_parallel(
478 |
cfg: EasyDict,
479 |
create_cfg: EasyDict,
480 |
system_cfg: EasyDict,
481 |
seed: int = 0,
482 |
save_cfg: bool = True,
483 |
save_path: str = '',
484 |
platform: str = 'local',
485 |
coordinator_host: Optional[str] = None,
486 |
learner_host: Optional[str] = None,
487 |
collector_host: Optional[str] = None,
488 |
coordinator_port: Optional[int] = None,
489 |
learner_port: Optional[int] = None,
490 |
collector_port: Optional[int] = None,
491 |
) -> EasyDict:
492 |
493 |
494 |
Combine the input parallel mode configuration information with other input information. Compile config\
495 |
to make it easy to be called by other programs
496 |
497 |
- cfg (:obj:`EasyDict`): Input main config dict
498 |
- create_cfg (:obj:`dict`): Input create config dict, including type parameters, such as environment type
499 |
- system_cfg (:obj:`dict`): Input system config dict, including system parameters, such as file path,\
500 |
communication mode, use multiple GPUs or not
501 |
- seed (:obj:`int`): Random number seed
502 |
- save_cfg (:obj:`bool`): Save config or not
503 |
- save_path (:obj:`str`): Path of saving file
504 |
- platform (:obj:`str`): Where to run the program, 'local' or 'slurm'
505 |
- coordinator_host (:obj:`Optional[str]`): Input coordinator's host when platform is slurm
506 |
- learner_host (:obj:`Optional[str]`): Input learner's host when platform is slurm
507 |
- collector_host (:obj:`Optional[str]`): Input collector's host when platform is slurm
508 |
509 |
- cfg (:obj:`EasyDict`): Config after compiling
510 |
511 |
# for compatibility
512 |
if 'replay_buffer' not in create_cfg:
513 |
create_cfg.replay_buffer = EasyDict(dict(type='advanced'))
514 |
# env
515 |
env = get_env_cls(create_cfg.env)
516 |
if 'default_config' in dir(env):
517 |
env_config = env.default_config()
518 |
519 |
env_config = EasyDict() # env does not have default_config
520 |
env_config = deep_merge_dicts(env_config_template, env_config)
521 |
522 |
523 |
env_manager = get_env_manager_cls(create_cfg.env_manager)
524 |
env_config.manager = env_manager.default_config()
525 |
526 |
527 |
# policy
528 |
policy = get_policy_cls(create_cfg.policy)
529 |
policy_config = policy.default_config()
530 |
policy_config = deep_merge_dicts(policy_config_template, policy_config)
531 |
532 |
533 |
collector = get_parallel_collector_cls(create_cfg.collector)
534 |
policy_config.collect.collector = collector.default_config()
535 |
536 |
policy_config.learn.learner = BaseLearner.default_config()
537 |
538 |
commander = get_parallel_commander_cls(create_cfg.commander)
539 |
policy_config.other.commander = commander.default_config()
540 |
541 |
542 |
policy_config.other.replay_buffer = compile_buffer_config(policy_config, cfg, None)
543 |
544 |
default_config = EasyDict({'env': env_config, 'policy': policy_config})
545 |
cfg = deep_merge_dicts(default_config, cfg)
546 |
547 |
cfg.policy.other.commander.path_policy = system_cfg.path_policy # league may use 'path_policy'
548 |
549 |
# system
550 |
for k in ['comm_learner', 'comm_collector']:
551 |
system_cfg[k] = create_cfg[k]
552 |
if platform == 'local':
553 |
cfg = parallel_transform(EasyDict({'main': cfg, 'system': system_cfg}))
554 |
elif platform == 'slurm':
555 |
cfg = parallel_transform_slurm(
556 |
557 |
'main': cfg,
558 |
'system': system_cfg
559 |
}), coordinator_host, learner_host, collector_host
560 |
561 |
elif platform == 'k8s':
562 |
cfg = parallel_transform_k8s(
563 |
564 |
'main': cfg,
565 |
'system': system_cfg
566 |
567 |
568 |
569 |
570 |
571 |
572 |
raise KeyError("not support platform type: {}".format(platform))
573 |
cfg.system.coordinator = deep_merge_dicts(Coordinator.default_config(), cfg.system.coordinator)
574 |
# seed
575 |
cfg.seed = seed
576 |
577 |
if save_cfg:
578 |
save_config(cfg, save_path)
579 |
return cfg
@@ -0,0 +1,17 @@
1 |
from easydict import EasyDict
2 |
from . import gym_bipedalwalker_v3
3 |
from . import gym_lunarlander_v2
4 |
5 |
supported_env_cfg = {
6 |
gym_bipedalwalker_v3.cfg.env.env_id: gym_bipedalwalker_v3.cfg,
7 |
gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.cfg,
8 |
9 |
10 |
supported_env_cfg = EasyDict(supported_env_cfg)
11 |
12 |
supported_env = {
13 |
gym_bipedalwalker_v3.cfg.env.env_id: gym_bipedalwalker_v3.env,
14 |
gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.env,
15 |
16 |
17 |
supported_env = EasyDict(supported_env)
@@ -0,0 +1,43 @@
1 |
from easydict import EasyDict
2 |
import ding.envs.gym_env
3 |
4 |
cfg = dict(
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
38 |
39 |
40 |
41 |
cfg = EasyDict(cfg)
42 |
43 |
env = ding.envs.gym_env.env
@@ -0,0 +1,38 @@
1 |
from easydict import EasyDict
2 |
import ding.envs.gym_env
3 |
4 |
cfg = dict(
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
33 |
34 |
35 |
36 |
cfg = EasyDict(cfg)
37 |
38 |
env = ding.envs.gym_env.env
@@ -0,0 +1,23 @@
1 |
from easydict import EasyDict
2 |
from . import gym_lunarlander_v2
3 |
from . import gym_pongnoframeskip_v4
4 |
from . import gym_qbertnoframeskip_v4
5 |
from . import gym_spaceInvadersnoframeskip_v4
6 |
7 |
supported_env_cfg = {
8 |
gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.cfg,
9 |
gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.cfg,
10 |
gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.cfg,
11 |
gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.cfg,
12 |
13 |
14 |
supported_env_cfg = EasyDict(supported_env_cfg)
15 |
16 |
supported_env = {
17 |
gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.env,
18 |
gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.env,
19 |
gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.env,
20 |
gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.env,
21 |
22 |
23 |
supported_env = EasyDict(supported_env)
@@ -0,0 +1,52 @@
1 |
from easydict import EasyDict
2 |
import ding.envs.gym_env
3 |
4 |
cfg = dict(
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
encoder_hidden_size_list=[512, 64],
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
), replay_buffer=dict(replay_buffer_size=100000, )
43 |
44 |
45 |
46 |
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
47 |
48 |
49 |
50 |
cfg = EasyDict(cfg)
51 |
52 |
env = ding.envs.gym_env.env
@@ -0,0 +1,54 @@
1 |
from easydict import EasyDict
2 |
import ding.envs.gym_env
3 |
4 |
cfg = dict(
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
obs_shape=[4, 84, 84],
21 |
22 |
encoder_hidden_size_list=[128, 128, 512],
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
collect=dict(n_sample=100, ),
36 |
eval=dict(evaluator=dict(eval_freq=4000, )),
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
replay_buffer=dict(replay_buffer_size=100000, ),
45 |
46 |
47 |
48 |
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
49 |
50 |
51 |
52 |
cfg = EasyDict(cfg)
53 |
54 |
env = ding.envs.gym_env.env
@@ -0,0 +1,54 @@
1 |
from easydict import EasyDict
2 |
import ding.envs.gym_env
3 |
4 |
cfg = dict(
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
obs_shape=[4, 84, 84],
21 |
22 |
encoder_hidden_size_list=[128, 128, 512],
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
collect=dict(n_sample=100, ),
36 |
eval=dict(evaluator=dict(eval_freq=4000, )),
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
replay_buffer=dict(replay_buffer_size=400000, ),
45 |
46 |
47 |
48 |
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
49 |
50 |
51 |
52 |
cfg = EasyDict(cfg)
53 |
54 |
env = ding.envs.gym_env.env
@@ -0,0 +1,54 @@
1 |
from easydict import EasyDict
2 |
import ding.envs.gym_env
3 |
4 |
cfg = dict(
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
obs_shape=[4, 84, 84],
21 |
22 |
encoder_hidden_size_list=[128, 128, 512],
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
collect=dict(n_sample=100, ),
36 |
eval=dict(evaluator=dict(eval_freq=4000, )),
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
replay_buffer=dict(replay_buffer_size=400000, ),
45 |
46 |
47 |
48 |
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
49 |
50 |
51 |
52 |
cfg = EasyDict(cfg)
53 |
54 |
env = ding.envs.gym_env.env
@@ -0,0 +1,29 @@
1 |
from easydict import EasyDict
2 |
from . import gym_bipedalwalker_v3
3 |
from . import gym_halfcheetah_v3
4 |
from . import gym_hopper_v3
5 |
from . import gym_lunarlandercontinuous_v2
6 |
from . import gym_pendulum_v1
7 |
from . import gym_walker2d_v3
8 |
9 |
supported_env_cfg = {
10 |
gym_bipedalwalker_v3.cfg.env.env_id: gym_bipedalwalker_v3.cfg,
11 |
gym_halfcheetah_v3.cfg.env.env_id: gym_halfcheetah_v3.cfg,
12 |
gym_hopper_v3.cfg.env.env_id: gym_hopper_v3.cfg,
13 |
gym_lunarlandercontinuous_v2.cfg.env.env_id: gym_lunarlandercontinuous_v2.cfg,
14 |
gym_pendulum_v1.cfg.env.env_id: gym_pendulum_v1.cfg,
15 |
gym_walker2d_v3.cfg.env.env_id: gym_walker2d_v3.cfg,
16 |
17 |
18 |
supported_env_cfg = EasyDict(supported_env_cfg)
19 |
20 |
supported_env = {
21 |
gym_bipedalwalker_v3.cfg.env.env_id: gym_bipedalwalker_v3.env,
22 |
gym_halfcheetah_v3.cfg.env.env_id: gym_halfcheetah_v3.env,
23 |
gym_hopper_v3.cfg.env.env_id: gym_hopper_v3.env,
24 |
gym_lunarlandercontinuous_v2.cfg.env.env_id: gym_lunarlandercontinuous_v2.env,
25 |
gym_pendulum_v1.cfg.env.env_id: gym_pendulum_v1.env,
26 |
gym_walker2d_v3.cfg.env.env_id: gym_walker2d_v3.env,
27 |
28 |
29 |
supported_env = EasyDict(supported_env)
@@ -0,0 +1,45 @@
1 |
from easydict import EasyDict
2 |
import ding.envs.gym_env
3 |
4 |
cfg = dict(
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
learner=dict(hook=dict(log_show_after_iter=1000, ))
34 |
35 |
collect=dict(n_sample=64, ),
36 |
other=dict(replay_buffer=dict(replay_buffer_size=300000, ), ),
37 |
38 |
39 |
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
40 |
41 |
42 |
43 |
cfg = EasyDict(cfg)
44 |
45 |
env = ding.envs.gym_env.env
@@ -0,0 +1,53 @@
1 |
from easydict import EasyDict
2 |
import ding.envs.gym_env
3 |
4 |
cfg = dict(
5 |
6 |
7 |
8 |
9 |
norm_obs=dict(use_norm=False, ),
10 |
norm_reward=dict(use_norm=False, ),
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
45 |
46 |
47 |
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
48 |
49 |
50 |
51 |
cfg = EasyDict(cfg)
52 |
53 |
env = ding.envs.gym_env.env
@@ -0,0 +1,53 @@
1 |
from easydict import EasyDict
2 |
import ding.envs.gym_env
3 |
4 |
cfg = dict(
5 |
6 |
7 |
8 |
9 |
norm_obs=dict(use_norm=False, ),
10 |
norm_reward=dict(use_norm=False, ),
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
45 |
46 |
47 |
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
48 |
49 |
50 |
51 |
cfg = EasyDict(cfg)
52 |
53 |
env = ding.envs.gym_env.env
@@ -0,0 +1,60 @@
1 |
from easydict import EasyDict
2 |
from functools import partial
3 |
import ding.envs.gym_env
4 |
5 |
cfg = dict(
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
ignore_done=False, # TODO(pu)
31 |
# (int) When critic network updates once, how many times will actor network update.
32 |
# Delayed Policy Updates in original TD3 paper(
33 |
# Default 1 for DDPG, 2 for TD3.
34 |
35 |
# (bool) Whether to add noise on target network's action.
36 |
# Target Policy Smoothing Regularization in original TD3 paper(
37 |
# Default True for TD3, False for DDPG.
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
collector=dict(collect_print_freq=1000, ),
49 |
50 |
eval=dict(evaluator=dict(eval_freq=100, ), ),
51 |
other=dict(replay_buffer=dict(replay_buffer_size=20000, ), ),
52 |
53 |
54 |
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
55 |
56 |
57 |
58 |
cfg = EasyDict(cfg)
59 |
60 |
env = partial(ding.envs.gym_env.env, continuous=True)
@@ -0,0 +1,52 @@
1 |
from easydict import EasyDict
2 |
import ding.envs.gym_env
3 |
4 |
cfg = dict(
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
collector=dict(collect_print_freq=1000, ),
38 |
39 |
eval=dict(evaluator=dict(eval_freq=100, )),
40 |
41 |
42 |
43 |
), ),
44 |
45 |
46 |
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
47 |
48 |
49 |
50 |
cfg = EasyDict(cfg)
51 |
52 |
env = ding.envs.gym_env.env
@@ -0,0 +1,53 @@
1 |
from easydict import EasyDict
2 |
import ding.envs.gym_env
3 |
4 |
cfg = dict(
5 |
6 |
7 |
8 |
9 |
norm_obs=dict(use_norm=False, ),
10 |
norm_reward=dict(use_norm=False, ),
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
45 |
46 |
47 |
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
48 |
49 |
50 |
51 |
cfg = EasyDict(cfg)
52 |
53 |
env = ding.envs.gym_env.env
@@ -0,0 +1,23 @@
1 |
from easydict import EasyDict
2 |
from . import gym_lunarlander_v2
3 |
from . import gym_pongnoframeskip_v4
4 |
from . import gym_qbertnoframeskip_v4
5 |
from . import gym_spaceInvadersnoframeskip_v4
6 |
7 |
supported_env_cfg = {
8 |
gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.cfg,
9 |
gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.cfg,
10 |
gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.cfg,
11 |
gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.cfg,
12 |
13 |
14 |
supported_env_cfg = EasyDict(supported_env_cfg)
15 |
16 |
supported_env = {
17 |
gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.env,
18 |
gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.env,
19 |
gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.env,
20 |
gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.env,
21 |
22 |
23 |
supported_env = EasyDict(supported_env)
@@ -0,0 +1,53 @@
1 |
from easydict import EasyDict
2 |
import ding.envs.gym_env
3 |
4 |
cfg = dict(
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
# Frequency of target network update.
24 |
25 |
26 |
27 |
28 |
29 |
encoder_hidden_size_list=[512, 64],
30 |
# Whether to use dueling head.
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
), replay_buffer=dict(replay_buffer_size=100000, )
44 |
45 |
46 |
47 |
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
48 |
49 |
50 |
51 |
cfg = EasyDict(cfg)
52 |
53 |
env = ding.envs.gym_env.env
@@ -0,0 +1,50 @@
1 |
from easydict import EasyDict
2 |
import ding.envs.gym_env
3 |
4 |
cfg = dict(
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
# Frequency of target network update.
26 |
27 |
28 |
29 |
obs_shape=[4, 84, 84],
30 |
31 |
encoder_hidden_size_list=[128, 128, 512],
32 |
33 |
collect=dict(n_sample=96, ),
34 |
35 |
36 |
37 |
38 |
39 |
40 |
), replay_buffer=dict(replay_buffer_size=100000, )
41 |
42 |
43 |
44 |
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
45 |
46 |
47 |
48 |
cfg = EasyDict(cfg)
49 |
50 |
env = ding.envs.gym_env.env
@@ -0,0 +1,50 @@
1 |
from easydict import EasyDict
2 |
import ding.envs.gym_env
3 |
4 |
cfg = dict(
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
# Frequency of target network update.
26 |
27 |
28 |
29 |
obs_shape=[4, 84, 84],
30 |
31 |
encoder_hidden_size_list=[128, 128, 512],
32 |
33 |
collect=dict(n_sample=100, ),
34 |
35 |
36 |
37 |
38 |
39 |
40 |
), replay_buffer=dict(replay_buffer_size=400000, )
41 |
42 |
43 |
44 |
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
45 |
46 |
47 |
48 |
cfg = EasyDict(cfg)
49 |
50 |
env = ding.envs.gym_env.env