Fred808 commited on
Commit
c36fb16
·
verified ·
1 Parent(s): e0e5738

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -21
app.py CHANGED
@@ -11,26 +11,10 @@ app = FastAPI()
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
14
- # Load the Falcon-7B model with 8-bit quantization (if CUDA is available)
15
- model_id = "tiiuae/falcon-7b-instruct"
16
  tokenizer = AutoTokenizer.from_pretrained(model_id)
17
-
18
- # Check if CUDA is available
19
- if torch.cuda.is_available():
20
- # Load the model with 8-bit quantization for GPU
21
- model = AutoModelForCausalLM.from_pretrained(
22
- model_id,
23
- revision="main", # Pin to a specific revision
24
- load_in_8bit=True,
25
- device_map="auto"
26
- )
27
- else:
28
- # Fallback to CPU or full precision
29
- model = AutoModelForCausalLM.from_pretrained(
30
- model_id,
31
- revision="main", # Pin to a specific revision
32
- device_map="auto"
33
- )
34
 
35
  # Create a text generation pipeline
36
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
@@ -38,8 +22,8 @@ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
38
  # Define request body schema
39
  class TextGenerationRequest(BaseModel):
40
  prompt: str
41
- max_new_tokens: int = 50
42
- temperature: float = 0.7
43
  top_k: int = 50
44
  top_p: float = 0.9
45
  do_sample: bool = True
 
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
14
+ # Load the GPT-2 model and tokenizer
15
+ model_id = "gpt2" # Use GPT-2
16
  tokenizer = AutoTokenizer.from_pretrained(model_id)
17
+ model = AutoModelForCausalLM.from_pretrained(model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # Create a text generation pipeline
20
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
22
  # Define request body schema
23
  class TextGenerationRequest(BaseModel):
24
  prompt: str
25
+ max_new_tokens: int = 50 # Reduce this for faster responses
26
+ temperature: float = 0.7 # Lower for more deterministic outputs
27
  top_k: int = 50
28
  top_p: float = 0.9
29
  do_sample: bool = True