Git Product home page Git Product logo

Comments (3)

juneweng avatar juneweng commented on May 22, 2024

In addition,,Test_acc is mean of 10 folds best_Test_acc or not?

from facial-expression-recognition.pytorch.

WuJie1010 avatar WuJie1010 commented on May 22, 2024

May be you can try again or train some fold with loss acc specifically
Yes!Test_acc is mean of 10 folds best_Test_acc

from facial-expression-recognition.pytorch.

yuhao910716 avatar yuhao910716 commented on May 22, 2024

`import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
class_names = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral']

'path 是你创建的路径,label是你定义好的类别,因为本人代码都是策contempt表情的,因此选择6,是根据class_names的索引定义的,如果想改动去代码63行处理'

path='CK+48/contempt/'
label=6

path='CK+48/anger/'

path='CK+48/surprise/'

path='CK+48/happy/'

path='test_imgs/'

Creat the list to store the data and label information

path='1_test/'

from PIL import Image
import os
import numpy as np
from models import *
import matplotlib.pyplot as plt
net = VGG('VGG19')
"================================================================================================================================================="
checkpoint = torch.load(os.path.join('trained_model_pt', 'PrivateTest_model.t7'), map_location='cpu')
'数据的加载方法'
net.load_state_dict(checkpoint['net'].state_dict())

*****************************

net.eval()
all_num=0
true_num=0
import cv2

if name == 'main':
import transforms as transforms
print('==> Preparing data..')
cut_size = 48
#先把48*48的数据集转换成4个角落的和中心为44的数据,然后进行测试
transform_test = transforms.Compose([
transforms.TenCrop(cut_size),
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
])
import time
print('start',time.time())
for file in os.listdir(path):
raw_img = cv2.imread(path+file,cv2.IMREAD_GRAYSCALE)
raw_img = cv2.resize(raw_img, (48, 48), interpolation=cv2.INTER_CUBIC)
img = raw_img[:, :, np.newaxis]
img = np.concatenate((img, img, img), axis=2)
img = Image.fromarray(img)
inputs = transform_test(img)
ncrops, c, h, w = np.shape(inputs)
inputs = inputs.view(-1, c, h, w)
inputs = Variable(inputs, volatile=True)
start=time.time()
outputs = net(inputs[0:])
outputs_avg = outputs.view(1, ncrops, -1).mean(1) # avg over crops
score = F.softmax(outputs_avg)
_, predicted = torch.max(outputs_avg.data, 1)
predicted_1 = np.reshape(predicted, -1)

    if predicted_1 - label == 0:
        true_num += 1
    all_num += 1
    if all_num > 1000:
        break


    "============================================================"
    '把横线以下注释掉就能看到效果'
    plt.rcParams['figure.figsize'] = (13.5, 5.5)
    axes = plt.subplot(1, 3, 1)
    # print(np.shape(np.array( inputs[0:][0])))
    array_05 = np.array(inputs[0:][0]).transpose(1, 2, 0)

    plt.imshow(array_05)
    plt.xlabel('Input Image', fontsize=16)
    axes.set_xticks([])
    axes.set_yticks([])
    plt.tight_layout()

    plt.subplots_adjust(left=0.05, bottom=0.2, right=0.95, top=0.9, hspace=0.02, wspace=0.3)
    plt.subplot(1, 3, 2)
    ind = 0.1 + 0.6 * np.arange(len(class_names))  # the x locations for the groups
    list_data=score.data.numpy()
    width = 0.4  # the width of the bars: can also be len(x) sequence
    color_list = ['red', 'orangered', 'darkorange', 'limegreen', 'darkgreen', 'royalblue', 'navy']
    plt.bar([1,2,3,4,5,6,7], list_data[0], 1,width, color=color_list)
    plt.title("Classification results ", fontsize=20)
    plt.xlabel(" Expression Category ", fontsize=16)
    plt.ylabel(" Classification Score ", fontsize=16)
    plt.xticks([1,2,3,4,5,6,7], class_names, rotation=45, fontsize=14)
    plt.show()
    '把这里打开就就能保存相关的图片02'
    # plt.savefig(os.path.join('images/results/{}.png'.format(batch_idx)))
    plt.close()
print("判断正确的个数:",true_num)
print("总共判断的个数:",all_num)

`

from facial-expression-recognition.pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.