1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
| ''' 导入相关包 ''' import wfdb import pywt import seaborn import numpy as np import matplotlib.pyplot as plt from sklearn.metrics import confusion_matrix import torch import torch.utils.data as Data from torch import nn
''' 加载数据集 '''
RATIO = 0.2
def denoise(data): coeffs = pywt.wavedec(data=data, wavelet='db5', level=9) cA9, cD9, cD8, cD7, cD6, cD5, cD4, cD3, cD2, cD1 = coeffs threshold = (np.median(np.abs(cD1)) / 0.6745) * (np.sqrt(2 * np.log(len(cD1)))) cD1.fill(0) cD2.fill(0) for i in range(1, len(coeffs) - 2): coeffs[i] = pywt.threshold(coeffs[i], threshold) rdata = pywt.waverec(coeffs=coeffs, wavelet='db5') return rdata
def getDataSet(number, X_data, Y_data): ecgClassSet = ['N', 'A', 'V', 'L', 'R'] record = wfdb.rdrecord('C:/mycode/dataset_make/mit-bih-arrhythmia-database-1.0.0/' + number, channel_names=['MLII']) data = record.p_signal.flatten() rdata = denoise(data=data) annotation = wfdb.rdann('C:/mycode/dataset_make/mit-bih-arrhythmia-database-1.0.0/' + number, 'atr') Rlocation = annotation.sample Rclass = annotation.symbol start = 10 end = 5 i = start j = len(annotation.symbol) - end while i < j: try: lable = ecgClassSet.index(Rclass[i]) x_train = rdata[Rlocation[i] - 100:Rlocation[i] + 200] X_data.append(x_train) Y_data.append(lable) i += 1 except ValueError: i += 1 return
def loadData(): numberSet = ['100', '101', '103', '105', '106', '107', '108', '109', '111', '112', '113', '114', '115', '116', '117', '119', '121', '122', '123', '124', '200', '201', '202', '203', '205', '208', '210', '212', '213', '214', '215', '217', '219', '220', '221', '222', '223', '228', '230', '231', '232', '233', '234'] dataSet = [] lableSet = [] for n in numberSet: getDataSet(n, dataSet, lableSet) dataSet = np.array(dataSet).reshape(-1, 300) lableSet = np.array(lableSet).reshape(-1, 1) train_ds = np.hstack((dataSet, lableSet)) np.random.shuffle(train_ds) X = train_ds[:, :300].reshape(-1, 1, 300) Y = train_ds[:, 300]
shuffle_index = np.random.permutation(len(X))
test_length = int(RATIO * len(shuffle_index)) test_index = shuffle_index[:test_length] train_index = shuffle_index[test_length:] X_test, Y_test = X[test_index], Y[test_index] X_train, Y_train = X[train_index], Y[train_index] return X_train, Y_train, X_test, Y_test
X_train, Y_train, X_test, Y_test = loadData()
''' 数据处理 ''' train_Data = Data.TensorDataset(torch.Tensor(X_train), torch.Tensor(Y_train)) train_loader = Data.DataLoader(dataset=train_Data, batch_size=128) test_Data = Data.TensorDataset(torch.Tensor(X_test), torch.Tensor(Y_test)) test_loader = Data.DataLoader(dataset=test_Data, batch_size=128)
''' 模型搭建 ''' class RnnModel(nn.Module): def __init__(self): super(RnnModel, self).__init__() ''' 参数解释:(输入维度,隐藏层维度,网络层数) ''' self.rnn = nn.RNN(300, 50, 3, nonlinearity='tanh') self.linear = nn.Linear(50, 5)
def forward(self, x): r_out, h_state = self.rnn(x) output = self.linear(r_out[:,-1,:]) return output
model = RnnModel()
''' 设置损失函数和参数优化方法 ''' criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
''' 模型训练 ''' EPOCHS = 5 for epoch in range(EPOCHS): running_loss = 0 for i, data in enumerate(train_loader): inputs, label = data y_predict = model(inputs) loss = criterion(y_predict, label.long()) optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item()
correct = 0 total = 0 with torch.no_grad(): for data in test_loader: inputs, label = data y_pred = model(inputs) _, predicted = torch.max(y_pred.data, dim=1) total += label.size(0) correct += (predicted == label).sum().item()
print(f'Epoch: {epoch + 1}, ACC on test: {correct / total}')
|