php中文网 | cnphp.com

 找回密码
 立即注册

QQ登录

只需一步,快速开始

搜索
查看: 575|回复: 0

基于Python的笑脸识别

[复制链接]

3150

主题

3160

帖子

1万

积分

管理员

Rank: 9Rank: 9Rank: 9

UID
1
威望
0
积分
7976
贡献
0
注册时间
2021-4-14
最后登录
2024-11-24
在线时间
763 小时
QQ
发表于 2022-6-15 22:15:00 | 显示全部楼层 |阅读模式
[mw_shl_code=python,true]from train_model.model import FaceCNN
from data_set.FaceData import FaceDataset
import torch
import torch.utils.data as data
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import cv2

# 验证模型在验证集上的正确率
def validate(model, dataset, batch_size):
    val_loader = data.DataLoader(dataset, batch_size)
    result, num = 0.0, 0
    for images, labels in val_loader:
        pred = model.forward(images)
        pred = np.argmax(pred.data.numpy(), axis=1)
        labels = labels.data.numpy()
        result += np.sum((pred == labels))
        num += len(images)
    acc = result / num
    return acc

def train(train_dataset, val_dataset, batch_size, epochs, learning_rate, wt_decay):
    # 载入数据并分割batch
    train_loader = data.DataLoader(train_dataset, batch_size)
    # 构建模型
    model = FaceCNN()
    # 损失函数
    loss_function = nn.CrossEntropyLoss()
    # 优化器
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, weight_decay=wt_decay)
    # 学习率衰减
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
    # 逐轮训练
    for epoch in range(epochs):
        # 记录损失值
        loss_rate = 0
        # scheduler.step() # 学习率衰减
        model.train()  # 模型训练
        for images, labels in train_loader:
            # 梯度清零
            optimizer.zero_grad()
            # 前向传播
            output = model.forward(images)
            # 误差计算
            loss_rate = loss_function(output, labels)
            # 误差的反向传播
            loss_rate.backward()
            # 更新参数
            optimizer.step()

        # 打印每轮的损失
        print('After {} epochs , the loss_rate is : '.format(epoch + 1), loss_rate.item())
        model.eval()  # 模型评估
        acc_train = validate(model, train_dataset, batch_size)
        acc_val = validate(model, val_dataset, batch_size)
        print('After {} epochs , the acc_train is : '.format(epoch + 1), acc_train)
        print('After {} epochs , the acc_val is : '.format(epoch + 1), acc_val)
        if epoch % 5 == 0:
            # model.eval()  # 模型评估
            # acc_train = validate(model, train_dataset, batch_size)
            # acc_val = validate(model, val_dataset, batch_size)
            # print('After {} epochs , the acc_train is : '.format(epoch + 1), acc_train)
            # print('After {} epochs , the acc_val is : '.format(epoch + 1), acc_val)
            torch.save(model.state_dict(), 'C:\\Users\\bhj\\Desktop\\smile_train\\moudles\\model_net%d.pth'%epoch)
    return model

def main():
    # 数据集实例化(创建数据集)
    data_train_path = "C:\\Users\\bhj\\Desktop\\smile_train\\datasets\\train_data"
    data_val_path = "C:\\Users\\bhj\\Desktop\\smile_train\\datasets\\val_data"
    train_csv = "C:\\Users\\bhj\\Desktop\\smile_train\\train_data.csv"
    val_csv = "C:\\Users\\bhj\\Desktop\\smile_train\\val_data.csv"
    train_dataset = FaceDataset(data_train_path, train_csv)
    val_dataset = FaceDataset(data_val_path, val_csv)
    # 超参数可自行指定
    model = train(train_dataset, val_dataset, batch_size=32, epochs=100, learning_rate=0.01, wt_decay=0)
    # 保存模型
    torch.save(model.state_dict(), 'model_net_result.pth')


if __name__ == '__main__':
    main()[/mw_shl_code]

回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

QQ|php中文网 | cnphp.com ( 赣ICP备2021002321号-2 )

GMT+8, 2024-11-24 18:15 , Processed in 0.323877 second(s), 36 queries , Gzip On.

Powered by Discuz! X3.4 Licensed

Copyright © 2001-2020, Tencent Cloud.

申明:本站所有资源皆搜集自网络,相关版权归版权持有人所有,如有侵权,请电邮(fiorkn@foxmail.com)告之,本站会尽快删除。

快速回复 返回顶部 返回列表