acecalisto3 commited on
Commit
2637490
·
verified ·
1 Parent(s): af7f748

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py CHANGED
@@ -33,6 +33,61 @@ if 'current_state' not in st.session_state:
33
  'workspace_chat': {}
34
  }
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  class AIAgent:
37
  def __init__(self, name, description, skills):
38
  self.name = name
 
33
  'workspace_chat': {}
34
  }
35
 
36
+ class InstructModel:
37
+ def __init__(self):
38
+ """Initialize the Mixtral-8x7B-Instruct model"""
39
+ try:
40
+ self.model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"
41
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
42
+ self.model = AutoModelForCausalLM.from_pretrained(
43
+ self.model_name,
44
+ torch_dtype=torch.float16,
45
+ device_map="auto"
46
+ )
47
+ except Exception as e:
48
+ raise EnvironmentError(f"Failed to load model: {str(e)}")
49
+
50
+ def generate_response(self, prompt: str) -> str:
51
+ """Generate a response using the Mixtral model"""
52
+ try:
53
+ # Format the prompt according to Mixtral's expected format
54
+ formatted_prompt = f"<s>[INST] {prompt} [/INST]"
55
+
56
+ # Tokenize input
57
+ inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device)
58
+
59
+ # Generate response
60
+ outputs = self.model.generate(
61
+ inputs.input_ids,
62
+ max_new_tokens=512,
63
+ temperature=0.7,
64
+ top_p=0.95,
65
+ do_sample=True,
66
+ pad_token_id=self.tokenizer.eos_token_id
67
+ )
68
+
69
+ # Decode and clean up response
70
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
71
+
72
+ # Remove the prompt from the response
73
+ response = response.replace(formatted_prompt, "").strip()
74
+
75
+ return response
76
+
77
+ except Exception as e:
78
+ raise Exception(f"Error generating response: {str(e)}")
79
+
80
+ def __del__(self):
81
+ """Cleanup when the model is no longer needed"""
82
+ try:
83
+ del self.model
84
+ del self.tokenizer
85
+ torch.cuda.empty_cache()
86
+ except:
87
+ pass
88
+
89
+
90
+
91
  class AIAgent:
92
  def __init__(self, name, description, skills):
93
  self.name = name