whackthejacker commited on
Commit
2917bca
·
verified ·
1 Parent(s): 961806a

Update codegen.py

Browse files
Files changed (1) hide show
  1. codegen.py +77 -40
codegen.py CHANGED
@@ -1,45 +1,82 @@
1
-
2
  import transformers
3
  from transformers import pipeline
4
 
5
- def generate(idea):
6
- """
7
- Generates code based on a given idea using the bigscience/T0_3B model.
8
-
9
- Args:
10
- idea (str): The idea for the code to be generated.
11
-
12
- Returns:
13
- str: The generated code.
14
- """
15
- # Load the code generation model
16
- model_name = "bigscience/T0_3B" # Use a model that works for code generation
17
- model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
18
- tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
19
-
20
- # Generate the code
21
- input_text = f"""
22
- # Idea: {idea}
23
- # Code:
24
- """
25
- input_ids = tokenizer.encode(input_text, return_tensors="pt")
26
- output_sequences = model.generate(
27
- input_ids=input_ids,
28
- max_length=1024,
29
- num_return_sequences=1,
30
- no_repeat_ngram_size=2,
31
- early_stopping=True,
32
- temperature=0.7, # Adjust temperature for creativity
33
- top_k=50, # Adjust top_k for diversity
34
- )
35
- generated_code = tokenizer.decode(output_sequences[0], skip_special_tokens=True)
36
-
37
- # Remove the prompt and formatting
38
- generated_code = generated_code.split("\n# Code:")[1].strip()
39
-
40
- return generated_code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  # Example usage
43
- idea = "Write a Python function to calculate the factorial of a number"
44
- code = generate(idea)
45
- print(code)
 
 
 
 
1
  import transformers
2
  from transformers import pipeline
3
 
4
+ class CodeGenerator:
5
+ def __init__(self, model_name="bigscience/T0_3B"):
6
+ """
7
+ Initializes the CodeGenerator with a specified model.
8
+
9
+ Args:
10
+ model_name (str): The name of the model to be used for code generation.
11
+ """
12
+ self.model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
13
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
14
+
15
+ def generate_code(self, idea):
16
+ """
17
+ Generates code based on a given idea using the specified model.
18
+
19
+ Args:
20
+ idea (str): The idea for the code to be generated.
21
+
22
+ Returns:
23
+ str: The generated code.
24
+ """
25
+ input_text = self._format_input(idea)
26
+ input_ids = self.tokenizer.encode(input_text, return_tensors="pt")
27
+ output_sequences = self._generate_output(input_ids)
28
+ generated_code = self._extract_code(output_sequences)
29
+
30
+ return generated_code
31
+
32
+ def _format_input(self, idea):
33
+ """
34
+ Formats the input text for the model.
35
+
36
+ Args:
37
+ idea (str): The idea for the code to be generated.
38
+
39
+ Returns:
40
+ str: Formatted input text.
41
+ """
42
+ return f"# Idea: {idea}\n# Code:\n"
43
+
44
+ def _generate_output(self, input_ids):
45
+ """
46
+ Generates output sequences from the model.
47
+
48
+ Args:
49
+ input_ids (tensor): The input IDs for the model.
50
+
51
+ Returns:
52
+ tensor: The generated output sequences.
53
+ """
54
+ return self.model.generate(
55
+ input_ids=input_ids,
56
+ max_length=1024,
57
+ num_return_sequences=1,
58
+ no_repeat_ngram_size=2,
59
+ early_stopping=True,
60
+ temperature=0.7,
61
+ top_k=50,
62
+ )
63
+
64
+ def _extract_code(self, output_sequences):
65
+ """
66
+ Extracts the generated code from the output sequences.
67
+
68
+ Args:
69
+ output_sequences (tensor): The generated output sequences.
70
+
71
+ Returns:
72
+ str: The extracted code.
73
+ """
74
+ generated_code = self.tokenizer.decode(output_sequences[0], skip_special_tokens=True)
75
+ return generated_code.split("\n# Code:")[1].strip()
76
 
77
  # Example usage
78
+ if __name__ == "__main__":
79
+ idea = "Write a Python function to calculate the factorial of a number"
80
+ code_generator = CodeGenerator()
81
+ generated_code = code_generator.generate_code(idea)
82
+ print(generated_code)