File size: 3,306 Bytes
8cf98bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from inference import load_q8_transformer
import hashlib
from q8_kernels.graph.graph import make_dynamic_graphed_callable
from argparse import Namespace
from diffusers import LTXPipeline
import types
import torch

# To account for the type-casting in `ff_output` of `LTXVideoTransformerBlock` 
def patched_ltx_transformer_forward(
    self,
    hidden_states: torch.Tensor,
    encoder_hidden_states: torch.Tensor,
    temb: torch.Tensor,
    image_rotary_emb = None,
    encoder_attention_mask = None,
) -> torch.Tensor:
    batch_size = hidden_states.size(0)
    norm_hidden_states = self.norm1(hidden_states)

    num_ada_params = self.scale_shift_table.shape[0]
    ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
    shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
    norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa

    attn_hidden_states = self.attn1(
        hidden_states=norm_hidden_states,
        encoder_hidden_states=None,
        image_rotary_emb=image_rotary_emb,
    )
    hidden_states = hidden_states + attn_hidden_states * gate_msa

    attn_hidden_states = self.attn2(
        hidden_states,
        encoder_hidden_states=encoder_hidden_states,
        image_rotary_emb=None,
        attention_mask=encoder_attention_mask,
    )
    hidden_states = hidden_states + attn_hidden_states
    norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp

    ff_output = self.ff(norm_hidden_states).to(norm_hidden_states.dtype)
    hidden_states = hidden_states + ff_output * gate_mlp

    return hidden_states

def load_transformer():
    args = Namespace()
    args.q8_transformer_path = "sayakpaul/q8-ltx-video"
    transformer = load_q8_transformer(args)

    transformer.to(torch.bfloat16)
    for b in transformer.transformer_blocks:
        b.to(dtype=torch.float)

    for n, m in transformer.transformer_blocks.named_parameters():
        if "scale_shift_table" in n:
            m.data = m.data.to(torch.bfloat16)
    
    for b in transformer.transformer_blocks:
        b.forward = types.MethodType(patched_ltx_transformer_forward, b)
    
    transformer.forward = make_dynamic_graphed_callable(transformer.forward)
    return transformer 

def warmup_transformer(pipe):
    prompt_embeds = torch.load("prompt_embeds.pt", map_location="cuda", weights_only=True)
    for _ in range(5):
        _ = pipe(
            **prompt_embeds,
            output_type="latent",  
            width=768,
            height=512,
            num_frames=121
        )

def prepare_pipeline():
    pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", text_encoder=None, torch_dtype=torch.bfloat16)
    pipe.transformer = load_transformer()
    pipe = pipe.to("cuda")
    pipe.transformer.compile()
    pipe.set_progress_bar_config(disable=True)

    warmup_transformer(pipe)
    return pipe


def compute_hash(text: str) -> str:
    # Encode the text to bytes
    text_bytes = text.encode("utf-8")

    # Create a SHA-256 hash object
    hash_object = hashlib.sha256()

    # Update the hash object with the text bytes
    hash_object.update(text_bytes)

    # Return the hexadecimal representation of the hash
    return hash_object.hexdigest()