Spaces:
Paused
Paused
from inference import load_q8_transformer | |
import hashlib | |
from q8_kernels.graph.graph import make_dynamic_graphed_callable | |
from argparse import Namespace | |
from diffusers import LTXPipeline | |
import types | |
import torch | |
# To account for the type-casting in `ff_output` of `LTXVideoTransformerBlock` | |
def patched_ltx_transformer_forward( | |
self, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: torch.Tensor, | |
temb: torch.Tensor, | |
image_rotary_emb = None, | |
encoder_attention_mask = None, | |
) -> torch.Tensor: | |
batch_size = hidden_states.size(0) | |
norm_hidden_states = self.norm1(hidden_states) | |
num_ada_params = self.scale_shift_table.shape[0] | |
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1) | |
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) | |
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa | |
attn_hidden_states = self.attn1( | |
hidden_states=norm_hidden_states, | |
encoder_hidden_states=None, | |
image_rotary_emb=image_rotary_emb, | |
) | |
hidden_states = hidden_states + attn_hidden_states * gate_msa | |
attn_hidden_states = self.attn2( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
image_rotary_emb=None, | |
attention_mask=encoder_attention_mask, | |
) | |
hidden_states = hidden_states + attn_hidden_states | |
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp | |
ff_output = self.ff(norm_hidden_states).to(norm_hidden_states.dtype) | |
hidden_states = hidden_states + ff_output * gate_mlp | |
return hidden_states | |
def load_transformer(): | |
args = Namespace() | |
args.q8_transformer_path = "sayakpaul/q8-ltx-video" | |
transformer = load_q8_transformer(args) | |
transformer.to(torch.bfloat16) | |
for b in transformer.transformer_blocks: | |
b.to(dtype=torch.float) | |
for n, m in transformer.transformer_blocks.named_parameters(): | |
if "scale_shift_table" in n: | |
m.data = m.data.to(torch.bfloat16) | |
for b in transformer.transformer_blocks: | |
b.forward = types.MethodType(patched_ltx_transformer_forward, b) | |
transformer.forward = make_dynamic_graphed_callable(transformer.forward) | |
return transformer | |
def warmup_transformer(pipe): | |
prompt_embeds = torch.load("prompt_embeds.pt", map_location="cuda", weights_only=True) | |
for _ in range(5): | |
_ = pipe( | |
**prompt_embeds, | |
output_type="latent", | |
width=768, | |
height=512, | |
num_frames=121 | |
) | |
def prepare_pipeline(): | |
pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", text_encoder=None, torch_dtype=torch.bfloat16) | |
pipe.transformer = load_transformer() | |
pipe = pipe.to("cuda") | |
pipe.transformer.compile() | |
pipe.set_progress_bar_config(disable=True) | |
warmup_transformer(pipe) | |
return pipe | |
def compute_hash(text: str) -> str: | |
# Encode the text to bytes | |
text_bytes = text.encode("utf-8") | |
# Create a SHA-256 hash object | |
hash_object = hashlib.sha256() | |
# Update the hash object with the text bytes | |
hash_object.update(text_bytes) | |
# Return the hexadecimal representation of the hash | |
return hash_object.hexdigest() | |