Mediapipe 实现姿态识别

参考代码:

  1. 【编程奇妙夜】AI健身+三维人体姿态估计(附Mediapipe代码复现)
  2. Google云盘

8 个类

关于这 8 个类原文中也没有具体的解释,所以基本是照搬了源码。只对有 bug 无法运行的地方进行了修改。

BootstrapHelper

Code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import cv2
import csv
import numpy as np
import os
import sys
import tqdm

from mediapipe.python.solutions import drawing_utils as mp_drawing
from mediapipe.python.solutions import pose as mp_pose
from PIL import ImageDraw
from PIL import Image
from Function.Show import show_image


class BootstrapHelper(object):
"""Helps to bootstrap images and filter pose samples for classification."""

def __init__(self,
images_in_folder,
images_out_folder,
csvs_out_folder):
self._images_in_folder = images_in_folder
self._images_out_folder = images_out_folder
self._csvs_out_folder = csvs_out_folder

# Get list of pose classes and print image statistics.
self._pose_class_names = sorted([n for n in os.listdir(self._images_in_folder) if not n.startswith('.')])

def bootstrap(self, per_pose_class_limit=None):
"""Bootstraps images in a given folder.

Required image in folder (same use for image out folder):
pushups_up/
image_001.jpg
image_002.jpg
...
pushups_down/
image_001.jpg
image_002.jpg
...
...

Produced CSVs out folder:
pushups_up.csv
pushups_down.csv

Produced CSV structure with pose 3D landmarks:
sample_00001,x1,y1,z1,x2,y2,z2,....
sample_00002,x1,y1,z1,x2,y2,z2,....
"""
# Create output folder for CVSs.
if not os.path.exists(self._csvs_out_folder):
os.makedirs(self._csvs_out_folder)

for pose_class_name in self._pose_class_names:
print('Bootstrapping ', pose_class_name, file=sys.stderr)

# Paths for the pose class.
images_in_folder = os.path.join(self._images_in_folder, pose_class_name)
images_out_folder = os.path.join(self._images_out_folder, pose_class_name)
csv_out_path = os.path.join(self._csvs_out_folder, pose_class_name + '.csv')
if not os.path.exists(images_out_folder):
os.makedirs(images_out_folder)

with open(csv_out_path, 'w') as csv_out_file:
csv_out_writer = csv.writer(csv_out_file, delimiter=',', quoting=csv.QUOTE_MINIMAL)
# Get list of images.
image_names = sorted([n for n in os.listdir(images_in_folder) if not n.startswith('.')])
if per_pose_class_limit is not None:
image_names = image_names[:per_pose_class_limit]

# Bootstrap every image.
for image_name in tqdm.tqdm(image_names):
# Load image.
input_frame = cv2.imread(os.path.join(images_in_folder, image_name))
input_frame = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB)

# Initialize fresh pose tracker and run it.
# 修改
# with mp_pose.Pose(upper_body_only=False) as pose_tracker:
with mp_pose.Pose() as pose_tracker:
result = pose_tracker.process(image=input_frame)
pose_landmarks = result.pose_landmarks

# Save image with pose prediction (if pose was detected).
output_frame = input_frame.copy()
if pose_landmarks is not None:
mp_drawing.draw_landmarks(
image=output_frame,
landmark_list=pose_landmarks,
connections=mp_pose.POSE_CONNECTIONS)
output_frame = cv2.cvtColor(output_frame, cv2.COLOR_RGB2BGR)
cv2.imwrite(os.path.join(images_out_folder, image_name), output_frame)

# Save landmarks if pose was detected.
if pose_landmarks is not None:
# Get landmarks.
frame_height, frame_width = output_frame.shape[0], output_frame.shape[1]
pose_landmarks = np.array(
[[lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width]
for lmk in pose_landmarks.landmark],
dtype=np.float32)
assert pose_landmarks.shape == (33, 3), 'Unexpected landmarks shape: {}'.format(pose_landmarks.shape)
csv_out_writer.writerow([image_name] + pose_landmarks.flatten().astype(np.str).tolist())

# Draw XZ projection and concatenate with the image.
projection_xz = self._draw_xz_projection(
output_frame=output_frame, pose_landmarks=pose_landmarks)
output_frame = np.concatenate((output_frame, projection_xz), axis=1)

def _draw_xz_projection(self, output_frame, pose_landmarks, r=0.5, color='red'):
frame_height, frame_width = output_frame.shape[0], output_frame.shape[1]
img = Image.new('RGB', (frame_width, frame_height), color='white')

if pose_landmarks is None:
return np.asarray(img)

# Scale radius according to the image width.
r *= frame_width * 0.01

draw = ImageDraw.Draw(img)
for idx_1, idx_2 in mp_pose.POSE_CONNECTIONS:
# Flip Z and move hips center to the center of the image.
x1, y1, z1 = pose_landmarks[idx_1] * [1, 1, -1] + [0, 0, frame_height * 0.5]
x2, y2, z2 = pose_landmarks[idx_2] * [1, 1, -1] + [0, 0, frame_height * 0.5]

draw.ellipse([x1 - r, z1 - r, x1 + r, z1 + r], fill=color)
draw.ellipse([x2 - r, z2 - r, x2 + r, z2 + r], fill=color)
draw.line([x1, z1, x2, z2], width=int(r), fill=color)

return np.asarray(img)

def align_images_and_csvs(self, print_removed_items=False):
"""Makes sure that image folders and CSVs have the same sample.

Leaves only intersetion of samples in both image folders and CSVs.
"""
for pose_class_name in self._pose_class_names:
# Paths for the pose class.
images_out_folder = os.path.join(self._images_out_folder, pose_class_name)
csv_out_path = os.path.join(self._csvs_out_folder, pose_class_name + '.csv')

# Read CSV into memory.
rows = []
with open(csv_out_path) as csv_out_file:
csv_out_reader = csv.reader(csv_out_file, delimiter=',')
for row in csv_out_reader:
rows.append(row)

# Image names left in CSV.
image_names_in_csv = []

# Re-write the CSV removing lines without corresponding images.
with open(csv_out_path, 'w') as csv_out_file:
csv_out_writer = csv.writer(csv_out_file, delimiter=',', quoting=csv.QUOTE_MINIMAL)
for row in rows:
image_name = row[0]
image_path = os.path.join(images_out_folder, image_name)
if os.path.exists(image_path):
image_names_in_csv.append(image_name)
csv_out_writer.writerow(row)
elif print_removed_items:
print('Removed image from CSV: ', image_path)

# Remove images without corresponding line in CSV.
for image_name in os.listdir(images_out_folder):
if image_name not in image_names_in_csv:
image_path = os.path.join(images_out_folder, image_name)
os.remove(image_path)
if print_removed_items:
print('Removed image from folder: ', image_path)

def analyze_outliers(self, outliers):
"""Classifies each sample agains all other to find outliers.

If sample is classified differrrently than the original class - it sould
either be deleted or more similar samples should be aadded.
"""
for outlier in outliers:
image_path = os.path.join(self._images_out_folder, outlier.sample.class_name, outlier.sample.name)

print('Outlier')
print(' sample path = ', image_path)
print(' sample class = ', outlier.sample.class_name)
print(' detected class = ', outlier.detected_class)
print(' all classes = ', outlier.all_classes)

img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
show_image(img)

def remove_outliers(self, outliers):
"""Removes outliers from the image folders."""
for outlier in outliers:
image_path = os.path.join(self._images_out_folder, outlier.sample.class_name, outlier.sample.name)
os.remove(image_path)

def print_images_in_statistics(self):
"""Prints statistics from the input image folder."""
self._print_images_statistics(self._images_in_folder, self._pose_class_names)

def print_images_out_statistics(self):
"""Prints statistics from the output image folder."""
self._print_images_statistics(self._images_out_folder, self._pose_class_names)

def _print_images_statistics(self, images_folder, pose_class_names):
print('Number of images per pose class:')
for pose_class_name in pose_class_names:
n_images = len([
n for n in os.listdir(os.path.join(images_folder, pose_class_name))
if not n.startswith('.')])
print(' {}: {}'.format(pose_class_name, n_images))

对源代码进行了一处修改:

1
2
3
4
# 将
with mp_pose.Pose(upper_body_only=False) as pose_tracker:
# 修改成了
with mp_pose.Pose() as pose_tracker:

EMADictSmoothing

Code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class EMADictSmoothing(object):
"""Smoothes pose classification."""

def __init__(self, window_size=10, alpha=0.2):
self._window_size = window_size
self._alpha = alpha

self._data_in_window = []

def __call__(self, data):
"""Smoothes given pose classification.

Smoothing is done by computing Exponential Moving Average for every pose
class observed in the given time window. Missed pose classes arre replaced
with 0.

Args:
data: Dictionary with pose classification. Sample:
{
'pushups_down': 8,
'pushups_up': 2,
}

Result:
Dictionary in the same format but with smoothed and float instead of
integer values. Sample:
{
'pushups_down': 8.3,
'pushups_up': 1.7,
}
"""
# Add new data to the beginning of the window for simpler code.
self._data_in_window.insert(0, data)
self._data_in_window = self._data_in_window[:self._window_size]

# Get all keys.
keys = set([key for data in self._data_in_window for key, _ in data.items()])

# Get smoothed values.
smoothed_data = dict()
for key in keys:
factor = 1.0
top_sum = 0.0
bottom_sum = 0.0
for data in self._data_in_window:
value = data[key] if key in data else 0.0

top_sum += factor * value
bottom_sum += factor

# Update factor.
factor *= (1.0 - self._alpha)

smoothed_data[key] = top_sum / bottom_sum

return smoothed_data

FullBodyPoseEmbedder

Code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import numpy as np


class FullBodyPoseEmbedder(object):
"""Converts 3D pose landmarks into 3D embedding."""

def __init__(self, torso_size_multiplier=2.5):
# Multiplier to apply to the torso to get minimal body size.
self._torso_size_multiplier = torso_size_multiplier

# Names of the landmarks as they appear in the prediction.
self._landmark_names = [
'nose',
'left_eye_inner', 'left_eye', 'left_eye_outer',
'right_eye_inner', 'right_eye', 'right_eye_outer',
'left_ear', 'right_ear',
'mouth_left', 'mouth_right',
'left_shoulder', 'right_shoulder',
'left_elbow', 'right_elbow',
'left_wrist', 'right_wrist',
'left_pinky_1', 'right_pinky_1',
'left_index_1', 'right_index_1',
'left_thumb_2', 'right_thumb_2',
'left_hip', 'right_hip',
'left_knee', 'right_knee',
'left_ankle', 'right_ankle',
'left_heel', 'right_heel',
'left_foot_index', 'right_foot_index',
]

def __call__(self, landmarks):
"""Normalizes pose landmarks and converts to embedding

Args:
landmarks - NumPy array with 3D landmarks of shape (N, 3).

Result:
Numpy array with pose embedding of shape (M, 3) where `M` is the number of
pairwise distances defined in `_get_pose_distance_embedding`.
"""
assert landmarks.shape[0] == len(self._landmark_names), 'Unexpected number of landmarks: {}'.format(landmarks.shape[0])

# Get pose landmarks.
landmarks = np.copy(landmarks)

# Normalize landmarks.
landmarks = self._normalize_pose_landmarks(landmarks)

# Get embedding.
embedding = self._get_pose_distance_embedding(landmarks)

return embedding

def _normalize_pose_landmarks(self, landmarks):
"""Normalizes landmarks translation and scale."""
landmarks = np.copy(landmarks)

# Normalize translation.
pose_center = self._get_pose_center(landmarks)
landmarks -= pose_center

# Normalize scale.
pose_size = self._get_pose_size(landmarks, self._torso_size_multiplier)
landmarks /= pose_size
# Multiplication by 100 is not required, but makes it eaasier to debug.
landmarks *= 100

return landmarks

def _get_pose_center(self, landmarks):
"""Calculates pose center as point between hips."""
left_hip = landmarks[self._landmark_names.index('left_hip')]
right_hip = landmarks[self._landmark_names.index('right_hip')]
center = (left_hip + right_hip) * 0.5
return center

def _get_pose_size(self, landmarks, torso_size_multiplier):
"""Calculates pose size.

It is the maximum of two values:
* Torso size multiplied by `torso_size_multiplier`
* Maximum distance from pose center to any pose landmark
"""
# This approach uses only 2D landmarks to compute pose size.
landmarks = landmarks[:, :2]

# Hips center.
left_hip = landmarks[self._landmark_names.index('left_hip')]
right_hip = landmarks[self._landmark_names.index('right_hip')]
hips = (left_hip + right_hip) * 0.5

# Shoulders center.
left_shoulder = landmarks[self._landmark_names.index('left_shoulder')]
right_shoulder = landmarks[self._landmark_names.index('right_shoulder')]
shoulders = (left_shoulder + right_shoulder) * 0.5

# Torso size as the minimum body size.
torso_size = np.linalg.norm(shoulders - hips)

# Max dist to pose center.
pose_center = self._get_pose_center(landmarks)
max_dist = np.max(np.linalg.norm(landmarks - pose_center, axis=1))

return max(torso_size * torso_size_multiplier, max_dist)

def _get_pose_distance_embedding(self, landmarks):
"""Converts pose landmarks into 3D embedding.

We use several pairwise 3D distances to form pose embedding. All distances
include X and Y components with sign. We differnt types of pairs to cover
different pose classes. Feel free to remove some or add new.

Args:
landmarks - NumPy array with 3D landmarks of shape (N, 3).

Result:
Numpy array with pose embedding of shape (M, 3) where `M` is the number of
pairwise distances.
"""
embedding = np.array([
# One joint.

self._get_distance(
self._get_average_by_names(landmarks, 'left_hip', 'right_hip'),
self._get_average_by_names(landmarks, 'left_shoulder', 'right_shoulder')),

self._get_distance_by_names(landmarks, 'left_shoulder', 'left_elbow'),
self._get_distance_by_names(landmarks, 'right_shoulder', 'right_elbow'),

self._get_distance_by_names(landmarks, 'left_elbow', 'left_wrist'),
self._get_distance_by_names(landmarks, 'right_elbow', 'right_wrist'),

self._get_distance_by_names(landmarks, 'left_hip', 'left_knee'),
self._get_distance_by_names(landmarks, 'right_hip', 'right_knee'),

self._get_distance_by_names(landmarks, 'left_knee', 'left_ankle'),
self._get_distance_by_names(landmarks, 'right_knee', 'right_ankle'),

# Two joints.

self._get_distance_by_names(landmarks, 'left_shoulder', 'left_wrist'),
self._get_distance_by_names(landmarks, 'right_shoulder', 'right_wrist'),

self._get_distance_by_names(landmarks, 'left_hip', 'left_ankle'),
self._get_distance_by_names(landmarks, 'right_hip', 'right_ankle'),

# Four joints.

self._get_distance_by_names(landmarks, 'left_hip', 'left_wrist'),
self._get_distance_by_names(landmarks, 'right_hip', 'right_wrist'),

# Five joints.

self._get_distance_by_names(landmarks, 'left_shoulder', 'left_ankle'),
self._get_distance_by_names(landmarks, 'right_shoulder', 'right_ankle'),

self._get_distance_by_names(landmarks, 'left_hip', 'left_wrist'),
self._get_distance_by_names(landmarks, 'right_hip', 'right_wrist'),

# Cross body.

self._get_distance_by_names(landmarks, 'left_elbow', 'right_elbow'),
self._get_distance_by_names(landmarks, 'left_knee', 'right_knee'),

self._get_distance_by_names(landmarks, 'left_wrist', 'right_wrist'),
self._get_distance_by_names(landmarks, 'left_ankle', 'right_ankle'),

# Body bent direction.

# self._get_distance(
# self._get_average_by_names(landmarks, 'left_wrist', 'left_ankle'),
# landmarks[self._landmark_names.index('left_hip')]),
# self._get_distance(
# self._get_average_by_names(landmarks, 'right_wrist', 'right_ankle'),
# landmarks[self._landmark_names.index('right_hip')]),
])

return embedding

def _get_average_by_names(self, landmarks, name_from, name_to):
lmk_from = landmarks[self._landmark_names.index(name_from)]
lmk_to = landmarks[self._landmark_names.index(name_to)]
return (lmk_from + lmk_to) * 0.5

def _get_distance_by_names(self, landmarks, name_from, name_to):
lmk_from = landmarks[self._landmark_names.index(name_from)]
lmk_to = landmarks[self._landmark_names.index(name_to)]
return self._get_distance(lmk_from, lmk_to)

def _get_distance(self, lmk_from, lmk_to):
return lmk_to - lmk_from

PoseClassificationVisualizer

Code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import io
from PIL import Image
from PIL import ImageFont
from PIL import ImageDraw
import requests
from matplotlib import pyplot as plt


class PoseClassificationVisualizer(object):
"""Keeps track of classifcations for every frame and renders them."""

def __init__(self,
class_name,
plot_location_x=0.05,
plot_location_y=0.05,
plot_max_width=0.4,
plot_max_height=0.4,
plot_figsize=(9, 4),
plot_x_max=None,
plot_y_max=None,
counter_location_x=0.00,
counter_location_y=0.50,
counter_font_path='https://github.com/googlefonts/roboto/blob/main/src/hinted/Roboto-Regular.ttf?raw=true',
counter_font_color='red',
counter_font_size=0.15):
self._class_name = class_name
self._plot_location_x = plot_location_x
self._plot_location_y = plot_location_y
self._plot_max_width = plot_max_width
self._plot_max_height = plot_max_height
self._plot_figsize = plot_figsize
self._plot_x_max = plot_x_max
self._plot_y_max = plot_y_max
self._counter_location_x = counter_location_x
self._counter_location_y = counter_location_y
self._counter_font_path = counter_font_path
self._counter_font_color = counter_font_color
self._counter_font_size = counter_font_size

self._counter_font = None

self._pose_classification_history = []
self._pose_classification_filtered_history = []

def __call__(self,
frame,
pose_classification,
pose_classification_filtered,
repetitions_count):
"""Renders pose classifcation and counter until given frame."""
# Extend classification history.
self._pose_classification_history.append(pose_classification)
self._pose_classification_filtered_history.append(pose_classification_filtered)

# Output frame with classification plot and counter.
output_img = Image.fromarray(frame)

output_width = output_img.size[0]
output_height = output_img.size[1]

# Draw the plot.
img = self._plot_classification_history(output_width, output_height)
# 修改
# img.thumbnail((int(output_width * self._plot_max_width),
# int(output_height * self._plot_max_height)),
# Image.ANTIALIAS)

img.thumbnail((int(output_width * self._plot_max_width), int(output_height * self._plot_max_height)))

output_img.paste(img,
(int(output_width * self._plot_location_x),
int(output_height * self._plot_location_y)))

# Draw the count.
output_img_draw = ImageDraw.Draw(output_img)
if self._counter_font is None:
font_size = int(output_height * self._counter_font_size)
font_request = requests.get(self._counter_font_path, allow_redirects=True)
self._counter_font = ImageFont.truetype(io.BytesIO(font_request.content), size=font_size)
output_img_draw.text((output_width * self._counter_location_x,
output_height * self._counter_location_y),
str(repetitions_count),
font=self._counter_font,
fill=self._counter_font_color)

return output_img

def _plot_classification_history(self, output_width, output_height):
fig = plt.figure(figsize=self._plot_figsize)

for classification_history in [self._pose_classification_history,
self._pose_classification_filtered_history]:
y = []
for classification in classification_history:
if classification is None:
y.append(None)
elif self._class_name in classification:
y.append(classification[self._class_name])
else:
y.append(0)
plt.plot(y, linewidth=7)

plt.grid(axis='y', alpha=0.75)
plt.xlabel('Frame')
plt.ylabel('Confidence')
plt.title('Classification history for `{}`'.format(self._class_name))
# 修改
# plt.legend(loc='upper right')

if self._plot_y_max is not None:
plt.ylim(top=self._plot_y_max)
if self._plot_x_max is not None:
plt.xlim(right=self._plot_x_max)

# Convert plot to image.
buf = io.BytesIO()
dpi = min(
output_width * self._plot_max_width / float(self._plot_figsize[0]),
output_height * self._plot_max_height / float(self._plot_figsize[1]))
fig.savefig(buf, dpi=dpi)
buf.seek(0)
img = Image.open(buf)
plt.close()

return img

对源代码进行了两处修改:

1
2
3
4
5
6
7
8
9
10
# 修改一
# 将
img.thumbnail((int(output_width * self._plot_max_width), int(output_height * self._plot_max_height)), Image.ANTIALIAS)
# 修改为了
img.thumbnail((int(output_width * self._plot_max_width), int(output_height * self._plot_max_height)))

# 修改二
# 注释掉了
# plt.legend(loc='upper right')
# 否则会报错

PoseClassifier

Code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import numpy as np
import csv
from Class.PoseSampleClass import PoseSample
from Class.PoseSampleOutlierClass import PoseSampleOutlier


class PoseClassifier(object):
"""Classifies pose landmarks."""

def __init__(self,
pose_samples_folder,
pose_embedder,
file_extension='csv',
file_separator=',',
n_landmarks=33,
n_dimensions=3,
top_n_by_max_distance=30,
top_n_by_mean_distance=10,
axes_weights=(1., 1., 0.2)):
self._pose_embedder = pose_embedder
self._n_landmarks = n_landmarks
self._n_dimensions = n_dimensions
self._top_n_by_max_distance = top_n_by_max_distance
self._top_n_by_mean_distance = top_n_by_mean_distance
self._axes_weights = axes_weights

self._pose_samples = self._load_pose_samples(pose_samples_folder,
file_extension,
file_separator,
n_landmarks,
n_dimensions,
pose_embedder)

def _load_pose_samples(self,
pose_samples_folder,
file_extension,
file_separator,
n_landmarks,
n_dimensions,
pose_embedder):
"""Loads pose samples from a given folder.

Required folder structure:
neutral_standing.csv
pushups_down.csv
pushups_up.csv
squats_down.csv
...

Required CSV structure:
sample_00001,x1,y1,z1,x2,y2,z2,....
sample_00002,x1,y1,z1,x2,y2,z2,....
...
"""
# Each file in the folder represents one pose class.
file_names = [name for name in os.listdir(pose_samples_folder) if name.endswith(file_extension)]

pose_samples = []
for file_name in file_names:
# Use file name as pose class name.
class_name = file_name[:-(len(file_extension) + 1)]

# Parse CSV.
with open(os.path.join(pose_samples_folder, file_name)) as csv_file:
csv_reader = csv.reader(csv_file, delimiter=file_separator)
for row in csv_reader:
assert len(row) == n_landmarks * n_dimensions + 1, 'Wrong number of values: {}'.format(len(row))
landmarks = np.array(row[1:], np.float32).reshape([n_landmarks, n_dimensions])
pose_samples.append(PoseSample(
name=row[0],
landmarks=landmarks,
class_name=class_name,
embedding=pose_embedder(landmarks),
))

return pose_samples

def find_pose_sample_outliers(self):
"""Classifies each sample against the entire database."""
# Find outliers in target poses
outliers = []
for sample in self._pose_samples:
# Find nearest poses for the target one.
pose_landmarks = sample.landmarks.copy()
pose_classification = self.__call__(pose_landmarks)
class_names = [class_name for class_name, count in pose_classification.items() if count == max(pose_classification.values())]

# Sample is an outlier if nearest poses have different class or more than
# one pose class is detected as nearest.
if sample.class_name not in class_names or len(class_names) != 1:
outliers.append(PoseSampleOutlier(sample, class_names, pose_classification))

return outliers

def __call__(self, pose_landmarks):
"""Classifies given pose.

Classification is done in two stages:
* First we pick top-N samples by MAX distance. It allows to remove samples
that are almost the same as given pose, but has few joints bent in the
other direction.
* Then we pick top-N samples by MEAN distance. After outliers are removed
on a previous step, we can pick samples that are closes on average.

Args:
pose_landmarks: NumPy array with 3D landmarks of shape (N, 3).

Returns:
Dictionary with count of nearest pose samples from the database. Sample:
{
'pushups_down': 8,
'pushups_up': 2,
}
"""
# Check that provided and target poses have the same shape.
assert pose_landmarks.shape == (self._n_landmarks, self._n_dimensions), 'Unexpected shape: {}'.format(pose_landmarks.shape)

# Get given pose embedding.
pose_embedding = self._pose_embedder(pose_landmarks)
flipped_pose_embedding = self._pose_embedder(pose_landmarks * np.array([-1, 1, 1]))

# Filter by max distance.
#
# That helps to remove outliers - poses that are almost the same as the
# given one, but has one joint bent into another direction and actually
# represnt a different pose class.
max_dist_heap = []
for sample_idx, sample in enumerate(self._pose_samples):
max_dist = min(
np.max(np.abs(sample.embedding - pose_embedding) * self._axes_weights),
np.max(np.abs(sample.embedding - flipped_pose_embedding) * self._axes_weights),
)
max_dist_heap.append([max_dist, sample_idx])

max_dist_heap = sorted(max_dist_heap, key=lambda x: x[0])
max_dist_heap = max_dist_heap[:self._top_n_by_max_distance]

# Filter by mean distance.
#
# After removing outliers we can find the nearest pose by mean distance.
mean_dist_heap = []
for _, sample_idx in max_dist_heap:
sample = self._pose_samples[sample_idx]
mean_dist = min(
np.mean(np.abs(sample.embedding - pose_embedding) * self._axes_weights),
np.mean(np.abs(sample.embedding - flipped_pose_embedding) * self._axes_weights),
)
mean_dist_heap.append([mean_dist, sample_idx])

mean_dist_heap = sorted(mean_dist_heap, key=lambda x: x[0])
mean_dist_heap = mean_dist_heap[:self._top_n_by_mean_distance]

# Collect results into map: (class_name -> n_samples)
class_names = [self._pose_samples[sample_idx].class_name for _, sample_idx in mean_dist_heap]
result = {class_name: class_names.count(class_name) for class_name in set(class_names)}

return result

PoseSample

Code:

1
2
3
4
5
6
7
8
class PoseSample(object):

def __init__(self, name, landmarks, class_name, embedding):
self.name = name
self.landmarks = landmarks
self.class_name = class_name

self.embedding = embedding

PoseSampleOutlier

Code:

1
2
3
4
5
6
class PoseSampleOutlier(object):

def __init__(self, sample, detected_class, all_classes):
self.sample = sample
self.detected_class = detected_class
self.all_classes = all_classes

RepetitionCounter

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class RepetitionCounter(object):
"""Counts number of repetitions of given target pose class."""

def __init__(self, class_name, enter_threshold=6, exit_threshold=4):
self._class_name = class_name

# If pose counter passes given threshold, then we enter the pose.
self._enter_threshold = enter_threshold
self._exit_threshold = exit_threshold

# Either we are in given pose or not.
self._pose_entered = False

# Number of times we exited the pose.
self._n_repeats = 0

@property
def n_repeats(self):
return self._n_repeats

def __call__(self, pose_classification):
"""Counts number of repetitions happend until given frame.

We use two thresholds. First you need to go above the higher one to enter
the pose, and then you need to go below the lower one to exit it. Difference
between the thresholds makes it stable to prediction jittering (which will
cause wrong counts in case of having only one threshold).

Args:
pose_classification: Pose classification dictionary on current frame.
Sample:
{
'pushups_down': 8.3,
'pushups_up': 1.7,
}

Returns:
Integer counter of repetitions.
"""
# Get pose confidence.
pose_confidence = 0.0
if self._class_name in pose_classification:
pose_confidence = pose_classification[self._class_name]

# On the very first frame or if we were out of the pose, just check if we
# entered it on this frame and update the state.
if not self._pose_entered:
self._pose_entered = pose_confidence > self._enter_threshold
return self._n_repeats

# If we were in the pose and are exiting it, then increase the counter and
# update the state.
if pose_confidence < self._exit_threshold:
self._n_repeats += 1
self._pose_entered = False
return self._n_repeats

6 个方法

Show

Code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import cv2
from matplotlib import pyplot as plt


def show_image(img):
"""Shows output PIL image."""
plt.figure(figsize=(10, 10))
plt.imshow(img)
plt.show()


def look_image(img):
# opencv读入图像格式为BGR,matplotlib可视化格式为RGB,因此需将BGR转RGB
img_RGB = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10, 10))
plt.imshow(img_RGB)
plt.show()

ProcessFrame

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import time
import cv2
from mediapipe.python.solutions import drawing_utils as mp_drawing
from mediapipe.python.solutions import pose as mp_pose


def process_frame(img, pose):
# 记录该帧开始处理的时间
start_time = time.time()

# 获取图像宽高
h, w = img.shape[0], img.shape[1]

# BGR转RGB
img_RGB = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# 将RGB图像输入模型,获取预测结果
results = pose.process(img_RGB)

if results.pose_landmarks: # 若检测出人体关键点

# 可视化关键点及骨架连线
mp_drawing.draw_landmarks(img, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)

for i in range(33): # 遍历所有33个关键点,可视化

# 获取该关键点的三维坐标
cx = int(results.pose_landmarks.landmark[i].x * w)
cy = int(results.pose_landmarks.landmark[i].y * h)
cz = results.pose_landmarks.landmark[i].z

radius = 10

if i == 0: # 鼻尖
img = cv2.circle(img, (cx, cy), radius, (0, 0, 255), -1)
elif i in [11, 12]: # 肩膀
img = cv2.circle(img, (cx, cy), radius, (223, 155, 6), -1)
elif i in [23, 24]: # 髋关节
img = cv2.circle(img, (cx, cy), radius, (1, 240, 255), -1)
elif i in [13, 14]: # 胳膊肘
img = cv2.circle(img, (cx, cy), radius, (140, 47, 240), -1)
elif i in [25, 26]: # 膝盖
img = cv2.circle(img, (cx, cy), radius, (0, 0, 255), -1)
elif i in [15, 16, 27, 28]: # 手腕和脚腕
img = cv2.circle(img, (cx, cy), radius, (223, 155, 60), -1)
elif i in [17, 19, 21]: # 左手
img = cv2.circle(img, (cx, cy), radius, (94, 218, 121), -1)
elif i in [18, 20, 22]: # 右手
img = cv2.circle(img, (cx, cy), radius, (16, 144, 247), -1)
elif i in [27, 29, 31]: # 左脚
img = cv2.circle(img, (cx, cy), radius, (29, 123, 243), -1)
elif i in [28, 30, 32]: # 右脚
img = cv2.circle(img, (cx, cy), radius, (193, 182, 255), -1)
elif i in [9, 10]: # 嘴
img = cv2.circle(img, (cx, cy), radius, (205, 235, 255), -1)
elif i in [1, 2, 3, 4, 5, 6, 7, 8]: # 眼及脸颊
img = cv2.circle(img, (cx, cy), radius, (94, 218, 121), -1)
else: # 其它关键点
img = cv2.circle(img, (cx, cy), radius, (0, 255, 0), -1)

# 展示图片
# look_img(img)

else:
scaler = 1
failure_str = 'No Person'
img = cv2.putText(img, failure_str, (25 * scaler, 100 * scaler), cv2.FONT_HERSHEY_SIMPLEX, 1.25 * scaler, (255, 0, 255), 2 * scaler)
# print('从图像中未检测出人体关键点,报错。')

# 记录该帧处理完毕的时间
end_time = time.time()
# 计算每秒处理图像帧数FPS
FPS = 1 / (end_time - start_time)

scaler = 1
# 在图像上写FPS数值,参数依次为:图片,添加的文字,左上角坐标,字体,字体大小,颜色,字体粗细
img = cv2.putText(img, 'FPS ' + str(int(FPS)), (25 * scaler, 50 * scaler), cv2.FONT_HERSHEY_SIMPLEX, 1.25 * scaler, (255, 0, 255), 2 * scaler)
return img

Image

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import cv2
import mediapipe as mp
import matplotlib.pyplot as plt
import numpy as np

from Function.Show import look_image


def image(path):

# 导入solution
mp_pose = mp.solutions.pose
# # 导入绘图函数
mp_drawing = mp.solutions.drawing_utils

# 导入模型
pose = mp_pose.Pose(static_image_mode=True, # 是静态图片还是连续视频帧
model_complexity=2, # 选择人体姿态关键点检测模型,0性能差但快,2性能好但慢,1介于两者之间
smooth_landmarks=True, # 是否平滑关键点
enable_segmentation=True, # 是否人体抠图
min_detection_confidence=0.5, # 置信度阈值
min_tracking_confidence=0.5) # 追踪阈值

# 从图片文件读入图像,opencv读入为BGR格式
img = cv2.imread(path)
print('original')
print('-----')
look_image(img)

# BGR转RGB
img_RGB = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# 将RGB图像输入模型,获取预测结果
results = pose.process(img_RGB)

mp_drawing.draw_landmarks(img, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
print('result')
print('-----')
look_image(img)

# 三维真实物理坐标系
print('world')
print('-----')
mp_drawing.plot_landmarks(results.pose_world_landmarks, mp_pose.POSE_CONNECTIONS)

# 抠图
# mask 表示每个像素对应人体的概率
mask = results.segmentation_mask
print(mask.shape)
print(img.shape)
# 阈值过滤
mask = mask > 0.5

print('-----')
print('mask')
print('-----')
plt.imshow(mask)
plt.show()

# 单通道转三通道
mask_3 = np.stack((mask, mask, mask), axis=-1)
MASK_COLOR = [0, 200, 0]
fg_image = np.zeros(img.shape, dtype=np.uint8)
fg_image[:] = MASK_COLOR
# 获得前景人像
FG_img = np.where(mask_3, img, fg_image)
# 获得抠掉前景人像的背景
BG_img = np.where(~mask_3, img, fg_image)

print('forward')
print('-----')
look_image(FG_img)

print('back')
print('-----')
look_image(BG_img)

# 所有关键点的检测结果
# print(results.pose_landmarks)
# 关节连接
# print(mp_pose.POSE_CONNECTIONS)

# 左胳膊肘关键点的归一化坐标
print(results.pose_landmarks.landmark[mp_pose.PoseLandmark.LEFT_ELBOW])
# results.pose_landmarks.landmark[13]
# results.pose_landmarks.landmark[13].x

# img.shape = (high, width, channel)
h = img.shape[0]
w = img.shape[1]

# 左胳膊肘关键点像素横坐标
print(results.pose_landmarks.landmark[mp_pose.PoseLandmark.LEFT_ELBOW].x * w)
# 左胳膊肘关键点像素纵坐标
print(results.pose_landmarks.landmark[mp_pose.PoseLandmark.LEFT_ELBOW].y * h)

# 获取该关键点的三维坐标
cx = int(results.pose_landmarks.landmark[mp_pose.PoseLandmark.LEFT_KNEE].x * w)
cy = int(results.pose_landmarks.landmark[mp_pose.PoseLandmark.LEFT_KNEE].y * h)
cz = results.pose_landmarks.landmark[mp_pose.PoseLandmark.LEFT_KNEE].z

# 绘制圆,目标图像,圆心坐标,半径,BGR颜色,线宽 -1 表示填充
img = cv2.circle(img, (cx, cy), 15, (255, 0, 0), -1)
print('-----')
print('knee')
print('-----')
look_image(img)

for i in range(33):
# 获取该关键点的三维坐标
cx = int(results.pose_landmarks.landmark[i].x * w)
cy = int(results.pose_landmarks.landmark[i].y * h)
cz = results.pose_landmarks.landmark[i].z

radius = 10

if i == 0: # 鼻尖
img = cv2.circle(img, (cx, cy), radius, (0, 0, 255), -1)
elif i in [11, 12]: # 肩膀
img = cv2.circle(img, (cx, cy), radius, (223, 155, 6), -1)
elif i in [23, 24]: # 髋关节
img = cv2.circle(img, (cx, cy), radius, (1, 240, 255), -1)
elif i in [13, 14]: # 胳膊肘
img = cv2.circle(img, (cx, cy), radius, (140, 47, 240), -1)
elif i in [25, 26]: # 膝盖
img = cv2.circle(img, (cx, cy), radius, (0, 0, 255), -1)
elif i in [15, 16, 27, 28]: # 手腕和脚腕
img = cv2.circle(img, (cx, cy), radius, (223, 155, 60), -1)
elif i in [17, 19, 21]: # 左手
img = cv2.circle(img, (cx, cy), radius, (94, 218, 121), -1)
elif i in [18, 20, 22]: # 右手
img = cv2.circle(img, (cx, cy), radius, (16, 144, 247), -1)
elif i in [27, 29, 31]: # 左脚
img = cv2.circle(img, (cx, cy), radius, (29, 123, 243), -1)
elif i in [28, 30, 32]: # 右脚
img = cv2.circle(img, (cx, cy), radius, (193, 182, 255), -1)
elif i in [9, 10]: # 嘴
img = cv2.circle(img, (cx, cy), radius, (205, 235, 255), -1)
elif i in [1, 2, 3, 4, 5, 6, 7, 8]: # 眼及脸颊
img = cv2.circle(img, (cx, cy), radius, (94, 218, 121), -1)
else: # 其它关键点
img = cv2.circle(img, (cx, cy), radius, (0, 255, 0), -1)

print('all')
look_image(img)

HandleWithDataset

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import csv
import os
from Class.BootstrapHelperClass import BootstrapHelper
from Class.FullBodyPoseEmbedderClass import FullBodyPoseEmbedder
from Class.PoseClassifierClass import PoseClassifier


def handle_with_dataset(images_in_folder_path, images_out_folder_path, csvs_out_folder_path):

bootstrap_images_in_folder = images_in_folder_path
bootstrap_images_out_folder = images_out_folder_path
bootstrap_csvs_out_folder = csvs_out_folder_path

bootstrap_helper = BootstrapHelper(
images_in_folder=bootstrap_images_in_folder,
images_out_folder=bootstrap_images_out_folder,
csvs_out_folder=bootstrap_csvs_out_folder
)

# 检测每个动作有多少张图
print('dataset')
bootstrap_helper.print_images_in_statistics()
print('------')

# 提取特征
bootstrap_helper.bootstrap(per_pose_class_limit=None)

# 检查每个动作有多少张图提取了特征
print('recognize')
bootstrap_helper.print_images_out_statistics()
print('------')

pose_embedder = FullBodyPoseEmbedder()

# 参数
pose_classifier = PoseClassifier(
pose_samples_folder=bootstrap_csvs_out_folder,
pose_embedder=pose_embedder,
top_n_by_max_distance=30,
top_n_by_mean_distance=10
)

# 过滤异常数据
outliers = pose_classifier.find_pose_sample_outliers()
print('Number of outliers: ', len(outliers))
print('------')

# 看一下异常数据
bootstrap_helper.analyze_outliers(outliers)

# 移除异常
bootstrap_helper.remove_outliers(outliers)

print('------')
bootstrap_helper.align_images_and_csvs(print_removed_items=False)
print('last')
bootstrap_helper.print_images_out_statistics()


def dump_for_the_app(path):
pose_samples_folder = path
pose_samples_csv_path = 'squat_csvs_out_basic.csv'
file_extension = 'csv'
file_separator = ','

# Each file in the folder represents one pose class.
file_names = [name for name in os.listdir(pose_samples_folder) if name.endswith(file_extension)]

with open(pose_samples_csv_path, 'w') as csv_out:
csv_out_writer = csv.writer(csv_out, delimiter=file_separator, quoting=csv.QUOTE_MINIMAL)
for file_name in file_names:
# Use file name as pose class name.
class_name = file_name[:-(len(file_extension) + 1)]

# One file line: `sample_00001,x1,y1,x2,y2,....`.
with open(os.path.join(pose_samples_folder, file_name)) as csv_in:
csv_in_reader = csv.reader(csv_in, delimiter=file_separator)
for row in csv_in_reader:
row.insert(1, class_name)
csv_out_writer.writerow(row)

Camera

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import cv2
import mediapipe as mp
from Function.ProcessFrame import process_frame


def camera():
# 导入solution
mp_pose = mp.solutions.pose

# 导入模型
pose = mp_pose.Pose(static_image_mode=False, # 是静态图片还是连续视频帧
model_complexity=2, # 选择人体姿态关键点检测模型,0性能差但快,2性能好但慢,1介于两者之间
smooth_landmarks=True, # 是否平滑关键点
min_detection_confidence=0.5, # 置信度阈值
min_tracking_confidence=0.5) # 追踪阈值

cap = cv2.VideoCapture(1)
cap.open(0)

# 无限循环,直到break被触发
while cap.isOpened():
# 获取画面
success, frame = cap.read()
if not success:
break

## !!!处理帧函数
frame = process_frame(img=frame, pose=pose)

# 展示处理后的三通道图像
cv2.imshow('my_window', frame)

if cv2.waitKey(1) in [ord('q'), 27]: # 按键盘上的q或esc退出(在英文输入法下)
break

# 关闭摄像头
cap.release()

# 关闭图像窗口
cv2.destroyAllWindows()

Application

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import numpy as np
import cv2

from Class.FullBodyPoseEmbedderClass import FullBodyPoseEmbedder
from Class.PoseClassifierClass import PoseClassifier
from Class.EMADictSmoothingClass import EMADictSmoothing
from Class.PoseClassificationVisualizerClass import PoseClassificationVisualizer
from Class.RepetitionCounterClass import RepetitionCounter

from mediapipe.python.solutions import drawing_utils as mp_drawing
from mediapipe.python.solutions import pose as mp_pose

from Function.Show import show_image, look_image
from Function.ProcessFrame import process_frame


def application(datasource, v_path, c_path, key, enter_threshold, exit_threshold, frames, is_debug):

if datasource == 'video':
video_path = v_path
cap = cv2.VideoCapture(video_path)
cap_n_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT) # CAP_PROP_FRAME_COUNT 获取视频帧数 fps * time
elif datasource == 'camera':
cap = cv2.VideoCapture(1)
cap.open(0)
cap_n_frames = frames

print('cap‘s frames is ', cap_n_frames)

class_name = key
pose_samples_folder = c_path
pose_tracker = mp_pose.Pose()

# Initialize embedder.
pose_embedder = FullBodyPoseEmbedder()

# Initialize classifier.
# Check that you are using the same parameters as during bootstrapping.
# 参数
pose_classifier = PoseClassifier(
pose_samples_folder=pose_samples_folder,
pose_embedder=pose_embedder,
top_n_by_max_distance=30,
top_n_by_mean_distance=10
)

# Initialize EMA smoothing.
# 参数
pose_classification_filter = EMADictSmoothing(
window_size=10,
alpha=0.2)

# 指定动作的两个阈值
# 参数
repetition_counter = RepetitionCounter(
class_name=class_name,
enter_threshold=enter_threshold,
exit_threshold=exit_threshold)

# Initialize renderer.
# 参数
pose_classification_visualizer = PoseClassificationVisualizer(
class_name=class_name,
plot_x_max=cap_n_frames,
# Graphic looks nicer if it's the same as `top_n_by_mean_distance`.
plot_y_max=10) # graph的Y轴高度

# 无限循环,直到break被触发
# while cap.isOpen():
while True:
# 获取画面
success, input_frame = cap.read()
if not success:
break

if cv2.waitKey(1) in [ord('q'), 27]: # 按键盘上的q或esc退出(在英文输入法下)
break

# Run pose tracker.
input_frame = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB)

if is_debug:
print('STEP I')
print('Image Type is ', type(input_frame)) # <class 'numpy.ndarray'>
print('Image Shape is ', input_frame.shape)
show_image(input_frame)
look_image(input_frame)
print('-----')

result = pose_tracker.process(image=input_frame)
pose_landmarks = result.pose_landmarks

# Draw pose prediction.
output_frame = input_frame.copy()
# 有肢体数据
if pose_landmarks is not None:
mp_drawing.draw_landmarks(image=output_frame, landmark_list=pose_landmarks, connections=mp_pose.POSE_CONNECTIONS)

if is_debug:
print('STEP II')
print('Image Type is ', type(output_frame)) # <class 'numpy.ndarray'>
print('Image Shape is ', output_frame.shape)
show_image(output_frame)
look_image(output_frame)
print('-----')

if pose_landmarks is not None:
# Get landmarks.
frame_height, frame_width = output_frame.shape[0], output_frame.shape[1]
pose_landmarks = np.array([[lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width]
for lmk in pose_landmarks.landmark], dtype=np.float32)
assert pose_landmarks.shape == (33, 3), 'Unexpected landmarks shape: {}'.format(pose_landmarks.shape)

# Classify the pose on the current frame.
pose_classification = pose_classifier(pose_landmarks)

# Smooth classification using EMA.
pose_classification_filtered = pose_classification_filter(pose_classification)

# Count repetitions.
repetitions_count = repetition_counter(pose_classification_filtered)

else:
# No pose => no classification on current frame.
pose_classification = None

# Still add empty classification to the filter to maintaing correct
# smoothing for future frames.
pose_classification_filtered = pose_classification_filter(dict())
pose_classification_filtered = None

# Don't update the counter presuming that person is 'frozen'. Just
# take the latest repetitions count.
repetitions_count = repetition_counter.n_repeats

# Draw classification plot and repetition counter.
output_frame = pose_classification_visualizer(
frame=output_frame,
pose_classification=pose_classification,
pose_classification_filtered=pose_classification_filtered,
repetitions_count=repetitions_count)

if is_debug:
print('STEP III')
print('Image Type is ', type(output_frame)) # <class 'PIL.Image.Image'>
print('Image Shape is ', np.array(output_frame).shape) # Image no Shape
show_image(output_frame)
# look_image(output_frame) # 无法打开Image格式
print('-----')

# PIL.Image格式转OpenCV
output_frame = cv2.cvtColor(np.asarray(output_frame), cv2.COLOR_RGB2BGR)

# cv2.imshow('camera', output_frame)

if is_debug:
print('STEP IV')
print(type(output_frame)) # <class 'numpy.ndarray'>
print(output_frame.shape)
show_image(output_frame)
look_image(output_frame)
print('-----')

output_frame = process_frame(img=output_frame, pose=pose_tracker)

if is_debug:
print('STEP V')
print(type(output_frame)) # <class 'numpy.ndarray'>
print(output_frame.shape)
show_image(output_frame)
look_image(output_frame)
print('-----')

# 展示处理后的三通道图像
# cv2.namedWindow("video", 0)
# cv2.resizeWindow("video", 1080, 1920)
cv2.imshow('camera', output_frame)

# 关闭摄像头
cap.release()

# 关闭图像窗口
cv2.destroyAllWindows()

主函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from Function.HandleWithDataset import handle_with_dataset, dump_for_the_app
from Function.Application import application
from Function.Camera import camera
from Function.Image import image

if __name__ == '__main__':

dataset_path = '/Users/bakako/Downloads/dataset' # 数据集文件夹路径
images_out_path = '/Users/bakako/Downloads/images_out' # 关键点检测后的图像文件夹
csvs_out_path = '/Users/bakako/Downloads/csvs_out' # 关键点检测后的坐标文件夹
video_path = '/Users/bakako/Downloads/squat-4.mp4' # 视频
image_path = '/Users/bakako/Downloads/pose.jpg' # 图片
key = 'down'
is_debug = False

# yolo是做目标检测的,不是做关键点检测的
def switch(num):
if num == 0:
image(path=image_path)
elif num == 1:
camera()
elif num == 2:
handle_with_dataset(dataset_path, images_out_path, csvs_out_path)
# dump_for_the_app(csvs_out_path)
elif num == 3:
"""
datasource: camera or video
v_path: 视频路径
c_path: 坐标文件夹
key:
enter_threshold: 置信度判定阈值
exit_threshold: 置信度回落阈值
frames: 帧数*时间
is_debug:
"""
application(datasource='camera', v_path=None, c_path=csvs_out_path,
key=key, enter_threshold=4, exit_threshold=2, frames=300, is_debug=False)
elif num == 4:
application(datasource='video', v_path=video_path, c_path=csvs_out_path,
key=key, enter_threshold=4, exit_threshold=2, frames=None, is_debug=False)
else:
print('Error Number')

switch(0)
"""
0:识别图片
1: 调用摄像头
2: 训练数据集
3: 摄像头计数
4: 视频计数
"""

Mediapipe 实现姿态识别
https://wonderhoi.com/2023/09/08/Mediapipe-实现姿态识别/
作者
wonderhoi
发布于
2023年9月8日
许可协议