adamo1139 commited on
Commit
54ae8cb
·
verified ·
1 Parent(s): a7be7fd

Upload yi-34b-dpo-rawrr-v2-2-hf.py

Browse files
Files changed (1) hide show
  1. yi-34b-dpo-rawrr-v2-2-hf.py +133 -0
yi-34b-dpo-rawrr-v2-2-hf.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unsloth import FastLanguageModel
2
+ from datasets import Dataset, load_dataset
3
+ from dataclasses import dataclass, field
4
+ from typing import Dict, Optional
5
+ import torch
6
+ max_seq_length = 4096 # Choose any! We auto support RoPE Scaling internally!
7
+ dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
8
+ load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
9
+
10
+ model, tokenizer = FastLanguageModel.from_pretrained(
11
+ model_name = "larryvrh/Yi-34B-200K-Llamafied", # Choose ANY! eg mistralai/Mistral-7B-Instruct-v0.2
12
+ max_seq_length = max_seq_length,
13
+ attn_implementation="flash_attention_2",
14
+ dtype = dtype,
15
+ load_in_4bit = load_in_4bit,
16
+ # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
17
+ )
18
+
19
+
20
+
21
+ #@title Alignment Handbook utils
22
+ import os
23
+ import re
24
+ from typing import List, Literal, Optional
25
+
26
+ from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
27
+ from datasets.builder import DatasetGenerationError
28
+
29
+
30
+ #DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
31
+ tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
32
+
33
+ def chatml_format(example):
34
+ # Format system
35
+ if len(example['system']) > 0:
36
+ message = {"role": "system", "content": example['system']}
37
+ system = tokenizer.apply_chat_template([message], tokenize=False)
38
+ else:
39
+ system = ""
40
+
41
+ # Format instruction
42
+ message = {"role": "user", "content": example['prompt']}
43
+ prompt = tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=True)
44
+
45
+ # Format chosen answer
46
+ chosen = example['chosen'] + "<|im_end|>\n"
47
+
48
+ # Format rejected answer
49
+ rejected = example['rejected'] + "<|im_end|>\n"
50
+
51
+ return {
52
+ "prompt": system + prompt,
53
+ "chosen": chosen,
54
+ "rejected": rejected,
55
+ }
56
+
57
+ # Load dataset
58
+ dataset = load_dataset("adamo1139/rawrr_v2", split="train")
59
+
60
+ import pprint
61
+ pprint.pprint("""NOT a formatted dataset
62
+ """)
63
+ pprint
64
+ pprint.pprint(dataset[250])
65
+ pprint.pprint(dataset[260])
66
+ pprint.pprint(dataset[270])
67
+ pprint.pprint(dataset[280])
68
+ pprint.pprint(dataset[290])
69
+ # Save columns
70
+ original_columns = dataset.column_names
71
+
72
+ # Format dataset
73
+ dataset = dataset.map(
74
+ chatml_format,
75
+ remove_columns=original_columns
76
+ )
77
+
78
+ # Print sample
79
+ pprint.pprint("""formatted dataset""")
80
+ pprint.pprint(dataset[250])
81
+ pprint.pprint(dataset[260])
82
+ pprint.pprint(dataset[270])
83
+ pprint.pprint(dataset[280])
84
+ pprint.pprint(dataset[290])
85
+
86
+
87
+ model = FastLanguageModel.get_peft_model(
88
+ model,
89
+ r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
90
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
91
+ "gate_proj", "up_proj", "down_proj",],
92
+ lora_alpha = 32,
93
+ lora_dropout = 0, # Currently only supports dropout = 0
94
+ bias = "none", # Currently only supports bias = "none"
95
+ use_gradient_checkpointing = True,
96
+ random_state = 3407,
97
+ use_rslora = False, # We support rank stabilized LoRA
98
+ loftq_config = None, # And LoftQ
99
+ )
100
+
101
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
102
+ from trl import DPOTrainer
103
+
104
+ dpo_trainer = DPOTrainer(
105
+ model = model,
106
+ ref_model = None,
107
+ args = TrainingArguments(
108
+ per_device_train_batch_size = 1,
109
+ gradient_accumulation_steps = 16,
110
+ warmup_ratio = 0.05,
111
+ num_train_epochs = 1,
112
+ learning_rate = 0.000045,
113
+ fp16 = not torch.cuda.is_bf16_supported(),
114
+ bf16 = torch.cuda.is_bf16_supported(),
115
+ logging_steps = 1,
116
+ optim = "adamw_8bit",
117
+ weight_decay = 0.0,
118
+ lr_scheduler_type = "linear",
119
+ seed = 42,
120
+ save_strategy = "steps",
121
+ save_steps = 100,
122
+ save_total_limit = 10,
123
+ output_dir = "rawrr_v2_run2",
124
+ ),
125
+ beta = 0.1,
126
+ train_dataset = dataset,
127
+ # eval_dataset = raw_datasets["test"],
128
+ tokenizer = tokenizer,
129
+ max_length = 700,
130
+ max_prompt_length = 400,
131
+ )
132
+ dpo_trainer.train()
133
+ model.save_pretrained("yi-34b-200k_rawrr_v2_run2") # Local saving