Spaces:
Runtime error
Runtime error
Update qasrl_model_pipeline.py
Browse filesadd update_config to fix max_length problem and allow customization in __init__
- qasrl_model_pipeline.py +9 -1
qasrl_model_pipeline.py
CHANGED
@@ -61,7 +61,15 @@ class QASRL_Pipeline(Text2TextGenerationPipeline):
|
|
61 |
self.data_args.use_bilateral_predicate_marker = True
|
62 |
if "append_verb_form" not in vars(self.data_args):
|
63 |
self.data_args.append_verb_form = True
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
def _sanitize_parameters(self, **kwargs):
|
67 |
preprocess_kwargs, forward_kwargs, postprocess_kwargs = {}, {}, {}
|
|
|
61 |
self.data_args.use_bilateral_predicate_marker = True
|
62 |
if "append_verb_form" not in vars(self.data_args):
|
63 |
self.data_args.append_verb_form = True
|
64 |
+
self._update_config(**kwargs)
|
65 |
+
|
66 |
+
def _update_config(self, **kwargs):
|
67 |
+
" Update self.model.config with initialization parameters and necessary defaults. "
|
68 |
+
# set default values that will always override model.config, but can overriden by __init__ kwargs
|
69 |
+
kwargs["max_length"] = kwargs.get("max_length", 80)
|
70 |
+
# override model.config with kwargs
|
71 |
+
for k,v in kwargs.items():
|
72 |
+
self.model.config.__dict__[k] = v
|
73 |
|
74 |
def _sanitize_parameters(self, **kwargs):
|
75 |
preprocess_kwargs, forward_kwargs, postprocess_kwargs = {}, {}, {}
|