fix source prefix
Browse files- pipeline.py +2 -2
pipeline.py
CHANGED
@@ -125,9 +125,9 @@ class QASRL_Pipeline(Text2TextGenerationPipeline):
|
|
125 |
def _get_source_prefix(self, predicate_type: Optional[str]):
|
126 |
if not self.is_t5_model or self.data_args.source_prefix is None:
|
127 |
return ''
|
128 |
-
if "
|
129 |
if predicate_type is None:
|
130 |
-
raise ValueError("source_prefix includes '
|
131 |
if self.data_args.source_prefix == "Generate QAs for <predicate_type> QASRL: ": # backwrad compatibility - "Generate QAs for <predicate_type> QASRL: " alone was a sign for a longer prefix
|
132 |
return f"Generate QAs for {predicate_type} QASRL: "
|
133 |
else:
|
|
|
125 |
def _get_source_prefix(self, predicate_type: Optional[str]):
|
126 |
if not self.is_t5_model or self.data_args.source_prefix is None:
|
127 |
return ''
|
128 |
+
if "<predicate_type>" in self.data_args.source_prefix:
|
129 |
if predicate_type is None:
|
130 |
+
raise ValueError("source_prefix includes '<predicate_type>' but input has no `predicate_type`.")
|
131 |
if self.data_args.source_prefix == "Generate QAs for <predicate_type> QASRL: ": # backwrad compatibility - "Generate QAs for <predicate_type> QASRL: " alone was a sign for a longer prefix
|
132 |
return f"Generate QAs for {predicate_type} QASRL: "
|
133 |
else:
|