JAX / Flax์์์ ๐งจ Stable Diffusion!
๐ค Hugging Face [Diffusers] (https://github.com/huggingface/diffusers) ๋ ๋ฒ์ 0.5.1๋ถํฐ Flax๋ฅผ ์ง์ํฉ๋๋ค! ์ด๋ฅผ ํตํด Colab, Kaggle, Google Cloud Platform์์ ์ฌ์ฉํ ์ ์๋ ๊ฒ์ฒ๋ผ Google TPU์์ ์ด๊ณ ์ ์ถ๋ก ์ด ๊ฐ๋ฅํฉ๋๋ค.
์ด ๋ ธํธ๋ถ์ JAX / Flax๋ฅผ ์ฌ์ฉํด ์ถ๋ก ์ ์คํํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค. Stable Diffusion์ ์๋ ๋ฐฉ์์ ๋ํ ์์ธํ ๋ด์ฉ์ ์ํ๊ฑฐ๋ GPU์์ ์คํํ๋ ค๋ฉด ์ด [๋ ธํธ๋ถ] ](https://huggingface.co/docs/diffusers/stable_diffusion)์ ์ฐธ์กฐํ์ธ์.
๋จผ์ , TPU ๋ฐฑ์๋๋ฅผ ์ฌ์ฉํ๊ณ ์๋์ง ํ์ธํฉ๋๋ค. Colab์์ ์ด ๋ ธํธ๋ถ์ ์คํํ๋ ๊ฒฝ์ฐ, ๋ฉ๋ด์์ ๋ฐํ์์ ์ ํํ ๋ค์ โ๋ฐํ์ ์ ํ ๋ณ๊ฒฝโ ์ต์ ์ ์ ํํ ๋ค์ ํ๋์จ์ด ๊ฐ์๊ธฐ ์ค์ ์์ TPU๋ฅผ ์ ํํฉ๋๋ค.
JAX๋ TPU ์ ์ฉ์ ์๋์ง๋ง ๊ฐ TPU ์๋ฒ์๋ 8๊ฐ์ TPU ๊ฐ์๊ธฐ๊ฐ ๋ณ๋ ฌ๋ก ์๋ํ๊ธฐ ๋๋ฌธ์ ํด๋น ํ๋์จ์ด์์ ๋ ๋น์ ๋ฐํ๋ค๋ ์ ์ ์์๋์ธ์.
Setup
๋จผ์ diffusers๊ฐ ์ค์น๋์ด ์๋์ง ํ์ธํฉ๋๋ค.
!pip install jax==0.3.25 jaxlib==0.3.25 flax transformers ftfy !pip install diffusers
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
import jax
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind
print(f"Found {num_devices} JAX devices of type {device_type}.")
assert (
"TPU" in device_type
), "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"
Found 8 JAX devices of type Cloud TPU.
๊ทธ๋ฐ ๋ค์ ๋ชจ๋ dependencies๋ฅผ ๊ฐ์ ธ์ต๋๋ค.
import numpy as np
import jax
import jax.numpy as jnp
from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from huggingface_hub import notebook_login
from diffusers import FlaxStableDiffusionPipeline
๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
TPU ์ฅ์น๋ ํจ์จ์ ์ธ half-float ์ ํ์ธ bfloat16์ ์ง์ํฉ๋๋ค. ํ ์คํธ์๋ ์ด ์ ํ์ ์ฌ์ฉํ์ง๋ง ๋์ float32๋ฅผ ์ฌ์ฉํ์ฌ ์ ์ฒด ์ ๋ฐ๋(full precision)๋ฅผ ์ฌ์ฉํ ์๋ ์์ต๋๋ค.
dtype = jnp.bfloat16
Flax๋ ํจ์ํ ํ๋ ์์ํฌ์ด๋ฏ๋ก ๋ชจ๋ธ์ ๋ฌด์ํ(stateless)ํ์ด๋ฉฐ ๋งค๊ฐ๋ณ์๋ ๋ชจ๋ธ ์ธ๋ถ์ ์ ์ฅ๋ฉ๋๋ค. ์ฌ์ ํ์ต๋ Flax ํ์ดํ๋ผ์ธ์ ๋ถ๋ฌ์ค๋ฉด ํ์ดํ๋ผ์ธ ์์ฒด์ ๋ชจ๋ธ ๊ฐ์ค์น(๋๋ ๋งค๊ฐ๋ณ์)๊ฐ ๋ชจ๋ ๋ฐํ๋ฉ๋๋ค. ์ ํฌ๋ bf16 ๋ฒ์ ์ ๊ฐ์ค์น๋ฅผ ์ฌ์ฉํ๊ณ ์์ผ๋ฏ๋ก ์ ํ ๊ฒฝ๊ณ ๊ฐ ํ์๋์ง๋ง ๋ฌด์ํด๋ ๋ฉ๋๋ค.
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
variant="bf16",
dtype=dtype,
)
์ถ๋ก
TPU์๋ ์ผ๋ฐ์ ์ผ๋ก 8๊ฐ์ ๋๋ฐ์ด์ค๊ฐ ๋ณ๋ ฌ๋ก ์๋ํ๋ฏ๋ก ๋ณด์ ํ ๋๋ฐ์ด์ค ์๋งํผ ํ๋กฌํํธ๋ฅผ ๋ณต์ ํฉ๋๋ค. ๊ทธ๋ฐ ๋ค์ ๊ฐ๊ฐ ํ๋์ ์ด๋ฏธ์ง ์์ฑ์ ๋ด๋นํ๋ 8๊ฐ์ ๋๋ฐ์ด์ค์์ ํ ๋ฒ์ ์ถ๋ก ์ ์ํํฉ๋๋ค. ๋ฐ๋ผ์ ํ๋์ ์นฉ์ด ํ๋์ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ ๋ฐ ๊ฑธ๋ฆฌ๋ ์๊ฐ๊ณผ ๋์ผํ ์๊ฐ์ 8๊ฐ์ ์ด๋ฏธ์ง๋ฅผ ์ป์ ์ ์์ต๋๋ค.
ํ๋กฌํํธ๋ฅผ ๋ณต์ ํ๊ณ ๋๋ฉด ํ์ดํ๋ผ์ธ์ prepare_inputs
ํจ์๋ฅผ ํธ์ถํ์ฌ ํ ํฐํ๋ ํ
์คํธ ID๋ฅผ ์ป์ต๋๋ค. ํ ํฐํ๋ ํ
์คํธ์ ๊ธธ์ด๋ ๊ธฐ๋ณธ CLIP ํ
์คํธ ๋ชจ๋ธ์ ๊ตฌ์ฑ์ ๋ฐ๋ผ 77ํ ํฐ์ผ๋ก ์ค์ ๋ฉ๋๋ค.
prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids.shape
(8, 77)
๋ณต์ฌ(Replication) ๋ฐ ์ ๋ ฌํ
๋ชจ๋ธ ๋งค๊ฐ๋ณ์์ ์
๋ ฅ๊ฐ์ ์ฐ๋ฆฌ๊ฐ ๋ณด์ ํ 8๊ฐ์ ๋ณ๋ ฌ ์ฅ์น์ ๋ณต์ฌ(Replication)๋์ด์ผ ํฉ๋๋ค. ๋งค๊ฐ๋ณ์ ๋์
๋๋ฆฌ๋ flax.jax_utils.replicate
(๋์
๋๋ฆฌ๋ฅผ ์ํํ๋ฉฐ ๊ฐ์ค์น์ ๋ชจ์์ ๋ณ๊ฒฝํ์ฌ 8๋ฒ ๋ฐ๋ณตํ๋ ํจ์)๋ฅผ ์ฌ์ฉํ์ฌ ๋ณต์ฌ๋ฉ๋๋ค. ๋ฐฐ์ด์ shard
๋ฅผ ์ฌ์ฉํ์ฌ ๋ณต์ ๋ฉ๋๋ค.
p_params = replicate(params)
prompt_ids = shard(prompt_ids) prompt_ids.shape
(8, 1, 77)
์ด shape์ 8๊ฐ์ ๋๋ฐ์ด์ค ๊ฐ๊ฐ์ด shape (1, 77)
์ jnp ๋ฐฐ์ด์ ์
๋ ฅ๊ฐ์ผ๋ก ๋ฐ๋๋ค๋ ์๋ฏธ์
๋๋ค. ์ฆ 1์ ๋๋ฐ์ด์ค๋น batch(๋ฐฐ์น) ํฌ๊ธฐ์
๋๋ค. ๋ฉ๋ชจ๋ฆฌ๊ฐ ์ถฉ๋ถํ TPU์์๋ ํ ๋ฒ์ ์ฌ๋ฌ ์ด๋ฏธ์ง(์นฉ๋น)๋ฅผ ์์ฑํ๋ ค๋ ๊ฒฝ์ฐ 1๋ณด๋ค ํด ์ ์์ต๋๋ค.
์ด๋ฏธ์ง๋ฅผ ์์ฑํ ์ค๋น๊ฐ ๊ฑฐ์ ์๋ฃ๋์์ต๋๋ค! ์ด์ ์์ฑ ํจ์์ ์ ๋ฌํ ๋์ ์์ฑ๊ธฐ๋ง ๋ง๋ค๋ฉด ๋ฉ๋๋ค. ์ด๊ฒ์ ๋์๋ฅผ ๋ค๋ฃจ๋ ๋ชจ๋ ํจ์์ ๋์ ์์ฑ๊ธฐ๊ฐ ์์ด์ผ ํ๋ค๋, ๋์์ ๋ํด ๋งค์ฐ ์ง์งํ๊ณ ๋ ๋จ์ ์ธ Flax์ ํ์ค ์ ์ฐจ์ ๋๋ค. ์ด๋ ๊ฒ ํ๋ฉด ์ฌ๋ฌ ๋ถ์ฐ๋ ๊ธฐ๊ธฐ์์ ํ๋ จํ ๋์๋ ์ฌํ์ฑ์ด ๋ณด์ฅ๋ฉ๋๋ค.
์๋ ํฌํผ ํจ์๋ ์๋๋ฅผ ์ฌ์ฉํ์ฌ ๋์ ์์ฑ๊ธฐ๋ฅผ ์ด๊ธฐํํฉ๋๋ค. ๋์ผํ ์๋๋ฅผ ์ฌ์ฉํ๋ ํ ์ ํํ ๋์ผํ ๊ฒฐ๊ณผ๋ฅผ ์ป์ ์ ์์ต๋๋ค. ๋์ค์ ๋ ธํธ๋ถ์์ ๊ฒฐ๊ณผ๋ฅผ ํ์ํ ๋์ ๋ค๋ฅธ ์๋๋ฅผ ์์ ๋กญ๊ฒ ์ฌ์ฉํ์ธ์.
def create_key(seed=0):
return jax.random.PRNGKey(seed)
rng๋ฅผ ์ป์ ๋ค์ 8๋ฒ โ๋ถํ โํ์ฌ ๊ฐ ๋๋ฐ์ด์ค๊ฐ ๋ค๋ฅธ ์ ๋๋ ์ดํฐ๋ฅผ ์์ ํ๋๋ก ํฉ๋๋ค. ๋ฐ๋ผ์ ๊ฐ ๋๋ฐ์ด์ค๋ง๋ค ๋ค๋ฅธ ์ด๋ฏธ์ง๊ฐ ์์ฑ๋๋ฉฐ ์ ์ฒด ํ๋ก์ธ์ค๋ฅผ ์ฌํํ ์ ์์ต๋๋ค.
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
JAX ์ฝ๋๋ ๋งค์ฐ ๋น ๋ฅด๊ฒ ์คํ๋๋ ํจ์จ์ ์ธ ํํ์ผ๋ก ์ปดํ์ผํ ์ ์์ต๋๋ค. ํ์ง๋ง ํ์ ํธ์ถ์์ ๋ชจ๋ ์ ๋ ฅ์ด ๋์ผํ ๋ชจ์์ ๊ฐ๋๋ก ํด์ผ ํ๋ฉฐ, ๊ทธ๋ ์ง ์์ผ๋ฉด JAX๊ฐ ์ฝ๋๋ฅผ ๋ค์ ์ปดํ์ผํด์ผ ํ๋ฏ๋ก ์ต์ ํ๋ ์๋๋ฅผ ํ์ฉํ ์ ์์ต๋๋ค.
jit = True
๋ฅผ ์ธ์๋ก ์ ๋ฌํ๋ฉด Flax ํ์ดํ๋ผ์ธ์ด ์ฝ๋๋ฅผ ์ปดํ์ผํ ์ ์์ต๋๋ค. ๋ํ ๋ชจ๋ธ์ด ์ฌ์ฉ ๊ฐ๋ฅํ 8๊ฐ์ ๋๋ฐ์ด์ค์์ ๋ณ๋ ฌ๋ก ์คํ๋๋๋ก ๋ณด์ฅํฉ๋๋ค.
๋ค์ ์ ์ ์ฒ์ ์คํํ๋ฉด ์ปดํ์ผํ๋ ๋ฐ ์๊ฐ์ด ์ค๋ ๊ฑธ๋ฆฌ์ง๋ง ์ดํ ํธ์ถ(์ ๋ ฅ์ด ๋ค๋ฅธ ๊ฒฝ์ฐ์๋)์ ํจ์ฌ ๋นจ๋ผ์ง๋๋ค. ์๋ฅผ ๋ค์ด, ํ ์คํธํ์ ๋ TPU v2-8์์ ์ปดํ์ผํ๋ ๋ฐ 1๋ถ ์ด์ ๊ฑธ๋ฆฌ์ง๋ง ์ดํ ์ถ๋ก ์คํ์๋ ์ฝ 7์ด๊ฐ ๊ฑธ๋ฆฝ๋๋ค.
%%time
images = pipeline(prompt_ids, p_params, rng, jit=True)[0]
CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s
Wall time: 1min 29s
๋ฐํ๋ ๋ฐฐ์ด์ shape์ (8, 1, 512, 512, 3)
์
๋๋ค. ์ด๋ฅผ ์ฌ๊ตฌ์ฑํ์ฌ ๋ ๋ฒ์งธ ์ฐจ์์ ์ ๊ฑฐํ๊ณ 512 ร 512 ร 3์ ์ด๋ฏธ์ง 8๊ฐ๋ฅผ ์ป์ ๋ค์ PIL๋ก ๋ณํํฉ๋๋ค.
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
์๊ฐํ
์ด๋ฏธ์ง๋ฅผ ๊ทธ๋ฆฌ๋์ ํ์ํ๋ ๋์ฐ๋ฏธ ํจ์๋ฅผ ๋ง๋ค์ด ๋ณด๊ฒ ์ต๋๋ค.
def image_grid(imgs, rows, cols):
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
image_grid(images, 2, 4)
๋ค๋ฅธ ํ๋กฌํํธ ์ฌ์ฉ
๋ชจ๋ ๋๋ฐ์ด์ค์์ ๋์ผํ ํ๋กฌํํธ๋ฅผ ๋ณต์ ํ ํ์๋ ์์ต๋๋ค. ํ๋กฌํํธ 2๊ฐ๋ฅผ ๊ฐ๊ฐ 4๋ฒ์ฉ ์์ฑํ๊ฑฐ๋ ํ ๋ฒ์ 8๊ฐ์ ์๋ก ๋ค๋ฅธ ํ๋กฌํํธ๋ฅผ ์์ฑํ๋ ๋ฑ ์ํ๋ ๊ฒ์ ๋ฌด์์ด๋ ํ ์ ์์ต๋๋ค. ํ๋ฒ ํด๋ณด์ธ์!
๋จผ์ ์ ๋ ฅ ์ค๋น ์ฝ๋๋ฅผ ํธ๋ฆฌํ ํจ์๋ก ๋ฆฌํฉํฐ๋งํ๊ฒ ์ต๋๋ค:
prompts = [
"Labrador in the style of Hokusai",
"Painting of a squirrel skating in New York",
"HAL-9000 in the style of Van Gogh",
"Times Square under water, with fish and a dolphin swimming around",
"Ancient Roman fresco showing a man working on his laptop",
"Close-up photograph of young black woman against urban background, high quality, bokeh",
"Armchair in the shape of an avocado",
"Clown astronaut in space, with Earth in the background",
]
prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
image_grid(images, 2, 4)
๋ณ๋ ฌํ(parallelization)๋ ์ด๋ป๊ฒ ์๋ํ๋๊ฐ?
์์ diffusers
Flax ํ์ดํ๋ผ์ธ์ด ๋ชจ๋ธ์ ์๋์ผ๋ก ์ปดํ์ผํ๊ณ ์ฌ์ฉ ๊ฐ๋ฅํ ๋ชจ๋ ๊ธฐ๊ธฐ์์ ๋ณ๋ ฌ๋ก ์คํํ๋ค๊ณ ๋ง์๋๋ ธ์ต๋๋ค. ์ด์ ๊ทธ ํ๋ก์ธ์ค๋ฅผ ๊ฐ๋ตํ๊ฒ ์ดํด๋ณด๊ณ ์๋ ๋ฐฉ์์ ๋ณด์ฌ๋๋ฆฌ๊ฒ ์ต๋๋ค.
JAX ๋ณ๋ ฌํ๋ ์ฌ๋ฌ ๊ฐ์ง ๋ฐฉ๋ฒ์ผ๋ก ์ํํ ์ ์์ต๋๋ค. ๊ฐ์ฅ ์ฌ์ด ๋ฐฉ๋ฒ์ jax.pmap ํจ์๋ฅผ ์ฌ์ฉํ์ฌ ๋จ์ผ ํ๋ก๊ทธ๋จ, ๋ค์ค ๋ฐ์ดํฐ(SPMD) ๋ณ๋ ฌํ๋ฅผ ๋ฌ์ฑํ๋ ๊ฒ์
๋๋ค. ์ฆ, ๋์ผํ ์ฝ๋์ ๋ณต์ฌ๋ณธ์ ๊ฐ๊ฐ ๋ค๋ฅธ ๋ฐ์ดํฐ ์
๋ ฅ์ ๋ํด ์ฌ๋ฌ ๊ฐ ์คํํ๋ ๊ฒ์
๋๋ค. ๋ ์ ๊ตํ ์ ๊ทผ ๋ฐฉ์๋ ๊ฐ๋ฅํ๋ฏ๋ก ๊ด์ฌ์ด ์์ผ์๋ค๋ฉด JAX ๋ฌธ์์ pjit
ํ์ด์ง์์ ์ด ์ฃผ์ ๋ฅผ ์ดํด๋ณด์๊ธฐ ๋ฐ๋๋๋ค!
jax.pmap
์ ๋ ๊ฐ์ง ๊ธฐ๋ฅ์ ์ํํฉ๋๋ค:
jax.jit()
๋ฅผ ํธ์ถํ ๊ฒ์ฒ๋ผ ์ฝ๋๋ฅผ ์ปดํ์ผ(๋๋jit
)ํฉ๋๋ค. ์ด ์์ ์pmap
์ ํธ์ถํ ๋๊ฐ ์๋๋ผ pmapped ํจ์๊ฐ ์ฒ์ ํธ์ถ๋ ๋ ์ํ๋ฉ๋๋ค.- ์ปดํ์ผ๋ ์ฝ๋๊ฐ ์ฌ์ฉ ๊ฐ๋ฅํ ๋ชจ๋ ๊ธฐ๊ธฐ์์ ๋ณ๋ ฌ๋ก ์คํ๋๋๋ก ํฉ๋๋ค.
์๋ ๋ฐฉ์์ ๋ณด์ฌ๋๋ฆฌ๊ธฐ ์ํด ์ด๋ฏธ์ง ์์ฑ์ ์คํํ๋ ๋น๊ณต๊ฐ ๋ฉ์๋์ธ ํ์ดํ๋ผ์ธ์ _generate
๋ฉ์๋๋ฅผ pmap
ํฉ๋๋ค. ์ด ๋ฉ์๋๋ ํฅํ Diffusers
๋ฆด๋ฆฌ์ค์์ ์ด๋ฆ์ด ๋ณ๊ฒฝ๋๊ฑฐ๋ ์ ๊ฑฐ๋ ์ ์๋ค๋ ์ ์ ์ ์ํ์ธ์.
p_generate = pmap(pipeline._generate)
pmap
์ ์ฌ์ฉํ ํ ์ค๋น๋ ํจ์ p_generate
๋ ๊ฐ๋
์ ์ผ๋ก ๋ค์์ ์ํํฉ๋๋ค:
- ๊ฐ ์ฅ์น์์ ๊ธฐ๋ณธ ํจ์
pipeline._generate
์ ๋ณต์ฌ๋ณธ์ ํธ์ถํฉ๋๋ค. - ๊ฐ ์ฅ์น์ ์
๋ ฅ ์ธ์์ ๋ค๋ฅธ ๋ถ๋ถ์ ๋ณด๋
๋๋ค. ์ด๊ฒ์ด ๋ฐ๋ก ์ค๋ฉ์ด ์ฌ์ฉ๋๋ ์ด์ ์
๋๋ค. ์ด ๊ฒฝ์ฐ
prompt_ids
์ shape์(8, 1, 77, 768)
์ ๋๋ค. ์ด ๋ฐฐ์ด์ 8๊ฐ๋ก ๋ถํ ๋๊ณ_generate
์ ๊ฐ ๋ณต์ฌ๋ณธ์(1, 77, 768)
์ shape์ ๊ฐ์ง ์ ๋ ฅ์ ๋ฐ๊ฒ ๋ฉ๋๋ค.
๋ณ๋ ฌ๋ก ํธ์ถ๋๋ค๋ ์ฌ์ค์ ์์ ํ ๋ฌด์ํ๊ณ _generate
๋ฅผ ์ฝ๋ฉํ ์ ์์ต๋๋ค. batch(๋ฐฐ์น) ํฌ๊ธฐ(์ด ์์ ์์๋ 1
)์ ์ฝ๋์ ์ ํฉํ ์ฐจ์๋ง ์ ๊ฒฝ ์ฐ๋ฉด ๋๋ฉฐ, ๋ณ๋ ฌ๋ก ์๋ํ๊ธฐ ์ํด ์๋ฌด๊ฒ๋ ๋ณ๊ฒฝํ ํ์๊ฐ ์์ต๋๋ค.
ํ์ดํ๋ผ์ธ ํธ์ถ์ ์ฌ์ฉํ ๋์ ๋ง์ฐฌ๊ฐ์ง๋ก, ๋ค์ ์ ์ ์ฒ์ ์คํํ ๋๋ ์๊ฐ์ด ๊ฑธ๋ฆฌ์ง๋ง ๊ทธ ์ดํ์๋ ํจ์ฌ ๋นจ๋ผ์ง๋๋ค.
%%time
images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()
images.shape
CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s
Wall time: 1min 15s
images.shape
(8, 1, 512, 512, 3)
JAX๋ ๋น๋๊ธฐ ๋์คํจ์น๋ฅผ ์ฌ์ฉํ๊ณ ๊ฐ๋ฅํ ํ ๋นจ๋ฆฌ ์ ์ด๊ถ์ Python ๋ฃจํ์ ๋ฐํํ๊ธฐ ๋๋ฌธ์ ์ถ๋ก ์๊ฐ์ ์ ํํ๊ฒ ์ธก์ ํ๊ธฐ ์ํด block_until_ready()
๋ฅผ ์ฌ์ฉํฉ๋๋ค. ์์ง ๊ตฌ์ฒดํ๋์ง ์์ ๊ณ์ฐ ๊ฒฐ๊ณผ๋ฅผ ์ฌ์ฉํ๋ ค๋ ๊ฒฝ์ฐ ์๋์ผ๋ก ์ฐจ๋จ์ด ์ํ๋๋ฏ๋ก ์ฝ๋์์ ์ด ํจ์๋ฅผ ์ฌ์ฉํ ํ์๊ฐ ์์ต๋๋ค.