diff --git a/inference.py b/inference.py index d3ff634..5313b6f 100644 --- a/inference.py +++ b/inference.py @@ -131,10 +131,8 @@ def compute_f1(self): tq = tqdm(self.test_loader) for batch_num, batch in enumerate(tq): tokenized_batch = self.tokenize(batch) - ids = tokenized_batch['input_ids'].to(self.device) - mask = tokenized_batch['attention_mask'].to(self.device) y = batch['label'] - y_pred = self.predict({'input_ids': ids, 'attention_mask': mask}) + y_pred = self.predict(tokenized_batch) preds.extend(list(y_pred)) true.extend(list(y.numpy())) tq.set_description( @@ -144,6 +142,7 @@ def compute_f1(self): return f1 def predict(self, input): + input = {k:v.to(self.device) for k,v in input.items()} logits = self.model(**input).logits return torch.argmax(logits, axis=1).cpu() @@ -194,6 +193,8 @@ def load_model(self): return compiled_model, tokenizer def predict(self, input): + input.pop('token_type_ids') + input = {k: v.to(self.device) for k, v in input.items()} logits = self.model(input)[self.output_layer] return np.argmax(logits, axis=1) @@ -204,19 +205,23 @@ def measure_size_mb(self): def measure_num_params(self): return 0 - FORMATS = {'pt': Evaluator, 'openvino': OpenVinoEvaluator} if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config_path", type=str, required=True) parser.add_argument("--save_report", type=bool, default=False) - parser.add_argument("--format", type=str, default='pt') + parser.add_argument("--format", type=str, default=None) args = parser.parse_args() config_path = args.config_path save_report = args.save_report format = args.format + if not format: + with open(config_path) as config_js: + config = json.load(config_js) + format = config.get('convert_to_format', 'pt') + assert format in FORMATS, f'Unknown format {format}.' \ f'The model can only be loaded from the following formats: {",".join(FORMATS.keys())}' EvaluatorClass = FORMATS[format]