akshaytrikha commited on
Commit
6d27b42
·
verified ·
1 Parent(s): 9ad6efd

Revert app.py to non ZeroGPU

Browse files
Files changed (1) hide show
  1. app.py +15 -26
app.py CHANGED
@@ -2,33 +2,23 @@ import gradio as gr
2
  from transformers import GPT2LMHeadModel, AutoTokenizer, pipeline
3
  import torch
4
 
5
- def create_model():
6
- # Initialize model with device-agnostic loading
7
- model = GPT2LMHeadModel.from_pretrained(
8
- "./model/gpt2-355M",
9
- device_map="auto", # Enables automatic device mapping
10
- torch_dtype=torch.float16 # Use float16 for better memory efficiency
11
- )
12
-
13
- tokenizer = AutoTokenizer.from_pretrained(
14
- "gpt2",
15
- pad_token='<|endoftext|>'
16
- )
17
-
18
- return pipeline(
19
- "text-generation",
20
- model=model,
21
- tokenizer=tokenizer,
22
- device_map="auto",
23
- config={"max_length": 140}
24
- )
25
-
26
- # Initialize the pipeline
27
- trump = create_model()
28
 
29
  def generate(text):
30
  result = trump(text, num_return_sequences=1)
31
- return result[0]["generated_text"].replace('"', '')
32
 
33
  examples = [
34
  ["Why does the lying news media"],
@@ -43,5 +33,4 @@ demo = gr.Interface(
43
  examples=examples
44
  )
45
 
46
- if __name__ == "__main__":
47
- demo.launch()
 
2
  from transformers import GPT2LMHeadModel, AutoTokenizer, pipeline
3
  import torch
4
 
5
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
6
+
7
+ # load pretrained + finetuned GPT2
8
+ model = GPT2LMHeadModel.from_pretrained("./model/gpt2-355M")
9
+ model = model.to(device)
10
+
11
+ # create tokenizer
12
+ tokenizer = AutoTokenizer.from_pretrained(
13
+ "gpt2",
14
+ pad_token='<|endoftext|>'
15
+ )
16
+
17
+ trump = pipeline("text-generation", model=model, tokenizer=tokenizer, config={"max_length":140})
 
 
 
 
 
 
 
 
 
 
18
 
19
  def generate(text):
20
  result = trump(text, num_return_sequences=1)
21
+ return result[0]["generated_text"].replace('"', '') # remove quotation marks
22
 
23
  examples = [
24
  ["Why does the lying news media"],
 
33
  examples=examples
34
  )
35
 
36
+ demo.launch()