yadonglu commited on
Commit
d5b3c6d
·
1 Parent(s): c1dc596

fix forward

Browse files
Files changed (1) hide show
  1. util/utils.py +4 -4
util/utils.py CHANGED
@@ -112,10 +112,10 @@ def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_
112
  inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt", do_resize=False).to(device=device, dtype=torch.float16)
113
  else:
114
  inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
115
- if 'florence' in model.config.name_or_path:
116
- generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=20,num_beams=1, do_sample=False)
117
- else:
118
- generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True,
119
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
120
  generated_text = [gen.strip() for gen in generated_text]
121
  generated_texts.extend(generated_text)
 
112
  inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt", do_resize=False).to(device=device, dtype=torch.float16)
113
  else:
114
  inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
115
+ # if 'florence' in model.config.name_or_path:
116
+ generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=20,num_beams=1, do_sample=False)
117
+ # else:
118
+ # generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True,
119
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
120
  generated_text = [gen.strip() for gen in generated_text]
121
  generated_texts.extend(generated_text)