|
""" |
|
This file is used for T2I generation, it also compute the clip similarity between the generated images and the input prompt |
|
""" |
|
from absl import flags |
|
from absl import app |
|
from ml_collections import config_flags |
|
import os |
|
|
|
import ml_collections |
|
import torch |
|
from torch import multiprocessing as mp |
|
import torch.nn as nn |
|
import accelerate |
|
import utils |
|
import tempfile |
|
from absl import logging |
|
import builtins |
|
import einops |
|
import math |
|
import numpy as np |
|
import time |
|
from PIL import Image |
|
|
|
from diffusion.flow_matching import FlowMatching, ODEFlowMatchingSolver, ODEEulerFlowMatchingSolver |
|
from tools.clip_score import ClipSocre |
|
import libs.autoencoder |
|
from libs.clip import FrozenCLIPEmbedder |
|
from libs.t5 import T5Embedder |
|
|
|
|
|
def unpreprocess(x): |
|
x = 0.5 * (x + 1.) |
|
x.clamp_(0., 1.) |
|
return x |
|
|
|
def get_caption(llm, text_model, _batch_prompt): |
|
_batch_con = _batch_prompt |
|
if llm == "clip": |
|
_latent, _latent_and_others = text_model.encode(_batch_con) |
|
_con = _latent_and_others['token_embedding'].detach() |
|
elif llm == "t5": |
|
_latent, _latent_and_others = text_model.get_text_embeddings(_batch_con) |
|
_con = (_latent_and_others['token_embedding'] * 10.0).detach() |
|
else: |
|
raise NotImplementedError |
|
_con_mask = _latent_and_others['token_mask'].detach() |
|
_batch_token = _latent_and_others['tokens'].detach() |
|
_batch_caption = _batch_con |
|
return (_con, _con_mask, _batch_token, _batch_caption) |
|
|
|
|
|
def evaluate(config): |
|
|
|
if config.get('benchmark', False): |
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cudnn.deterministic = False |
|
|
|
mp.set_start_method('spawn') |
|
accelerator = accelerate.Accelerator() |
|
device = accelerator.device |
|
accelerate.utils.set_seed(config.seed, device_specific=True) |
|
logging.info(f'Process {accelerator.process_index} using device: {device}') |
|
|
|
config.mixed_precision = accelerator.mixed_precision |
|
config = ml_collections.FrozenConfigDict(config) |
|
if accelerator.is_main_process: |
|
utils.set_logger(log_level='info', fname=config.output_path) |
|
else: |
|
utils.set_logger(log_level='error') |
|
builtins.print = lambda *args: None |
|
|
|
nnet = utils.get_nnet(**config.nnet) |
|
nnet = accelerator.prepare(nnet) |
|
logging.info(f'load nnet from {config.nnet_path}') |
|
accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu')) |
|
nnet.eval() |
|
|
|
|
|
|
|
if config.nnet.model_args.clip_dim == 4096: |
|
llm = "t5" |
|
t5 = T5Embedder(device=device) |
|
elif config.nnet.model_args.clip_dim == 768: |
|
llm = "clip" |
|
clip = FrozenCLIPEmbedder() |
|
clip.eval() |
|
clip.to(device) |
|
else: |
|
raise NotImplementedError |
|
|
|
if llm == "clip": |
|
context_generator = get_caption(llm, clip, _batch_prompt=[config.prompt]*config.sample.mini_batch_size) |
|
elif llm == "t5": |
|
context_generator = get_caption(llm, t5, _batch_prompt=[config.prompt]*config.sample.mini_batch_size) |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
|
|
autoencoder = libs.autoencoder.get_model(**config.autoencoder) |
|
autoencoder.to(device) |
|
|
|
@torch.cuda.amp.autocast() |
|
def encode(_batch): |
|
return autoencoder.encode(_batch) |
|
|
|
@torch.cuda.amp.autocast() |
|
def decode(_batch): |
|
return autoencoder.decode(_batch) |
|
|
|
bdv_nnet = None |
|
ClipSocre_model = ClipSocre(device=device) |
|
|
|
|
|
logging.info(config.sample) |
|
logging.info(f'sample: n_samples={config.sample.n_samples}, mode=t2i, mixed_precision={config.mixed_precision}') |
|
|
|
|
|
def ode_fm_solver_sample(nnet_ema, _n_samples, _sample_steps, bdv_nnet=bdv_nnet, context=None, caption=None, testbatch_img_blurred=None, two_stage_generation=-1, token=None, token_mask=None, return_clipScore=False, ClipSocre_model=None): |
|
with torch.no_grad(): |
|
del testbatch_img_blurred |
|
|
|
_z_gaussian = torch.randn(_n_samples, *config.z_shape, device=device) |
|
|
|
if 'dimr' in config.nnet.name or 'dit' in config.nnet.name: |
|
_z_x0, _mu, _log_var = nnet_ema(context, text_encoder = True, shape = _z_gaussian.shape, mask=token_mask) |
|
_z_init = _z_x0.reshape(_z_gaussian.shape) |
|
else: |
|
raise NotImplementedError |
|
|
|
assert config.sample.scale > 1 |
|
if config.cfg != -1: |
|
_cfg = config.cfg |
|
else: |
|
_cfg = config.sample.scale |
|
|
|
has_null_indicator = hasattr(config.nnet.model_args, "cfg_indicator") |
|
|
|
_sample_steps = config.sample.sample_steps |
|
|
|
ode_solver = ODEEulerFlowMatchingSolver(nnet_ema, bdv_model_fn=bdv_nnet, step_size_type="step_in_dsigma", guidance_scale=_cfg) |
|
_z, _ = ode_solver.sample(x_T=_z_init, batch_size=_n_samples, sample_steps=_sample_steps, unconditional_guidance_scale=_cfg, has_null_indicator=has_null_indicator) |
|
|
|
image_unprocessed = decode(_z) |
|
clip_score = ClipSocre_model.calculate_clip_score(caption, image_unprocessed) |
|
|
|
return image_unprocessed, clip_score |
|
|
|
|
|
def sample_fn(_n_samples, return_caption=False, return_clipScore=False, ClipSocre_model=None, config=None): |
|
_context, _token_mask, _token, _caption = context_generator |
|
assert _context.size(0) == _n_samples |
|
assert return_clipScore |
|
assert not return_caption |
|
return ode_fm_solver_sample(nnet, _n_samples, config.sample.sample_steps, bdv_nnet=bdv_nnet, context=_context, token=_token, token_mask=_token_mask, return_clipScore=return_clipScore, ClipSocre_model=ClipSocre_model, caption=_caption) |
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_path: |
|
path = config.img_save_path or config.sample.path or temp_path |
|
if accelerator.is_main_process: |
|
os.makedirs(path, exist_ok=True) |
|
logging.info(f'Samples are saved in {path}') |
|
|
|
clip_score_list = utils.sample2dir_wCLIP(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, unpreprocess, return_clipScore=True, ClipSocre_model=ClipSocre_model, config=config) |
|
if clip_score_list is not None: |
|
_clip_score_list = torch.cat(clip_score_list) |
|
if accelerator.is_main_process: |
|
logging.info(f'nnet_path={config.nnet_path}, clip_score{len(_clip_score_list)}={_clip_score_list.mean().item()}') |
|
|
|
|
|
FLAGS = flags.FLAGS |
|
config_flags.DEFINE_config_file( |
|
"config", None, "Training configuration.", lock_config=False) |
|
|
|
flags.mark_flags_as_required(["config"]) |
|
flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.") |
|
flags.DEFINE_string("prompt", None, "The prompt used for generation.") |
|
flags.DEFINE_string("output_path", None, "The path to output log.") |
|
flags.DEFINE_float("cfg", -1, 'cfg scale, will use the scale defined in the config file is not assigned') |
|
flags.DEFINE_string("img_save_path", None, "The path to image log.") |
|
|
|
|
|
def main(argv): |
|
config = FLAGS.config |
|
config.nnet_path = FLAGS.nnet_path |
|
config.prompt = FLAGS.prompt |
|
config.output_path = FLAGS.output_path |
|
config.img_save_path = FLAGS.img_save_path |
|
config.cfg = FLAGS.cfg |
|
evaluate(config) |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run(main) |
|
|