Git Product home page Git Product logo

Comments (6)

Yuliang-Liu avatar Yuliang-Liu commented on July 19, 2024 2

An Example:

OUTPUT_DIR: output/align/07x32
MODEL:
  META_ARCHITECTURE: "OneStage"
  ONE_STAGE_HEAD: "align"
  WEIGHT: "YOUR_MODEL"
  FCOS_ON: True
  BACKBONE:
    CONV_BODY: "R-50"
  NECK:
    CONV_BODY: "fpn-align"
  RESNETS:
    BACKBONE_OUT_CHANNELS: 256
  RETINANET:
    USE_C5: False # FCOS uses P5 instead of C5
  ALIGN:
    POOLER_RESOLUTION: (7, 32)
    POOLER_CANONICAL_SCALE: 160
    POOLER_SCALES: (0.25, 0.125, 0.0625)
    PREDICTOR: "ctc" # "ctc" or "attention"
  FCOS:
    CENTER_SAMPLE: True
    POS_RADIUS: 1.5
    LOC_LOSS_TYPE: "giou"
DATASETS:
  TRAIN: ("YOUR_TRAINSET",)
  TEST: ("YOUR_TESTSET",)
  TEXT:
    NUM_CHARS: 25
    VOC_SIZE: 97
INPUT:
  MIN_SIZE_RANGE_TRAIN: (640, 800)
  MAX_SIZE_TRAIN: 1333
  MIN_SIZE_TEST: 800
  MAX_SIZE_TEST: 1333
  FLIP_PROB_TRAIN: 0.0
DATALOADER:
  SIZE_DIVISIBILITY: 32
SOLVER:
  BASE_LR: 0.01
  WEIGHT_DECAY: 0.0001
  STEPS: (100000, 180000)
  MAX_ITER: 250000
  IMS_PER_BATCH: 2
  WARMUP_METHOD: "constant"
  CHECKPOINT_PERIOD: 2500
TEST:
  IMS_PER_BATCH: 1

@eyebies Simply changing "ctc" to "attention" if you would like to fine-tune from the provided model.

from bezier_curve_text_spotting.

saicoco avatar saicoco commented on July 19, 2024 2

@deepseek You can use following script to generate bezier points for rotated box, here I add find top_edge and bottom_edge to generate eight points for rotate box:

# coding=utf-8
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from scipy import interpolate
from scipy.special import comb as n_over_k
import glob, os
import cv2

from skimage import data, color
from skimage.transform import rescale, resize, downscale_local_mean

import matplotlib.pyplot as plt
import math
import numpy as np
import random
# from scipy.optimize import leastsq
import torch
from torch import nn
from torch.nn import functional as F

from sklearn.model_selection import train_test_split 
from sklearn.linear_model import LinearRegression
from sklearn import metrics
from sklearn.metrics import mean_squared_error, r2_score

from shapely.geometry import *
from PIL import Image
import time
from bresenham import bresenham
import re 
from tqdm import tqdm

class Bezier(nn.Module):
    def __init__(self, ps, ctps):
        super(Bezier, self).__init__()
        self.x1 = nn.Parameter(torch.as_tensor(ctps[0], dtype=torch.float64))
        self.x2 = nn.Parameter(torch.as_tensor(ctps[2], dtype=torch.float64))
        self.y1 = nn.Parameter(torch.as_tensor(ctps[1], dtype=torch.float64))
        self.y2 = nn.Parameter(torch.as_tensor(ctps[3], dtype=torch.float64))
        self.x0 = ps[0, 0]
        self.x3 = ps[-1, 0]
        self.y0 = ps[0, 1]
        self.y3 = ps[-1, 1]
        self.inner_ps = torch.as_tensor(ps[1:-1, :], dtype=torch.float64)
        self.t = torch.as_tensor(np.linspace(0, 1, 81))

    def forward(self):
        x0, x1, x2, x3, y0, y1, y2, y3 = self.control_points()
        t = self.t
        bezier_x = (1-t)*((1-t)*((1-t)*x0+t*x1)+t*((1-t)*x1+t*x2))+t*((1-t)*((1-t)*x1+t*x2)+t*((1-t)*x2+t*x3))
        bezier_y = (1-t)*((1-t)*((1-t)*y0+t*y1)+t*((1-t)*y1+t*y2))+t*((1-t)*((1-t)*y1+t*y2)+t*((1-t)*y2+t*y3))
        bezier = torch.stack((bezier_x, bezier_y), dim=1)
        diffs = bezier.unsqueeze(0) - self.inner_ps.unsqueeze(1)
        sdiffs = diffs ** 2
        dists = sdiffs.sum(dim=2).sqrt()
        min_dists, min_inds = dists.min(dim=1)
        return min_dists.sum()  

    def control_points(self):
        return self.x0, self.x1, self.x2, self.x3, self.y0, self.y1, self.y2, self.y3

    def control_points_f(self):
        return self.x0, self.x1.item(), self.x2.item(), self.x3, self.y0, self.y1.item(), self.y2.item(), self.y3


def train(x, y, ctps, lr):
    x, y = np.array(x), np.array(y)
    ps = np.vstack((x, y)).transpose()
    bezier = Bezier(ps, ctps)

    optimizer = torch.optim.SGD(bezier.parameters(), lr=lr)
    # start = time.time()
    # save initial points
    intial_pts = bezier.control_points_f()
    if not lr == 0.0:
        for i in range(1000):
            loss = bezier()
            if torch.isnan(loss):
                return intial_pts
            if i == 400: optimizer.param_groups[0]['lr'] *= 0.5
            if i == 800: optimizer.param_groups[0]['lr'] *= 0.5
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    # end = time.time()
    return bezier.control_points_f()

def draw(ps, control_points, t):
    x = ps[:, 0]
    y = ps[:, 1]
    x0, x1, x2, x3, y0, y1, y2, y3 = control_points
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.plot(x,y,color='m',linestyle='',marker='.')
    bezier_x = (1-t)*((1-t)*((1-t)*x0+t*x1)+t*((1-t)*x1+t*x2))+t*((1-t)*((1-t)*x1+t*x2)+t*((1-t)*x2+t*x3))
    bezier_y = (1-t)*((1-t)*((1-t)*y0+t*y1)+t*((1-t)*y1+t*y2))+t*((1-t)*((1-t)*y1+t*y2)+t*((1-t)*y2+t*y3))

    plt.plot(bezier_x,bezier_y, 'g-')
    plt.draw()
    plt.pause(1) # <-------
    raw_input("<Hit Enter To Close>")
    plt.close(fig)

Mtk = lambda n, t, k: t**k * (1-t)**(n-k) * n_over_k(n,k)
BezierCoeff = lambda ts: [[Mtk(3,t,k) for k in range(4)] for t in ts]


def bezier_fit(x, y):
    dy = y[1:] - y[:-1]
    dx = x[1:] - x[:-1]
    dt = (dx ** 2 + dy ** 2)**0.5
    t = dt/dt.sum()
    t = np.hstack(([0], t))
    t = t.cumsum()

    data = np.column_stack((x, y))
    Pseudoinverse = np.linalg.pinv(BezierCoeff(t)) # (9,4) -> (4,9)
    control_points = Pseudoinverse.dot(data)     # (4,9)*(9,2) -> (4,2)
    medi_ctp = control_points[1:-1,:].flatten().tolist()
    return medi_ctp
    
def bezier_fitv2(x, y):
#     t = (x - x[0]) / (x[-1] - x[0])
    xc01 = (2*x[0] + x[-1])/3.0
    yc01 = (2*y[0] + y[-1])/3.0
    xc02 = (x[0] + 2* x[-1])/3.0
    yc02 = (y[0] + 2* y[-1])/3.0
    control_points = [xc01,yc01,xc02,yc02]
    return control_points

def is_close_to_line(xs, ys, thres):
        regression_model = LinearRegression()
        # Fit the data(train the model)
        regression_model.fit(xs.reshape(-1,1), ys.reshape(-1,1))
        # Predict
        y_predicted = regression_model.predict(xs.reshape(-1,1))

        # model evaluation
        rmse = mean_squared_error(ys.reshape(-1,1)**2, y_predicted**2)
        rmse = rmse/(ys.reshape(-1,1)**2- y_predicted**2).max()**2

        if rmse > thres:
                return 0.0
        else:
                return 2.0

def is_close_to_linev2(xs, ys, size, thres = 0.05):
        pts = []
        nor_pixel = int(size**0.5)
        for i in range(len(xs)):
                pts.append(Point([xs[i], ys[i]]))
        import itertools
        # iterate by pairs of points
        slopes = [(second.y-first.y)/(second.x-first.x) if not (second.x-first.x) == 0.0 else math.inf*np.sign((second.y-first.y)) for first, second in zip(pts, pts[1:])]
        st_slope = (ys[-1] - ys[0])/(xs[-1] - xs[0])
        max_dis = ((ys[-1] - ys[0])**2 +(xs[-1] - xs[0])**2)**(0.5)
        diffs = abs(slopes - st_slope)
        score = diffs.sum() * max_dis/nor_pixel

        if score < thres:
                return 0.0
        else:
                return 3.0


def find_long_edges(points, bottoms):
    b1_start, b1_end = bottoms[0]
    b2_start, b2_end = bottoms[1]
    n_pts = len(points)
    i = (b1_end + 1) % n_pts
    long_edge_1 = []

    while (i % n_pts != b2_end):
        start = (i - 1) % n_pts
        end = i % n_pts
        long_edge_1.append((start, end))
        i = (i + 1) % n_pts

    i = (b2_end + 1) % n_pts
    long_edge_2 = []
    while (i % n_pts != b1_end):
        start = (i - 1) % n_pts
        end = i % n_pts
        long_edge_2.append((start, end))
        i = (i + 1) % n_pts
    return long_edge_1, long_edge_2
def norm2(x, axis=None):
    if axis:
        return np.sqrt(np.sum(x ** 2, axis=axis))
    return np.sqrt(np.sum(x ** 2))

def cos(p1, p2):
    return (p1 * p2).sum() / (norm2(p1) * norm2(p2))

def find_bottom(pts):

    if len(pts) > 4:
        e = np.concatenate([pts, pts[:3]])
        candidate = []
        for i in range(1, len(pts) + 1):
            v_prev = e[i] - e[i - 1]
            v_next = e[i + 2] - e[i + 1]
            if cos(v_prev, v_next) < -0.7:
                candidate.append((i % len(pts), (i + 1) % len(pts), norm2(e[i] - e[i + 1])))

        if len(candidate) != 2 or candidate[0][0] == candidate[1][1] or candidate[0][1] == candidate[1][0]:
            # if candidate number < 2, or two bottom are joined, select 2 farthest edge
            mid_list = []
            for i in range(len(pts)):
                mid_point = (e[i] + e[(i + 1) % len(pts)]) / 2
                mid_list.append((i, (i + 1) % len(pts), mid_point))

            dist_list = []
            for i in range(len(pts)):
                for j in range(len(pts)):
                    s1, e1, mid1 = mid_list[i]
                    s2, e2, mid2 = mid_list[j]
                    dist = norm2(mid1 - mid2)
                    dist_list.append((s1, e1, s2, e2, dist))
            bottom_idx = np.argsort([dist for s1, e1, s2, e2, dist in dist_list])[-2:]
            bottoms = [dist_list[bottom_idx[0]][:2], dist_list[bottom_idx[1]][:2]]
        else:
            bottoms = [candidate[0][:2], candidate[1][:2]]

    else:
        d1 = norm2(pts[1] - pts[0]) + norm2(pts[2] - pts[3])
        d2 = norm2(pts[2] - pts[1]) + norm2(pts[0] - pts[3])
        bottoms = [(0, 1), (2, 3)] if d1 < d2 else [(1, 2), (3, 0)]
    assert len(bottoms) == 2, 'fewer than 2 bottoms'
    return bottoms

def cal_control_pts(coords):
        poly = np.array(coords)
        bottom = find_bottom(poly)
        e1, e2 = find_long_edges(poly, bottom)
        id0, id1 = e1[0]
        id2, id3 = e2[0]
        poly = np.array(poly)[[id1, id0, id3, id2]]

        x0, y0 = poly[0]
        x1, y1 = poly[1]
        x2, y2 = poly[2]
        x3, y3 = poly[3]

        # find long edge
        new_x1 = 1./3 * (x1 - x0) + x0
        new_y1 = 1./3 * (y1 - y0) + y0
        new_x2 = 2./3 * (x1 - x0) + x0
        new_y2 = 2./3 * (y1 - y0) + y0
        
        new_x3 = 1./3 * (x2 - x3) + x3
        new_y3 = 1./3 * (y2 - y3) + y3
        new_x4 = 2./3 * (x2 - x3) + x3
        new_y4 = 2./3 * (y2 - y3) + y3
        
        newpts = [
                [x0, y0],
                [new_x1, new_y1],
                [new_x2, new_y2],
                [x1, y1],
                [x2, y2],
                [new_x4, new_y4],
                [new_x3, new_y3],
                [x3, y3]
        ]
        return newpts

import sys
data_dir = sys.argv[1]
out_dir = sys.argv[2]
labels = glob.glob('{}/*.txt'.format(data_dir))
labels.sort()

for il, label in tqdm(enumerate(labels)):
    # print('Pros '+label)
    imgdir = label.replace('.txt', '.jpg')

    data = []
    cts  = []
    polys = []
    fin = open(label, 'r').readlines()
    for il, line in enumerate(fin):
        line = line.strip().split(',')
        # if not len(line[:-1]) == 20: continue
        ct = line[-1]
        if ct == '#': continue
        # print('ct', ct)
        line = [item.replace('\ufeff', '') for item in line]
        try:
                coords = [(float(line[:-1][ix]), float(line[:-1][ix+1])) for ix in range(0, len(line[:-1]), 2)]
        except:
                continue
        coords = cal_control_pts(coords)
        poly = Polygon(coords)
        coords_data = np.array(coords).reshape((-1))
        data.append(coords_data)
        # data.append(np.array([float(x) for x in line[:-1]]))
        cts.append(ct)
        polys.append(poly)

    ############## top
    # img = plt.imread(imgdir)
    outgt = open(os.path.join(out_dir, label.split('/')[-1]), 'w')
    for iid, ddata in enumerate(data):
        lh = len(data[iid])
        assert(lh % 4 ==0)
        lhc2 = int(lh/2)
        lhc4 = int(lh/4)
        xcors = [data[iid][i] for i in range(0, len(data[iid]),2)]
        ycors = [data[iid][i+1] for i in range(0, len(data[iid]),2)]
        curve_data_top = data[iid][0:lhc2].reshape(lhc4, 2)
        curve_data_bottom = data[iid][lhc2:].reshape(lhc4, 2)

        left_vertex_x = [curve_data_top[0,0], curve_data_bottom[lhc4-1,0]]
        left_vertex_y = [curve_data_top[0,1], curve_data_bottom[lhc4-1,1]]
        right_vertex_x = [curve_data_top[lhc4-1,0], curve_data_bottom[0,0]]
        right_vertex_y = [curve_data_top[lhc4-1,1], curve_data_bottom[0,1]]

        x_data = curve_data_top[:, 0]
        y_data = curve_data_top[:, 1]

        init_control_points = bezier_fit(x_data, y_data)
        size = 512*512
        learning_rate = is_close_to_linev2(x_data, y_data, size)

        x0, x1, x2, x3, y0, y1, y2, y3 = train(x_data, y_data, init_control_points, 0.0)
        control_points = np.array([
                [x0,y0],\
                [x1,y1],\
                [x2,y2],\
                [x3,y3]                        
        ])

        x_data_b = curve_data_bottom[:, 0]
        y_data_b = curve_data_bottom[:, 1]

        init_control_points_b = bezier_fit(x_data_b, y_data_b)
        learning_rate = is_close_to_linev2(x_data_b, y_data_b, size)

        x0_b, x1_b, x2_b, x3_b, y0_b, y1_b, y2_b, y3_b = train(x_data_b, y_data_b, init_control_points_b, 0.0)
        control_points_b = np.array([
                [x0_b,y0_b],\
                [x1_b,y1_b],\
                [x2_b,y2_b],\
                [x3_b,y3_b]                        
        ])

        t_plot = np.linspace(0, 1, 81)
        Bezier_top = np.array(BezierCoeff(t_plot)).dot(control_points)

        Bezier_bottom = np.array(BezierCoeff(t_plot)).dot(control_points_b)

        # fig, ax = plt.subplots()
        # plt.plot(x_data, y_data,    'ro', label='input', linewidth = 1.0)
        # plt.plot(x_data_b, y_data_b,    'ro', label='input', linewidth = 1.0)

        plt.plot(Bezier_top[:,0],
                Bezier_top[:,1],         'g-', label='fit', linewidth=1.0)
        plt.plot(Bezier_bottom[:,0],
                Bezier_bottom[:,1],         'g-', label='fit', linewidth=1.0)        
        plt.plot(control_points[:,0],
                control_points[:,1], 'r.:', fillstyle='none', linewidth=1.0)
        plt.plot(control_points_b[:,0],
                control_points_b[:,1], 'r.:', fillstyle='none', linewidth=1.0)

        plt.plot(left_vertex_x, left_vertex_y, 'g-', linewidth=1.0)
        plt.plot(right_vertex_x, right_vertex_y, 'g-', linewidth=1.0)

        outstr = '{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}\n'.format(x0,y0,\
                                                                            round(x1, 2), round(y1, 2),\
                                                                            round(x2, 2), round(y2, 2),\
                                                                            round(x3, 2), round(y3, 2),\
                                                                            round(x0_b, 2), round(y0_b, 2),\
                                                                            round(x1_b, 2), round(y1_b, 2),\
                                                                            round(x2_b, 2), round(y2_b, 2),\
                                                                            round(x3_b, 2), round(y3_b, 2),\
                                                                            cts[iid])
        outgt.writelines(outstr)
    outgt.close()

    # plt.imshow(img)
    # plt.axis('off')

    # if not os.path.isdir('vis_results'):
    #         os.mkdir('vis_results')
    # plt.savefig('vis_results/'+os.path.basename(imgdir), bbox_inches='tight',dpi=400)
    # plt.clf()

After you get bezier points, you can use them with origin text annotations to genrate coco-format, you need to add extra info for annotation:

{
                'area': h*w,
                'bbox': box,
                'category_id': cat_id,
                'id': ann_id,
                'image_id': image_id,
                'iscrowd': 0,
                'segmentation': [poly],
                'text': [text],
                'bezier_pts': [bezier_pts], # bezier points you generated for each text instance
                'rec': [rec] # text label for recognition head
 }

And then, you just configure data path and run python tools/train.py --config-file *.yaml. The model will work well.

If you want to generate anno for curve text, you can use script in README, above all I mentioned is just for rotate box

from bezier_curve_text_spotting.

Yuliang-Liu avatar Yuliang-Liu commented on July 19, 2024

All code required for training is included in this repo. Simply using train_net.py.
All the training data are also publicly available.
You will need to figure out "how" at this moment. We will release the instruction in the future.

from bezier_curve_text_spotting.

eyebies avatar eyebies commented on July 19, 2024

@Yuliang-Liu Thanks for sharing the code! just wanted to check if the word_bezier.yaml needs any parameter changes for training? If I want to fine-tune from your model, what parameter changes would you recommend?
thanks!

from bezier_curve_text_spotting.

 avatar commented on July 19, 2024

@Yuliang-Liu further information on how to train from scratch is required.

from bezier_curve_text_spotting.

Yuliang-Liu avatar Yuliang-Liu commented on July 19, 2024

@saicoco Thank you.

We will release our full code in the Adet next week, including the models of CTW1500 and Total-text, the training data we used, evaluation scripts, results of detection, etc. This repo will not be maintained anymore.

Thanks for your attention.

from bezier_curve_text_spotting.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.