准备
安装facenet-pytorch及其相关包
pip install facenet-pytorch
模型下载,提取码: 74s4
vggface2_*.pt,两个文件存放于~/.cache/torch/checkpoints/
代码
from facenet_pytorch import MTCNN, InceptionResnetV1
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
import numpy as np
import PIL
import os
device = torch.device(‘cuda:0‘ if torch.cuda.is_available() else ‘cpu‘)
mtcnn = MTCNN(
image_size=160, margin=0, min_face_size=20,
thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,
device=device
)
resnet = InceptionResnetV1(pretrained=‘vggface2‘).eval().to(device)
def create_face_db(db_name, db_path, model, align_model):
"""
# 创建人脸数据库
Args:
db_name (string): 数据库文件名,npz格式
db_path (string): 图片文件夹,按人名命名子文件夹,所有图片按人名归类
"""
def collate_fn(x):
return x[0]
dataset = datasets.ImageFolder(db_path)
dataset.idx_to_class = {i: c for c, i in dataset.class_to_idx.items()}
loader = DataLoader(dataset, collate_fn=collate_fn)
aligned = []
names = []
for x, y in loader:
x_aligned, prob = align_model(x, return_prob=True)
if x_aligned is not None:
print(‘Face detected with probability: {:8f}‘.format(prob))
aligned.append(x_aligned)
names.append(dataset.idx_to_class[y])
device = torch.device(‘cuda:0‘ if torch.cuda.is_available() else ‘cpu‘)
aligned = torch.stack(aligned).to(device)
embeddings = model(aligned).detach().cpu().numpy()
np.savez(db_name, embed=embeddings, names=np.array(names))
def who(img_name, db_name, model, align_model):
x_aligned = None
with open(img_name, ‘rb‘) as f:
img = PIL.Image.open(f)
img = img.convert(‘RGB‘)
x_aligned, prob = align_model(img, return_prob=True)
if x_aligned is not None:
print(‘Face detected with probability: {:8f}‘.format(prob))
f = np.load(db_name)
device = torch.device(‘cuda:0‘ if torch.cuda.is_available() else ‘cpu‘)
embeddings2 = model(x_aligned.unsqueeze(0).to(device)).detach().cpu().numpy()
dists = [[np.linalg.norm(e1 - e2) for e2 in embeddings2] for e1 in f[‘embed‘]]
# print(dists)
index = np.argmin(dists)
# e2tensor = torch.from_numpy(f[‘embed‘])
# e1tensor = torch.from_numpy(embeddings2)
# dists = [[(e1 - e2).norm().item() for e2 in e1tensor] for e1 in e2tensor]
# print(dists)
return f[‘names‘][index], dists[index]
# 创建人脸数据库,如果没有新的图片加入,可不用执行
create_face_db(‘db.npz‘, ‘test_images‘, resnet, mtcnn)
# 查询图片对应人名,
name, dist = who(‘WP_000098.jpg‘, ‘db.npz‘, resnet, mtcnn)
print(‘it\‘s %s, dist:%s‘ % (name, dist))