Open zhangjun opened 2 years ago
def build_engine(weights):
# For more information on TRT basics, refer to the introductory samples.
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(common.EXPLICIT_BATCH) as network:
builder.max_workspace_size = common.GiB(1)
# Populate the network using weights from the PyTorch model.
# populate_network(network, weights)
# Build and return an engine.
input_tensor = network.add_input(
name=ModelData.INPUT_NAME, dtype=ModelData.DTYPE, shape=ModelData.INPUT_SHAPE)
conv1_w = weights['conv1.weight'].cpu().numpy()
out_planes = 64
conv1 = conv3x3(input_tensor, out_planes, network, conv1_w)
network.mark_output(tensor=conv1.get_output(0))
config = builder.create_builder_config()
profile = builder.create_optimization_profile()
profile.set_shape("input", (1,64, 1, 1), ( 1,64, 150, 250), (1,64, 200, 300))
idx = config.add_optimization_profile(profile)
engine = builder.build_cuda_engine(network)
# build engine
# engine = builder.build_engine(network, config=config)
return engine
IPluginV2
class IPluginV2
{
public:
virtual int32_t getTensorRTVersion() const noexcept;
virtual AsciiChar const* getPluginType() const noexcept = 0;
virtual AsciiChar const* getPluginVersion() const noexcept = 0;
virtual int32_t getNbOutputs() const noexcept = 0;
virtual Dims getOutputDimensions(int32_t index, Dims const* inputs, int32_t nbInputDims) noexcept = 0;
virtual bool supportsFormat(DataType type, PluginFormat format) const noexcept = 0;
virtual void configureWithFormat(Dims const* inputDims, int32_t nbInputs, Dims const* outputDims, int32_t nbOutputs,
DataType type, PluginFormat format, int32_t maxBatchSize) noexcept
= 0;
virtual int32_t initialize() noexcept = 0;
virtual void terminate() noexcept = 0;
virtual size_t getWorkspaceSize(int32_t maxBatchSize) const noexcept = 0;
virtual int32_t enqueue(int32_t batchSize, void const* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) noexcept
= 0;
virtual size_t getSerializationSize() const noexcept = 0;
virtual void serialize(void* buffer) const noexcept = 0;
virtual void destroy() noexcept = 0;
virtual IPluginV2* clone() const noexcept = 0;
virtual void setPluginNamespace(AsciiChar const* pluginNamespace) noexcept = 0;
virtual AsciiChar const* getPluginNamespace() const noexcept = 0;
IPluginV2() = default;
virtual ~IPluginV2() noexcept = default;
protected:
// @cond SuppressDoxyWarnings
IPluginV2(IPluginV2 const&) = default;
IPluginV2(IPluginV2&&) = default;
IPluginV2& operator=(IPluginV2 const&) & = default;
IPluginV2& operator=(IPluginV2&&) & = default;
// @endcond
};
IPluginV2Ext 支持不同输出数据类型,broadcast across batch
class IPluginV2Ext : public IPluginV2
{
public:
virtual nvinfer1::DataType getOutputDataType(
int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept
= 0;
virtual bool isOutputBroadcastAcrossBatch(
int32_t outputIndex, bool const* inputIsBroadcasted, int32_t nbInputs) const noexcept
= 0;
virtual bool canBroadcastInputAcrossBatch(int32_t inputIndex) const noexcept = 0;
virtual void configurePlugin(Dims const* inputDims, int32_t nbInputs, Dims const* outputDims, int32_t nbOutputs,
DataType const* inputTypes, DataType const* outputTypes, bool const* inputIsBroadcast,
bool const* outputIsBroadcast, PluginFormat floatFormat, int32_t maxBatchSize) noexcept
= 0;
IPluginV2Ext() = default;
~IPluginV2Ext() override = default;
virtual void attachToContext(
cudnnContext* /*cudnn*/, cublasContext* /*cublas*/, IGpuAllocator* /*allocator*/) noexcept
{
}
virtual void detachFromContext() noexcept {}
IPluginV2Ext* clone() const noexcept override = 0;
protected:
// @cond SuppressDoxyWarnings
IPluginV2Ext(IPluginV2Ext const&) = default;
IPluginV2Ext(IPluginV2Ext&&) = default;
IPluginV2Ext& operator=(IPluginV2Ext const&) & = default;
IPluginV2Ext& operator=(IPluginV2Ext&&) & = default;
};
class IPluginV2IOExt : public IPluginV2Ext
{
public:
virtual void configurePlugin(
PluginTensorDesc const* in, int32_t nbInput, PluginTensorDesc const* out, int32_t nbOutput) noexcept
= 0;
virtual bool supportsFormatCombination(
int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) const noexcept
= 0;
IPluginV2IOExt() = default;
~IPluginV2IOExt() override = default;
protected:
IPluginV2IOExt(IPluginV2IOExt const&) = default;
IPluginV2IOExt(IPluginV2IOExt&&) = default;
IPluginV2IOExt& operator=(IPluginV2IOExt const&) & = default;
IPluginV2IOExt& operator=(IPluginV2IOExt&&) & = default;
};
class IPluginCreator
{
public:
virtual AsciiChar const* getPluginName() const noexcept = 0;
virtual AsciiChar const* getPluginVersion() const noexcept = 0;
virtual PluginFieldCollection const* getFieldNames() noexcept = 0;
virtual IPluginV2* createPlugin(AsciiChar const* name, PluginFieldCollection const* fc) noexcept = 0;
virtual IPluginV2* deserializePlugin(AsciiChar const* name, void const* serialData, size_t serialLength) noexcept
= 0;
virtual void setPluginNamespace(AsciiChar const* pluginNamespace) noexcept = 0;
virtual AsciiChar const* getPluginNamespace() const noexcept = 0;
IPluginCreator() = default;
virtual ~IPluginCreator() = default;
protected: // @cond SuppressDoxyWarnings IPluginCreator(IPluginCreator const&) = default; IPluginCreator(IPluginCreator&&) = default; IPluginCreator& operator=(IPluginCreator const&) & = default; IPluginCreator& operator=(IPluginCreator&&) & = default; // @endcond };
* IPluginRegistry
class IPluginRegistry { public: virtual bool registerCreator(IPluginCreator& creator, AsciiChar const const pluginNamespace) noexcept = 0; virtual IPluginCreator const getPluginCreatorList(int32_t const numCreators) const noexcept = 0; virtual IPluginCreator getPluginCreator(AsciiChar const const pluginName, AsciiChar const const pluginVersion, AsciiChar const const pluginNamespace = "") noexcept = 0;
IPluginRegistry() = default;
IPluginRegistry(IPluginRegistry const&) = delete;
IPluginRegistry(IPluginRegistry&&) = delete;
IPluginRegistry& operator=(IPluginRegistry const&) & = delete;
IPluginRegistry& operator=(IPluginRegistry&&) & = delete;
protected: virtual ~IPluginRegistry() noexcept = default;
public: virtual bool deregisterCreator(IPluginCreator const& creator) noexcept = 0; };
*
*
## common
- IGpuAllocator
class IGpuAllocator { public: virtual void allocate(uint64_t const size, uint64_t const alignment, AllocatorFlags const flags) noexcept = 0; TRT_DEPRECATED virtual void free(void const memory) noexcept = 0; virtual ~IGpuAllocator() = default; IGpuAllocator() = default; virtual void reallocate(void /baseAddr/, uint64_t /alignment/, uint64_t /newSize/) noexcept { return nullptr; } virtual bool deallocate(void* const memory) noexcept { this->free(memory); return true; }
protected: IGpuAllocator(IGpuAllocator const&) = default; IGpuAllocator(IGpuAllocator&&) = default; IGpuAllocator& operator=(IGpuAllocator const&) & = default; IGpuAllocator& operator=(IGpuAllocator&&) & = default; };
-
class ILogger { public: enum class Severity : int32_t { kINTERNAL_ERROR = 0, kERROR = 1, kWARNING = 2, kINFO = 3, kVERBOSE = 4, }; virtual void log(Severity severity, AsciiChar const* msg) noexcept = 0;
ILogger() = default;
virtual ~ILogger() = default;
protected: ILogger(ILogger const&) = default; ILogger(ILogger&&) = default; ILogger& operator=(ILogger const&) & = default; ILogger& operator=(ILogger&&) & = default; };
## 基础数据结构
- DataType
enum class DataType : int32_t { kFLOAT = 0, kHALF = 1, kINT8 = 2, kINT32 = 3, kBOOL = 4 };
- Dims
class Dims32 { public: //! The maximum rank (number of dimensions) supported for a tensor. static constexpr int32_t MAX_DIMS{8}; //! The rank (number of dimensions). int32_t nbDims; //! The extent of each dimension. int32_t d[MAX_DIMS]; }; using Dims = Dims32;
- TensorFormat
enum class TensorFormat : int32_t { //! Row major linear format. //! For a tensor with dimensions {N, C, H, W} or {numbers, channels, //! columns, rows}, the dimensional index corresponds to {3, 2, 1, 0} //! and thus the order is W minor. //! //! For DLA usage, the tensor sizes are limited to C,H,W in the range [1,8192]. //! kLINEAR = 0, kCHW2 = 1, kHWC8 = 2, kCHW4 = 3, kCHW16 = 4, kCHW32 = 5, kDHWC8 = 6, kCDHW32 = 7, kHWC = 8, kDLA_LINEAR = 9, kDLA_HWC4 = 10, kHWC16 = 11 }; using PluginFormat = TensorFormat;
- ITensor
class ITensor : public INoCopy { public: void setName(const char name) noexcept { mImpl->setName(name); } const char getName() const noexcept { return mImpl->getName(); } void setDimensions(Dims dimensions) noexcept { mImpl->setDimensions(dimensions); } Dims getDimensions() const noexcept { return mImpl->getDimensions(); } void setType(DataType type) noexcept { mImpl->setType(type); } DataType getType() const noexcept { return mImpl->getType(); } bool setDynamicRange(float min, float max) noexcept { return mImpl->setDynamicRange(min, max); } bool isNetworkInput() const noexcept { return mImpl->isNetworkInput(); } bool isNetworkOutput() const noexcept { return mImpl->isNetworkOutput(); } void setBroadcastAcrossBatch(bool broadcastAcrossBatch) noexcept { mImpl->setBroadcastAcrossBatch(broadcastAcrossBatch); } bool getBroadcastAcrossBatch() const noexcept { return mImpl->getBroadcastAcrossBatch(); } TensorLocation getLocation() const noexcept { return mImpl->getLocation(); } void setLocation(TensorLocation location) noexcept { mImpl->setLocation(location); } bool dynamicRangeIsSet() const noexcept { return mImpl->dynamicRangeIsSet(); } void resetDynamicRange() noexcept { mImpl->resetDynamicRange(); } float getDynamicRangeMin() const noexcept { return mImpl->getDynamicRangeMin(); } float getDynamicRangeMax() const noexcept { return mImpl->getDynamicRangeMax(); } void setAllowedFormats(TensorFormats formats) noexcept { mImpl->setAllowedFormats(formats); } TensorFormats getAllowedFormats() const noexcept { return mImpl->getAllowedFormats(); } bool isShapeTensor() const noexcept { return mImpl->isShapeTensor(); } bool isExecutionTensor() const noexcept { return mImpl->isExecutionTensor(); }
protected: apiv::VTensor* mImpl; virtual ~ITensor() noexcept = default; };
-
https://github.com/281011824/TRT/blob/main/CMakeLists.txt