face_swap.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. #! /usr/bin/env python
  2. import cv2
  3. import numpy as np
  4. import scipy.spatial as spatial
  5. import logging
  6. ## 3D Transform
  7. def bilinear_interpolate(img, coords):
  8. """ Interpolates over every image channel
  9. http://en.wikipedia.org/wiki/Bilinear_interpolation
  10. :param img: max 3 channel image
  11. :param coords: 2 x _m_ array. 1st row = xcoords, 2nd row = ycoords
  12. :returns: array of interpolated pixels with same shape as coords
  13. """
  14. int_coords = np.int32(coords)
  15. x0, y0 = int_coords
  16. dx, dy = coords - int_coords
  17. # 4 Neighour pixels
  18. q11 = img[y0, x0]
  19. q21 = img[y0, x0 + 1]
  20. q12 = img[y0 + 1, x0]
  21. q22 = img[y0 + 1, x0 + 1]
  22. btm = q21.T * dx + q11.T * (1 - dx)
  23. top = q22.T * dx + q12.T * (1 - dx)
  24. inter_pixel = top * dy + btm * (1 - dy)
  25. return inter_pixel.T
  26. def grid_coordinates(points):
  27. """ x,y grid coordinates within the ROI of supplied points
  28. :param points: points to generate grid coordinates
  29. :returns: array of (x, y) coordinates
  30. """
  31. xmin = np.min(points[:, 0])
  32. xmax = np.max(points[:, 0]) + 1
  33. ymin = np.min(points[:, 1])
  34. ymax = np.max(points[:, 1]) + 1
  35. return np.asarray([(x, y) for y in range(ymin, ymax)
  36. for x in range(xmin, xmax)], np.uint32)
  37. def process_warp(src_img, result_img, tri_affines, dst_points, delaunay):
  38. """
  39. Warp each triangle from the src_image only within the
  40. ROI of the destination image (points in dst_points).
  41. """
  42. roi_coords = grid_coordinates(dst_points)
  43. # indices to vertices. -1 if pixel is not in any triangle
  44. roi_tri_indices = delaunay.find_simplex(roi_coords)
  45. for simplex_index in range(len(delaunay.simplices)):
  46. coords = roi_coords[roi_tri_indices == simplex_index]
  47. num_coords = len(coords)
  48. out_coords = np.dot(tri_affines[simplex_index],
  49. np.vstack((coords.T, np.ones(num_coords))))
  50. x, y = coords.T
  51. result_img[y, x] = bilinear_interpolate(src_img, out_coords)
  52. return None
  53. def triangular_affine_matrices(vertices, src_points, dst_points):
  54. """
  55. Calculate the affine transformation matrix for each
  56. triangle (x,y) vertex from dst_points to src_points
  57. :param vertices: array of triplet indices to corners of triangle
  58. :param src_points: array of [x, y] points to landmarks for source image
  59. :param dst_points: array of [x, y] points to landmarks for destination image
  60. :returns: 2 x 3 affine matrix transformation for a triangle
  61. """
  62. ones = [1, 1, 1]
  63. for tri_indices in vertices:
  64. src_tri = np.vstack((src_points[tri_indices, :].T, ones))
  65. dst_tri = np.vstack((dst_points[tri_indices, :].T, ones))
  66. mat = np.dot(src_tri, np.linalg.inv(dst_tri))[:2, :]
  67. yield mat
  68. def warp_image_3d(src_img, src_points, dst_points, dst_shape, dtype=np.uint8):
  69. rows, cols = dst_shape[:2]
  70. result_img = np.zeros((rows, cols, 3), dtype=dtype)
  71. delaunay = spatial.Delaunay(dst_points)
  72. tri_affines = np.asarray(list(triangular_affine_matrices(
  73. delaunay.simplices, src_points, dst_points)))
  74. process_warp(src_img, result_img, tri_affines, dst_points, delaunay)
  75. return result_img
  76. ## 2D Transform
  77. def transformation_from_points(points1, points2):
  78. points1 = points1.astype(np.float64)
  79. points2 = points2.astype(np.float64)
  80. c1 = np.mean(points1, axis=0)
  81. c2 = np.mean(points2, axis=0)
  82. points1 -= c1
  83. points2 -= c2
  84. s1 = np.std(points1)
  85. s2 = np.std(points2)
  86. points1 /= s1
  87. points2 /= s2
  88. U, S, Vt = np.linalg.svd(np.dot(points1.T, points2))
  89. R = (np.dot(U, Vt)).T
  90. return np.vstack([np.hstack([s2 / s1 * R,
  91. (c2.T - np.dot(s2 / s1 * R, c1.T))[:, np.newaxis]]),
  92. np.array([[0., 0., 1.]])])
  93. def warp_image_2d(im, M, dshape):
  94. output_im = np.zeros(dshape, dtype=im.dtype)
  95. cv2.warpAffine(im,
  96. M[:2],
  97. (dshape[1], dshape[0]),
  98. dst=output_im,
  99. borderMode=cv2.BORDER_TRANSPARENT,
  100. flags=cv2.WARP_INVERSE_MAP)
  101. return output_im
  102. ## Generate Mask
  103. def mask_from_points(size, points,erode_flag=1):
  104. radius = 10 # kernel size
  105. kernel = np.ones((radius, radius), np.uint8)
  106. mask = np.zeros(size, np.uint8)
  107. cv2.fillConvexPoly(mask, cv2.convexHull(points), 255)
  108. if erode_flag:
  109. mask = cv2.erode(mask, kernel,iterations=1)
  110. return mask
  111. ## Color Correction
  112. def correct_colours(im1, im2, landmarks1):
  113. COLOUR_CORRECT_BLUR_FRAC = 0.75
  114. LEFT_EYE_POINTS = list(range(42, 48))
  115. RIGHT_EYE_POINTS = list(range(36, 42))
  116. blur_amount = COLOUR_CORRECT_BLUR_FRAC * np.linalg.norm(
  117. np.mean(landmarks1[LEFT_EYE_POINTS], axis=0) -
  118. np.mean(landmarks1[RIGHT_EYE_POINTS], axis=0))
  119. blur_amount = int(blur_amount)
  120. if blur_amount % 2 == 0:
  121. blur_amount += 1
  122. im1_blur = cv2.GaussianBlur(im1, (blur_amount, blur_amount), 0)
  123. im2_blur = cv2.GaussianBlur(im2, (blur_amount, blur_amount), 0)
  124. # Avoid divide-by-zero errors.
  125. im2_blur = im2_blur.astype(int)
  126. im2_blur += 128*(im2_blur <= 1)
  127. result = im2.astype(np.float64) * im1_blur.astype(np.float64) / im2_blur.astype(np.float64)
  128. result = np.clip(result, 0, 255).astype(np.uint8)
  129. return result
  130. ## Copy-and-paste
  131. def apply_mask(img, mask):
  132. """ Apply mask to supplied image
  133. :param img: max 3 channel image
  134. :param mask: [0-255] values in mask
  135. :returns: new image with mask applied
  136. """
  137. masked_img=cv2.bitwise_and(img,img,mask=mask)
  138. return masked_img
  139. ## Alpha blending
  140. def alpha_feathering(src_img, dest_img, img_mask, blur_radius=15):
  141. mask = cv2.blur(img_mask, (blur_radius, blur_radius))
  142. mask = mask / 255.0
  143. result_img = np.empty(src_img.shape, np.uint8)
  144. for i in range(3):
  145. result_img[..., i] = src_img[..., i] * mask + dest_img[..., i] * (1-mask)
  146. return result_img
  147. def check_points(img,points):
  148. # Todo: I just consider one situation.
  149. if points[8,1]>img.shape[0]:
  150. logging.error("Jaw part out of image")
  151. else:
  152. return True
  153. return False
  154. def face_swap(src_face, dst_face, src_points, dst_points, dst_shape, dst_img, args, end=48):
  155. h, w = dst_face.shape[:2]
  156. ## 3d warp
  157. warped_src_face = warp_image_3d(src_face, src_points[:end], dst_points[:end], (h, w))
  158. ## Mask for blending
  159. mask = mask_from_points((h, w), dst_points)
  160. mask_src = np.mean(warped_src_face, axis=2) > 0
  161. mask = np.asarray(mask * mask_src, dtype=np.uint8)
  162. ## Correct color
  163. if args.correct_color:
  164. warped_src_face = apply_mask(warped_src_face, mask)
  165. dst_face_masked = apply_mask(dst_face, mask)
  166. warped_src_face = correct_colours(dst_face_masked, warped_src_face, dst_points)
  167. ## 2d warp
  168. if args.warp_2d:
  169. unwarped_src_face = warp_image_3d(warped_src_face, dst_points[:end], src_points[:end], src_face.shape[:2])
  170. warped_src_face = warp_image_2d(unwarped_src_face, transformation_from_points(dst_points, src_points),
  171. (h, w, 3))
  172. mask = mask_from_points((h, w), dst_points)
  173. mask_src = np.mean(warped_src_face, axis=2) > 0
  174. mask = np.asarray(mask * mask_src, dtype=np.uint8)
  175. ## Shrink the mask
  176. kernel = np.ones((10, 10), np.uint8)
  177. mask = cv2.erode(mask, kernel, iterations=1)
  178. ##Poisson Blending
  179. r = cv2.boundingRect(mask)
  180. center = ((r[0] + int(r[2] / 2), r[1] + int(r[3] / 2)))
  181. output = cv2.seamlessClone(warped_src_face, dst_face, mask, center, cv2.NORMAL_CLONE)
  182. x, y, w, h = dst_shape
  183. dst_img_cp = dst_img.copy()
  184. dst_img_cp[y:y + h, x:x + w] = output
  185. return dst_img_cp