q8-ltx-video / app_utils.py
sayakpaul's picture
sayakpaul HF staff
Upload 14 files
8cf98bd verified
from inference import load_q8_transformer
import hashlib
from q8_kernels.graph.graph import make_dynamic_graphed_callable
from argparse import Namespace
from diffusers import LTXPipeline
import types
import torch
# To account for the type-casting in `ff_output` of `LTXVideoTransformerBlock`
def patched_ltx_transformer_forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb = None,
encoder_attention_mask = None,
) -> torch.Tensor:
batch_size = hidden_states.size(0)
norm_hidden_states = self.norm1(hidden_states)
num_ada_params = self.scale_shift_table.shape[0]
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
attn_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=None,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + attn_hidden_states * gate_msa
attn_hidden_states = self.attn2(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=None,
attention_mask=encoder_attention_mask,
)
hidden_states = hidden_states + attn_hidden_states
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp
ff_output = self.ff(norm_hidden_states).to(norm_hidden_states.dtype)
hidden_states = hidden_states + ff_output * gate_mlp
return hidden_states
def load_transformer():
args = Namespace()
args.q8_transformer_path = "sayakpaul/q8-ltx-video"
transformer = load_q8_transformer(args)
transformer.to(torch.bfloat16)
for b in transformer.transformer_blocks:
b.to(dtype=torch.float)
for n, m in transformer.transformer_blocks.named_parameters():
if "scale_shift_table" in n:
m.data = m.data.to(torch.bfloat16)
for b in transformer.transformer_blocks:
b.forward = types.MethodType(patched_ltx_transformer_forward, b)
transformer.forward = make_dynamic_graphed_callable(transformer.forward)
return transformer
def warmup_transformer(pipe):
prompt_embeds = torch.load("prompt_embeds.pt", map_location="cuda", weights_only=True)
for _ in range(5):
_ = pipe(
**prompt_embeds,
output_type="latent",
width=768,
height=512,
num_frames=121
)
def prepare_pipeline():
pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", text_encoder=None, torch_dtype=torch.bfloat16)
pipe.transformer = load_transformer()
pipe = pipe.to("cuda")
pipe.transformer.compile()
pipe.set_progress_bar_config(disable=True)
warmup_transformer(pipe)
return pipe
def compute_hash(text: str) -> str:
# Encode the text to bytes
text_bytes = text.encode("utf-8")
# Create a SHA-256 hash object
hash_object = hashlib.sha256()
# Update the hash object with the text bytes
hash_object.update(text_bytes)
# Return the hexadecimal representation of the hash
return hash_object.hexdigest()