题目要求:基于MaskRCNN的目标检测推理模型,兼容之前YOLACT,YOLOv3,YOLOv5,YOLOv7模型。

与上述快速推理模型不同,MaskRCNN是一种one-stage目标检测模型,主要的优点在于精度更高。MaskRCNN由Faster RCNN改进而来,主要的改进在于提出ROI align,相比ROI Pooling,减小了量化误差,使得目标的定位和分割更精确。MaskRCNN是目标检测one-stage中的典型代表,同时又是实例分割中的重要模型。
MaskRCNN论文链接:MaskRCNN
MaskRCNN代码链接:MaskRCNN_Detectron
模型下载:TensorFlow_MaskRCNN_Incetptionv2

  • 结果展示
    在这里插入图片描述

  • 代码示例 (模块代码,其他部分见兼容模型)

#include <fstream>
#include <sstream>
#include <iostream>
#include <string>
 
#include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
 
#include "config.cpp"

using namespace cv;
using namespace dnn;
using namespace std;

class maskRCNN
{
public:
    maskRCNN(net_config& config);
    void detect(Mat& frame);
private:
    int inpWidth;
    int inpHeight;

    float confThreshold;
    float nmsThreshold;
    float maskThreshold;
    Net net;
    Mat mask;
    void drawPred(float conf, int left, int top, int right, int bottom, Mat& frame, int classid);
};

maskRCNN::maskRCNN(net_config& config)
{
    this -> confThreshold = config.confThreshold;
    this -> nmsThreshold = config.nmsThreshold;
    this -> maskThreshold = 0.3;
    string textGraph = "model/mask_rcnn/graph.pbtxt";
	string modelWeights = "model/mask_rcnn/frozen_inference_graph.pb";
    this -> net = readNetFromTensorflow(modelWeights, textGraph);
}

void maskRCNN::detect(Mat& frame)
{
    Mat blob;
    blobFromImage(frame, blob, 1.0, Size(frame.cols, frame.rows), Scalar(), true, false);
    this->net.setInput(blob);
    vector<Mat> outs;
    vector<string> outNames(2);
	outNames[0] = "detection_out_final";
	outNames[1] = "detection_masks";
    this->net.forward(outs, outNames);

    Mat outDet = outs[0];
    Mat outSeg = outs[1];
    int nDet = outs[0].size[2];
    int nClass = outs[1].size[1];
    outDet = outDet.reshape(1, outDet.total() / 7);

    // nms
    for(int i = 0; i < nDet; i++){
        float score = outDet.at<float>(i, 2);
        if (score > confThreshold){
            int classId = static_cast<int>(outDet.at<float>(i, 1));
			int left = static_cast<int>(frame.cols * outDet.at<float>(i, 3));
			int top = static_cast<int>(frame.rows * outDet.at<float>(i, 4));
			int right = static_cast<int>(frame.cols * outDet.at<float>(i, 5));
			int bottom = static_cast<int>(frame.rows * outDet.at<float>(i, 6));

            left = max(0, min(left, frame.cols - 1));
			top = max(0, min(top, frame.rows - 1));
			right = max(0, min(right, frame.cols - 1));
			bottom = max(0, min(bottom, frame.rows - 1));
			Rect box = Rect(left, top, right - left + 1, bottom - top + 1);
 
			// Extract the mask for the object 提取掩模
			Mat objectMask(outSeg.size[2], outSeg.size[3], CV_32F, outSeg.ptr<float>(i, classId));
            this -> mask = objectMask;
            drawPred(score, left, top, right, bottom, frame, classId);
        }
    }
}

void maskRCNN::drawPred(float conf, int left, int top, int right, int bottom, Mat& frame, int classid)
{
    //Draw a rectangle displaying the bounding box
    Rect box = Rect(left, top, right - left + 1, bottom - top + 1);
    rectangle(frame, Point(box.x, box.y), Point(box.x + box.width, box.y + box.height), Scalar(0, 0, 255), 2);

    //Get the label for the class name and its confidence
    string label = format("%.2f", conf);
    label = string(class_names[classid+1]) + ":" + label;

    //Display the label at the top of the bounding box
    int baseLine;
    Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
    top = max(top, labelSize.height);
    putText(frame, label, Point(left, top), FONT_HERSHEY_SIMPLEX, 0.75, Scalar(0, 255, 0), 1);

    Scalar color =Scalar((uchar)colors[classid+1]);
    resize(this -> mask, this -> mask, Size(box.width, box.height));
	Mat mask = (this -> mask > maskThreshold);
	// mask ori image and color map
	Mat coloredRoi = (0.4 * color + 0.6 * frame(box));
	coloredRoi.convertTo(coloredRoi, CV_8UC3);
 
	// border refine
	vector<Mat> contours;
	Mat hierarchy;
	mask.convertTo(mask, CV_8U);
	findContours(mask, contours, hierarchy, RETR_CCOMP, CHAIN_APPROX_SIMPLE);
	drawContours(coloredRoi, contours, -1, color, 2, LINE_8, hierarchy, 100);
	coloredRoi.copyTo(frame(box), mask);
}

// **** unit test ****//
// int main()
// {
		// maskRCNN model(config);
		// model.detect(srcimg);
		// static const string kWinName = "Deep learning object detection in OpenCV";
		// namedWindow(kWinName, WINDOW_NORMAL);
		// imshow(kWinName, srcimg);
		// waitKey(0);
		// destroyAllWindows();
// }
Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐