import logging import os import onnx import tensorrt as trt from typing import List from collections import OrderedDict from onnx import shape_inference def vit_tagging_t2t(input_path="simple_model.onnx",output_path="vit.trt"): model = onnx.load(input_path) inferred_model = shape_inference.infer_shapes(model) #print(inferred_model.graph.value_info) simplified_model = input_path bitmask = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) trt_logger = trt.Logger() all_count,mix_count=0,0 with trt.Builder(trt_logger) as builder, builder.create_network(bitmask) as network, builder.create_builder_config() as config, trt.OnnxParser(network, trt_logger) as parser: #config.max_workspace_size = self.max_workspace_size config.set_flag(trt.BuilderFlag.FP16) with open(simplified_model, 'rb') as f: success = parser.parse(f.read()) if not success: for idx in range(parser.num_errors): print(parser.get_error(idx)) raise RuntimeError("Failed to parse the ONNX file.") profile = builder.create_optimization_profile() min_shape = [3,224,224] max_shape = [3,224,224] opt_shape = max_shape #opt shape=max shape by default profile.set_shape("input", min=(1, *min_shape), opt=(70, *opt_shape), max=(70, *max_shape)) config.add_optimization_profile(profile) """ for i in range(network.num_layers): all_count+=1 layer = network.get_layer(i) if "ReduceMean" in layer.name or "Pow" in layer.name: mix_count+=1 config.set_flag(trt.BuilderFlag.STRICT_TYPES) layer.precision = trt.float32 layer.set_output_type(0, trt.float32) """ #networtgetInput(0)->setType(DataType::kHALF) network.get_input(0).dtype = trt.float32 network.get_output(0).dtype = trt.float32 print(all_count,mix_count) engine = builder.build_engine(network, config) #print(engine) with open(output_path, 'wb') as f: f.write(engine.serialize()) f.close() if __name__=="__main__": vit_tagging_t2t()