Update modeling_ovis.py
Browse files- modeling_ovis.py +4 -3
modeling_ovis.py
CHANGED
@@ -289,9 +289,10 @@ class Ovis(OvisPreTrainedModel):
|
|
289 |
attn_kwargs = dict()
|
290 |
if self.config.llm_attn_implementation:
|
291 |
if self.config.llm_attn_implementation == "flash_attention_2":
|
292 |
-
assert
|
293 |
-
|
294 |
-
|
|
|
295 |
attn_kwargs["attn_implementation"] = self.config.llm_attn_implementation
|
296 |
self.llm = AutoModelForCausalLM.from_config(self.config.llm_config, **attn_kwargs)
|
297 |
assert self.config.hidden_size == self.llm.config.hidden_size, "hidden size mismatch"
|
|
|
289 |
attn_kwargs = dict()
|
290 |
if self.config.llm_attn_implementation:
|
291 |
if self.config.llm_attn_implementation == "flash_attention_2":
|
292 |
+
assert is_flash_attn_2_available() # kyujin modified
|
293 |
+
#assert (is_flash_attn_2_available() and
|
294 |
+
# version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.6.3")), \
|
295 |
+
# "Using `flash_attention_2` requires having `flash_attn>=2.6.3` installed."
|
296 |
attn_kwargs["attn_implementation"] = self.config.llm_attn_implementation
|
297 |
self.llm = AutoModelForCausalLM.from_config(self.config.llm_config, **attn_kwargs)
|
298 |
assert self.config.hidden_size == self.llm.config.hidden_size, "hidden size mismatch"
|