Comments (6)
An Example:
OUTPUT_DIR: output/align/07x32
MODEL:
META_ARCHITECTURE: "OneStage"
ONE_STAGE_HEAD: "align"
WEIGHT: "YOUR_MODEL"
FCOS_ON: True
BACKBONE:
CONV_BODY: "R-50"
NECK:
CONV_BODY: "fpn-align"
RESNETS:
BACKBONE_OUT_CHANNELS: 256
RETINANET:
USE_C5: False # FCOS uses P5 instead of C5
ALIGN:
POOLER_RESOLUTION: (7, 32)
POOLER_CANONICAL_SCALE: 160
POOLER_SCALES: (0.25, 0.125, 0.0625)
PREDICTOR: "ctc" # "ctc" or "attention"
FCOS:
CENTER_SAMPLE: True
POS_RADIUS: 1.5
LOC_LOSS_TYPE: "giou"
DATASETS:
TRAIN: ("YOUR_TRAINSET",)
TEST: ("YOUR_TESTSET",)
TEXT:
NUM_CHARS: 25
VOC_SIZE: 97
INPUT:
MIN_SIZE_RANGE_TRAIN: (640, 800)
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
FLIP_PROB_TRAIN: 0.0
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
BASE_LR: 0.01
WEIGHT_DECAY: 0.0001
STEPS: (100000, 180000)
MAX_ITER: 250000
IMS_PER_BATCH: 2
WARMUP_METHOD: "constant"
CHECKPOINT_PERIOD: 2500
TEST:
IMS_PER_BATCH: 1
@eyebies Simply changing "ctc" to "attention" if you would like to fine-tune from the provided model.
from bezier_curve_text_spotting.
@deepseek You can use following script to generate bezier points for rotated box, here I add find top_edge and bottom_edge to generate eight points for rotate box:
# coding=utf-8
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from scipy import interpolate
from scipy.special import comb as n_over_k
import glob, os
import cv2
from skimage import data, color
from skimage.transform import rescale, resize, downscale_local_mean
import matplotlib.pyplot as plt
import math
import numpy as np
import random
# from scipy.optimize import leastsq
import torch
from torch import nn
from torch.nn import functional as F
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn import metrics
from sklearn.metrics import mean_squared_error, r2_score
from shapely.geometry import *
from PIL import Image
import time
from bresenham import bresenham
import re
from tqdm import tqdm
class Bezier(nn.Module):
def __init__(self, ps, ctps):
super(Bezier, self).__init__()
self.x1 = nn.Parameter(torch.as_tensor(ctps[0], dtype=torch.float64))
self.x2 = nn.Parameter(torch.as_tensor(ctps[2], dtype=torch.float64))
self.y1 = nn.Parameter(torch.as_tensor(ctps[1], dtype=torch.float64))
self.y2 = nn.Parameter(torch.as_tensor(ctps[3], dtype=torch.float64))
self.x0 = ps[0, 0]
self.x3 = ps[-1, 0]
self.y0 = ps[0, 1]
self.y3 = ps[-1, 1]
self.inner_ps = torch.as_tensor(ps[1:-1, :], dtype=torch.float64)
self.t = torch.as_tensor(np.linspace(0, 1, 81))
def forward(self):
x0, x1, x2, x3, y0, y1, y2, y3 = self.control_points()
t = self.t
bezier_x = (1-t)*((1-t)*((1-t)*x0+t*x1)+t*((1-t)*x1+t*x2))+t*((1-t)*((1-t)*x1+t*x2)+t*((1-t)*x2+t*x3))
bezier_y = (1-t)*((1-t)*((1-t)*y0+t*y1)+t*((1-t)*y1+t*y2))+t*((1-t)*((1-t)*y1+t*y2)+t*((1-t)*y2+t*y3))
bezier = torch.stack((bezier_x, bezier_y), dim=1)
diffs = bezier.unsqueeze(0) - self.inner_ps.unsqueeze(1)
sdiffs = diffs ** 2
dists = sdiffs.sum(dim=2).sqrt()
min_dists, min_inds = dists.min(dim=1)
return min_dists.sum()
def control_points(self):
return self.x0, self.x1, self.x2, self.x3, self.y0, self.y1, self.y2, self.y3
def control_points_f(self):
return self.x0, self.x1.item(), self.x2.item(), self.x3, self.y0, self.y1.item(), self.y2.item(), self.y3
def train(x, y, ctps, lr):
x, y = np.array(x), np.array(y)
ps = np.vstack((x, y)).transpose()
bezier = Bezier(ps, ctps)
optimizer = torch.optim.SGD(bezier.parameters(), lr=lr)
# start = time.time()
# save initial points
intial_pts = bezier.control_points_f()
if not lr == 0.0:
for i in range(1000):
loss = bezier()
if torch.isnan(loss):
return intial_pts
if i == 400: optimizer.param_groups[0]['lr'] *= 0.5
if i == 800: optimizer.param_groups[0]['lr'] *= 0.5
optimizer.zero_grad()
loss.backward()
optimizer.step()
# end = time.time()
return bezier.control_points_f()
def draw(ps, control_points, t):
x = ps[:, 0]
y = ps[:, 1]
x0, x1, x2, x3, y0, y1, y2, y3 = control_points
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(x,y,color='m',linestyle='',marker='.')
bezier_x = (1-t)*((1-t)*((1-t)*x0+t*x1)+t*((1-t)*x1+t*x2))+t*((1-t)*((1-t)*x1+t*x2)+t*((1-t)*x2+t*x3))
bezier_y = (1-t)*((1-t)*((1-t)*y0+t*y1)+t*((1-t)*y1+t*y2))+t*((1-t)*((1-t)*y1+t*y2)+t*((1-t)*y2+t*y3))
plt.plot(bezier_x,bezier_y, 'g-')
plt.draw()
plt.pause(1) # <-------
raw_input("<Hit Enter To Close>")
plt.close(fig)
Mtk = lambda n, t, k: t**k * (1-t)**(n-k) * n_over_k(n,k)
BezierCoeff = lambda ts: [[Mtk(3,t,k) for k in range(4)] for t in ts]
def bezier_fit(x, y):
dy = y[1:] - y[:-1]
dx = x[1:] - x[:-1]
dt = (dx ** 2 + dy ** 2)**0.5
t = dt/dt.sum()
t = np.hstack(([0], t))
t = t.cumsum()
data = np.column_stack((x, y))
Pseudoinverse = np.linalg.pinv(BezierCoeff(t)) # (9,4) -> (4,9)
control_points = Pseudoinverse.dot(data) # (4,9)*(9,2) -> (4,2)
medi_ctp = control_points[1:-1,:].flatten().tolist()
return medi_ctp
def bezier_fitv2(x, y):
# t = (x - x[0]) / (x[-1] - x[0])
xc01 = (2*x[0] + x[-1])/3.0
yc01 = (2*y[0] + y[-1])/3.0
xc02 = (x[0] + 2* x[-1])/3.0
yc02 = (y[0] + 2* y[-1])/3.0
control_points = [xc01,yc01,xc02,yc02]
return control_points
def is_close_to_line(xs, ys, thres):
regression_model = LinearRegression()
# Fit the data(train the model)
regression_model.fit(xs.reshape(-1,1), ys.reshape(-1,1))
# Predict
y_predicted = regression_model.predict(xs.reshape(-1,1))
# model evaluation
rmse = mean_squared_error(ys.reshape(-1,1)**2, y_predicted**2)
rmse = rmse/(ys.reshape(-1,1)**2- y_predicted**2).max()**2
if rmse > thres:
return 0.0
else:
return 2.0
def is_close_to_linev2(xs, ys, size, thres = 0.05):
pts = []
nor_pixel = int(size**0.5)
for i in range(len(xs)):
pts.append(Point([xs[i], ys[i]]))
import itertools
# iterate by pairs of points
slopes = [(second.y-first.y)/(second.x-first.x) if not (second.x-first.x) == 0.0 else math.inf*np.sign((second.y-first.y)) for first, second in zip(pts, pts[1:])]
st_slope = (ys[-1] - ys[0])/(xs[-1] - xs[0])
max_dis = ((ys[-1] - ys[0])**2 +(xs[-1] - xs[0])**2)**(0.5)
diffs = abs(slopes - st_slope)
score = diffs.sum() * max_dis/nor_pixel
if score < thres:
return 0.0
else:
return 3.0
def find_long_edges(points, bottoms):
b1_start, b1_end = bottoms[0]
b2_start, b2_end = bottoms[1]
n_pts = len(points)
i = (b1_end + 1) % n_pts
long_edge_1 = []
while (i % n_pts != b2_end):
start = (i - 1) % n_pts
end = i % n_pts
long_edge_1.append((start, end))
i = (i + 1) % n_pts
i = (b2_end + 1) % n_pts
long_edge_2 = []
while (i % n_pts != b1_end):
start = (i - 1) % n_pts
end = i % n_pts
long_edge_2.append((start, end))
i = (i + 1) % n_pts
return long_edge_1, long_edge_2
def norm2(x, axis=None):
if axis:
return np.sqrt(np.sum(x ** 2, axis=axis))
return np.sqrt(np.sum(x ** 2))
def cos(p1, p2):
return (p1 * p2).sum() / (norm2(p1) * norm2(p2))
def find_bottom(pts):
if len(pts) > 4:
e = np.concatenate([pts, pts[:3]])
candidate = []
for i in range(1, len(pts) + 1):
v_prev = e[i] - e[i - 1]
v_next = e[i + 2] - e[i + 1]
if cos(v_prev, v_next) < -0.7:
candidate.append((i % len(pts), (i + 1) % len(pts), norm2(e[i] - e[i + 1])))
if len(candidate) != 2 or candidate[0][0] == candidate[1][1] or candidate[0][1] == candidate[1][0]:
# if candidate number < 2, or two bottom are joined, select 2 farthest edge
mid_list = []
for i in range(len(pts)):
mid_point = (e[i] + e[(i + 1) % len(pts)]) / 2
mid_list.append((i, (i + 1) % len(pts), mid_point))
dist_list = []
for i in range(len(pts)):
for j in range(len(pts)):
s1, e1, mid1 = mid_list[i]
s2, e2, mid2 = mid_list[j]
dist = norm2(mid1 - mid2)
dist_list.append((s1, e1, s2, e2, dist))
bottom_idx = np.argsort([dist for s1, e1, s2, e2, dist in dist_list])[-2:]
bottoms = [dist_list[bottom_idx[0]][:2], dist_list[bottom_idx[1]][:2]]
else:
bottoms = [candidate[0][:2], candidate[1][:2]]
else:
d1 = norm2(pts[1] - pts[0]) + norm2(pts[2] - pts[3])
d2 = norm2(pts[2] - pts[1]) + norm2(pts[0] - pts[3])
bottoms = [(0, 1), (2, 3)] if d1 < d2 else [(1, 2), (3, 0)]
assert len(bottoms) == 2, 'fewer than 2 bottoms'
return bottoms
def cal_control_pts(coords):
poly = np.array(coords)
bottom = find_bottom(poly)
e1, e2 = find_long_edges(poly, bottom)
id0, id1 = e1[0]
id2, id3 = e2[0]
poly = np.array(poly)[[id1, id0, id3, id2]]
x0, y0 = poly[0]
x1, y1 = poly[1]
x2, y2 = poly[2]
x3, y3 = poly[3]
# find long edge
new_x1 = 1./3 * (x1 - x0) + x0
new_y1 = 1./3 * (y1 - y0) + y0
new_x2 = 2./3 * (x1 - x0) + x0
new_y2 = 2./3 * (y1 - y0) + y0
new_x3 = 1./3 * (x2 - x3) + x3
new_y3 = 1./3 * (y2 - y3) + y3
new_x4 = 2./3 * (x2 - x3) + x3
new_y4 = 2./3 * (y2 - y3) + y3
newpts = [
[x0, y0],
[new_x1, new_y1],
[new_x2, new_y2],
[x1, y1],
[x2, y2],
[new_x4, new_y4],
[new_x3, new_y3],
[x3, y3]
]
return newpts
import sys
data_dir = sys.argv[1]
out_dir = sys.argv[2]
labels = glob.glob('{}/*.txt'.format(data_dir))
labels.sort()
for il, label in tqdm(enumerate(labels)):
# print('Pros '+label)
imgdir = label.replace('.txt', '.jpg')
data = []
cts = []
polys = []
fin = open(label, 'r').readlines()
for il, line in enumerate(fin):
line = line.strip().split(',')
# if not len(line[:-1]) == 20: continue
ct = line[-1]
if ct == '#': continue
# print('ct', ct)
line = [item.replace('\ufeff', '') for item in line]
try:
coords = [(float(line[:-1][ix]), float(line[:-1][ix+1])) for ix in range(0, len(line[:-1]), 2)]
except:
continue
coords = cal_control_pts(coords)
poly = Polygon(coords)
coords_data = np.array(coords).reshape((-1))
data.append(coords_data)
# data.append(np.array([float(x) for x in line[:-1]]))
cts.append(ct)
polys.append(poly)
############## top
# img = plt.imread(imgdir)
outgt = open(os.path.join(out_dir, label.split('/')[-1]), 'w')
for iid, ddata in enumerate(data):
lh = len(data[iid])
assert(lh % 4 ==0)
lhc2 = int(lh/2)
lhc4 = int(lh/4)
xcors = [data[iid][i] for i in range(0, len(data[iid]),2)]
ycors = [data[iid][i+1] for i in range(0, len(data[iid]),2)]
curve_data_top = data[iid][0:lhc2].reshape(lhc4, 2)
curve_data_bottom = data[iid][lhc2:].reshape(lhc4, 2)
left_vertex_x = [curve_data_top[0,0], curve_data_bottom[lhc4-1,0]]
left_vertex_y = [curve_data_top[0,1], curve_data_bottom[lhc4-1,1]]
right_vertex_x = [curve_data_top[lhc4-1,0], curve_data_bottom[0,0]]
right_vertex_y = [curve_data_top[lhc4-1,1], curve_data_bottom[0,1]]
x_data = curve_data_top[:, 0]
y_data = curve_data_top[:, 1]
init_control_points = bezier_fit(x_data, y_data)
size = 512*512
learning_rate = is_close_to_linev2(x_data, y_data, size)
x0, x1, x2, x3, y0, y1, y2, y3 = train(x_data, y_data, init_control_points, 0.0)
control_points = np.array([
[x0,y0],\
[x1,y1],\
[x2,y2],\
[x3,y3]
])
x_data_b = curve_data_bottom[:, 0]
y_data_b = curve_data_bottom[:, 1]
init_control_points_b = bezier_fit(x_data_b, y_data_b)
learning_rate = is_close_to_linev2(x_data_b, y_data_b, size)
x0_b, x1_b, x2_b, x3_b, y0_b, y1_b, y2_b, y3_b = train(x_data_b, y_data_b, init_control_points_b, 0.0)
control_points_b = np.array([
[x0_b,y0_b],\
[x1_b,y1_b],\
[x2_b,y2_b],\
[x3_b,y3_b]
])
t_plot = np.linspace(0, 1, 81)
Bezier_top = np.array(BezierCoeff(t_plot)).dot(control_points)
Bezier_bottom = np.array(BezierCoeff(t_plot)).dot(control_points_b)
# fig, ax = plt.subplots()
# plt.plot(x_data, y_data, 'ro', label='input', linewidth = 1.0)
# plt.plot(x_data_b, y_data_b, 'ro', label='input', linewidth = 1.0)
plt.plot(Bezier_top[:,0],
Bezier_top[:,1], 'g-', label='fit', linewidth=1.0)
plt.plot(Bezier_bottom[:,0],
Bezier_bottom[:,1], 'g-', label='fit', linewidth=1.0)
plt.plot(control_points[:,0],
control_points[:,1], 'r.:', fillstyle='none', linewidth=1.0)
plt.plot(control_points_b[:,0],
control_points_b[:,1], 'r.:', fillstyle='none', linewidth=1.0)
plt.plot(left_vertex_x, left_vertex_y, 'g-', linewidth=1.0)
plt.plot(right_vertex_x, right_vertex_y, 'g-', linewidth=1.0)
outstr = '{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}\n'.format(x0,y0,\
round(x1, 2), round(y1, 2),\
round(x2, 2), round(y2, 2),\
round(x3, 2), round(y3, 2),\
round(x0_b, 2), round(y0_b, 2),\
round(x1_b, 2), round(y1_b, 2),\
round(x2_b, 2), round(y2_b, 2),\
round(x3_b, 2), round(y3_b, 2),\
cts[iid])
outgt.writelines(outstr)
outgt.close()
# plt.imshow(img)
# plt.axis('off')
# if not os.path.isdir('vis_results'):
# os.mkdir('vis_results')
# plt.savefig('vis_results/'+os.path.basename(imgdir), bbox_inches='tight',dpi=400)
# plt.clf()
After you get bezier points, you can use them with origin text annotations to genrate coco-format, you need to add extra info for annotation:
{
'area': h*w,
'bbox': box,
'category_id': cat_id,
'id': ann_id,
'image_id': image_id,
'iscrowd': 0,
'segmentation': [poly],
'text': [text],
'bezier_pts': [bezier_pts], # bezier points you generated for each text instance
'rec': [rec] # text label for recognition head
}
And then, you just configure data path and run python tools/train.py --config-file *.yaml
. The model will work well.
If you want to generate anno for curve text, you can use script in README, above all I mentioned is just for rotate box
from bezier_curve_text_spotting.
All code required for training is included in this repo. Simply using train_net.py.
All the training data are also publicly available.
You will need to figure out "how" at this moment. We will release the instruction in the future.
from bezier_curve_text_spotting.
@Yuliang-Liu Thanks for sharing the code! just wanted to check if the word_bezier.yaml needs any parameter changes for training? If I want to fine-tune from your model, what parameter changes would you recommend?
thanks!
from bezier_curve_text_spotting.
@Yuliang-Liu further information on how to train from scratch is required.
from bezier_curve_text_spotting.
@saicoco Thank you.
We will release our full code in the Adet next week, including the models of CTW1500 and Total-text, the training data we used, evaluation scripts, results of detection, etc. This repo will not be maintained anymore.
Thanks for your attention.
from bezier_curve_text_spotting.
Related Issues (20)
- Public total_train.json HOT 1
- When I run vis_bezier.py, The following error appeared. HOT 5
- 环境需求 HOT 3
- 框架及中文模型训练疑问 HOT 2
- 关于公式3
- BezierAlign can not use HOT 4
- How to get the string of text (label) inside datasets in ABC-Net?
- how to use the lexicon to obtain the results of text spotting? Thank you.
- 关于paper中figure5 (a)方法 HOT 2
- 作者你好,可以给出环境需求吗,最重要的是torch版本
- Can I fine-tune the model and trained it in different language
- 替换backbone为R18后文字识别全为0
- train eror no AttributeError DARTS_ON in confige file
- 请问为何对于ABCNet中的FCOS检测头不使用center-ness呢?
- 朋友,邮箱分享一下,有问题要咨询
- 此处的图片为什么取这么大呢?
- icdar2015 HOT 1
- Trained Model for evaluation HOT 2
- results ? HOT 1
- Can bezier model recognize Chinese texts end2end? HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from bezier_curve_text_spotting.