forked from imran1289-ah/EmotionDetector
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest-model-individual.py
39 lines (32 loc) · 1.22 KB
/
test-model-individual.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from PIL import Image
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets
from torchvision.datasets import ImageFolder
from cnn_model import CNN
from cnn_model2 import CNNVariant2
from cnn_model3 import CNNVariant3
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
transforms.Resize((256, 256)),
])
# Loading the trained model
model = CNNVariant3()
model.load_state_dict(torch.load("emotion_classifier_model_cnn_variant3.pth"), strict=False)
# Loading custom image and transforming it
IMAGE_PATH = r"dataset\Happy\00a7112212c133de419d2c89fd8be75729b041400d6458c3ed8db29a.jpg"
emotion_image = Image.open(IMAGE_PATH)
tensor_image = transform(emotion_image).unsqueeze(0)
# Predict the image by passing it to the model
model.eval()
with torch.no_grad():
outputs = model(tensor_image)
_, predicted = torch.max(outputs.data, 1)
labels = ["Focused", "Happy", "Neutral", "Suprised"]
predictedEmotion = labels[predicted.item()]
imageName = IMAGE_PATH.split("\\")[-1];
# Print the prediction
print(f"the emotion was found to be {predictedEmotion} for image {imageName}")