import time
from skimage import segmentation
import torch
import torch.nn as nn
import cv2
import os
import json
import numpy as np
image_path = 'tiger.jpg'
json_path = 'tiger.json'
def get_target_mean_value(target_path):
    img = cv2.imread(target_path)
    r = cv2.selectROI('input', img, False)
    print("input:", r)
    roi = img[int(r[1]):int(r[1] + r[3]), int(r[0]):int(r[0] + r[2])]
    return int(np.mean(roi))
class Args(object):
    input_image_path = image_path
    img = cv2.imread(input_image_path)
    height, width, _ = img.shape
    train_epoch = 2 ** 6
    mod_dim1 = 64
    mod_dim2 = 32
    gpu_id = 0
    min_label_num = 4
    max_label_num = 256
class Net(nn.Module):
    def __init__(self, inp_dim, mod_dim1, mod_dim2):
        super(Net, self).__init__()
        self.seq = nn.Sequential(
            nn.Conv2d(inp_dim, mod_dim1, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(mod_dim1),
            nn.ReLU(inplace=True),
            nn.Conv2d(mod_dim1, mod_dim2, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(mod_dim2),
            nn.ReLU(inplace=True),
            nn.Conv2d(mod_dim2, mod_dim1, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(mod_dim1),
            nn.ReLU(inplace=True),
            nn.Conv2d(mod_dim1, mod_dim2, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(mod_dim2),
        )
    def forward(self, x):
        return self.seq(x)
def run():
    start_time0 = time.time()
    args = Args()
    torch.cuda.manual_seed_all(1943)
    np.random.seed(1943)
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)  
    image = cv2.imread(args.input_image_path)
    seg_map = segmentation.felzenszwalb(image, scale=32, sigma=0.5, min_size=64)
    #seg_map = segmentation.slic(image, n_segments=10000, compactness=100)
    seg_map = seg_map.flatten()
    seg_lab = [np.where(seg_map == u_label)[0]
               for u_label in np.unique(seg_map)]
    device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
    tensor = image.transpose((2, 0, 1))
    tensor = tensor.astype(np.float32) / 255.0
    tensor = tensor[np.newaxis, :, :, :]
    tensor = torch.from_numpy(tensor).to(device)
    model = Net(inp_dim=3, mod_dim1=args.mod_dim1, mod_dim2=args.mod_dim2).to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=5e-2, momentum=0.9)
    image_flatten = image.reshape((-1, 3))
    color_avg = np.random.randint(255, size=(args.max_label_num, 3))
    show = image
    start_time1 = time.time()
    model.train()
    for batch_idx in range(args.train_epoch):
        optimizer.zero_grad()
        output = model(tensor)[0]
        output = output.permute(1, 2, 0).view(-1, args.mod_dim2)
        target = torch.argmax(output, 1)
        im_target = target.data.cpu().numpy()
        for inds in seg_lab:
            u_labels, hist = np.unique(im_target[inds], return_counts=True)
            im_target[inds] = u_labels[np.argmax(hist)]
        target = torch.from_numpy(im_target)
        target = target.to(device)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        un_label, lab_inverse = np.unique(im_target, return_inverse=True, )
        if un_label.shape[0] < args.max_label_num:
            img_flatten = image_flatten.copy()
            if len(color_avg) != un_label.shape[0]:
                color_avg = [np.mean(img_flatten[im_target == label], axis=0, dtype=np.int32) for label in un_label]
            for lab_id, color in enumerate(color_avg):
                img_flatten[lab_inverse == lab_id] = color
            show = img_flatten.reshape(image.shape)
        cv2.imshow("seg_pt", show)
        cv2.waitKey(1)
        print('Loss:', batch_idx, loss.item())
        if len(un_label) < args.min_label_num:
            break
    time0 = time.time() - start_time0
    time1 = time.time() - start_time1
    print('PyTorchInit: %.2f\nTimeUsed: %.2f' % (time0, time1))
    cv2.imwrite("seg_%s_%ds.jpg" % (args.input_image_path[6:-4], time1), show)
    return show
def convert2json(img, img_path, value):
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    height, width = img.shape[0:2]
    binary = cv2.threshold(img, value, 255, cv2.THRESH_BINARY)[1]
    kernel = np.ones((10, 10), np.uint8)
    dilated_img = cv2.dilate(binary, kernel, iterations=1)
    contours, _ = cv2.findContours(dilated_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    labelme_data = \
    {
        "version": "5.3.1",
        "flags": {},
        "shapes": [],
        "imagePath": img_path,
        "imageData": None,
        "imageHeight": height,
        "imageWidth": width
    }
    for contour in contours:
        points = contour.reshape(-1, 2).tolist()
        shape = {
            "label": "object",
            "points": points,
            "group_id": None,
            "shape_type": "polygon",
            "flags": {}
        }
        labelme_data["shapes"].append(shape)
    with open(json_path, 'w') as f:
        json.dump(labelme_data, f, indent=4)
    print("LabelMe JSON 文件已保存")
if __name__ == '__main__':
    gray_value = get_target_mean_value(image_path)
    mask = run()
    convert2json(mask, image_path, gray_value)

tiger.jpg
tiger.jpg
color_extraction.jpg
color_extraction.png

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐