plugin  0.1.0
model_base.h
1 /*
2 // Copyright (C) 2020-2024 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16 
17 #pragma once
18 #include <map>
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include <openvino/openvino.hpp>
24 
25 #include <utils/args_helper.hpp>
26 #include <utils/config_factory.h>
27 #include <utils/ocv_common.hpp>
28 
29 struct InferenceResult;
30 struct InputData;
31 struct InternalModelData;
32 struct ResultBase;
33 
34 class ModelBase {
35 public:
36  ModelBase(const std::string& modelFileName, const std::string& layout = "")
37  : modelFileName(modelFileName),
38  inputsLayouts(parseLayoutString(layout)) {}
39 
40  virtual ~ModelBase() {}
41 
42  virtual std::shared_ptr<InternalModelData> preprocess(const InputData& inputData, ov::InferRequest& request) = 0;
43  virtual ov::CompiledModel compileModel(const ModelConfig& config, ov::Core& core);
44  virtual void onLoadCompleted(const std::vector<ov::InferRequest>& requests) {}
45  virtual std::unique_ptr<ResultBase> postprocess(InferenceResult& infResult) = 0;
46 
47  const std::vector<std::string>& getOutputsNames() const {
48  return outputsNames;
49  }
50  const std::vector<std::string>& getInputsNames() const {
51  return inputsNames;
52  }
53 
54  std::string getModelFileName() {
55  return modelFileName;
56  }
57 
58  void setInputsPreprocessing(bool reverseInputChannels,
59  const std::string& meanValues,
60  const std::string& scaleValues) {
61  this->inputTransform = InputTransform(reverseInputChannels, meanValues, scaleValues);
62  }
63 
64 protected:
65  virtual void prepareInputsOutputs(std::shared_ptr<ov::Model>& model) = 0;
66  virtual void setBatch(std::shared_ptr<ov::Model>& model);
67 
68  std::shared_ptr<ov::Model> prepareModel(ov::Core& core);
69 
70  InputTransform inputTransform = InputTransform();
71  std::vector<std::string> inputsNames;
72  std::vector<std::string> outputsNames;
73  ov::CompiledModel compiledModel;
74  std::string modelFileName;
75  ModelConfig config = {};
76  std::map<std::string, ov::Layout> inputsLayouts;
77  ov::Layout getInputLayout(const ov::Output<ov::Node>& input);
78 };
Definition: ocv_common.hpp:271
Definition: model_base.h:34
std::shared_ptr< ov::Model > prepareModel(ov::Core &core)
Definition: model_base.cpp:28
a header file with common samples functionality
a header file with common samples functionality using OpenCV
Definition: results.h:53
Definition: input_data.h:20
Definition: internal_model_data.h:19
Definition: config_factory.h:26
Definition: results.h:29