darasb commited on
Commit
d686ee3
·
verified ·
1 Parent(s): f872df4

Upload 3 files

Browse files
Files changed (3) hide show
  1. model_epoch_36.pth +3 -0
  2. train.py +263 -0
  3. utils.py +134 -0
model_epoch_36.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a670da1c1ba61f510c1c5957ce863be26a12837e9f96f320636c92a43eee83ad
3
+ size 20397970
train.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # remember to run preprocess.py before training
2
+ # preprocess while training is not as effecient
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch.nn import MultiheadAttention
8
+ import torch.optim as optim
9
+ from torch.utils.data import Dataset, DataLoader, random_split
10
+ import json
11
+ import time
12
+ import os
13
+ import h5py
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+
17
+ class AttentionBlock(nn.Module):
18
+ def __init__(self, input_dim, num_heads, key_dim, ff_dim, rate=0.1):
19
+ super(AttentionBlock, self).__init__()
20
+ self.multihead_attn = MultiheadAttention(embed_dim=input_dim, num_heads=num_heads)
21
+ self.dropout1 = nn.Dropout(rate)
22
+ self.layer_norm1 = nn.LayerNorm(input_dim, eps=1e-6)
23
+
24
+ self.ffn = nn.Sequential(
25
+ nn.Linear(input_dim, ff_dim),
26
+ nn.ReLU(),
27
+ nn.Dropout(rate),
28
+ nn.Linear(ff_dim, input_dim),
29
+ nn.Dropout(rate)
30
+ )
31
+ self.layer_norm2 = nn.LayerNorm(input_dim, eps=1e-6)
32
+
33
+ def forward(self, x):
34
+ attn_output, _ = self.multihead_attn(x, x, x)
35
+ attn_output = self.dropout1(attn_output)
36
+ out1 = self.layer_norm1(x + attn_output)
37
+
38
+ ffn_output = self.ffn(out1)
39
+ out2 = self.layer_norm2(out1 + ffn_output)
40
+
41
+ return out2
42
+
43
+ class TextureContrastClassifier(nn.Module):
44
+ def __init__(self, input_shape, num_heads=4, key_dim=64, ff_dim=256, rate=0.5):
45
+ super(TextureContrastClassifier, self).__init__()
46
+ input_dim = input_shape[1] # assuming the input shape is (seq_len, feature_dim)
47
+
48
+ self.rich_texture_attention = AttentionBlock(input_dim, num_heads, key_dim, ff_dim, rate)
49
+ self.poor_texture_attention = AttentionBlock(input_dim, num_heads, key_dim, ff_dim, rate)
50
+
51
+ self.rich_texture_dense = nn.Sequential(
52
+ nn.Linear(input_dim, 128),
53
+ nn.ReLU(),
54
+ nn.Dropout(rate)
55
+ )
56
+
57
+ self.poor_texture_dense = nn.Sequential(
58
+ nn.Linear(input_dim, 128),
59
+ nn.ReLU(),
60
+ nn.Dropout(rate)
61
+ )
62
+
63
+ self.fc = nn.Sequential(
64
+ nn.Flatten(),
65
+ nn.Linear(input_shape[0] * 128, 256),
66
+ nn.ReLU(),
67
+ nn.Dropout(rate),
68
+ nn.Linear(256, 128),
69
+ nn.ReLU(),
70
+ nn.Dropout(rate),
71
+ nn.Linear(128, 64),
72
+ nn.ReLU(),
73
+ nn.Dropout(rate),
74
+ nn.Linear(64, 32),
75
+ nn.ReLU(),
76
+ nn.Dropout(rate),
77
+ nn.Linear(32, 16),
78
+ nn.ReLU(),
79
+ nn.Dropout(rate),
80
+ nn.Linear(16, 1),
81
+ nn.Sigmoid()
82
+ )
83
+
84
+ def forward(self, rich_texture, poor_texture):
85
+ rich_texture = self.rich_texture_attention(rich_texture)
86
+ rich_texture = self.rich_texture_dense(rich_texture)
87
+
88
+ poor_texture = self.poor_texture_attention(poor_texture)
89
+ poor_texture = self.poor_texture_dense(poor_texture)
90
+
91
+ difference = rich_texture - poor_texture
92
+ output = self.fc(difference)
93
+
94
+ return output
95
+
96
+ import os
97
+ import h5py
98
+ import numpy as np
99
+ from tqdm import tqdm
100
+
101
+ def load_and_split_data(h5_dir, train_ratio=0.8,max_num=40):
102
+ train_rich, train_poor, train_labels = [], [], []
103
+ test_rich, test_poor, test_labels = [], [], []
104
+
105
+ for file_name in tqdm(os.listdir(h5_dir)[:60]):
106
+ if file_name.endswith('.h5'):
107
+ file_path = os.path.join(h5_dir, file_name)
108
+ try:
109
+ with h5py.File(file_path, 'r') as h5f:
110
+ rich = h5f['rich'][:]
111
+ poor = h5f['poor'][:]
112
+ labels = h5f['labels'][:]
113
+
114
+ dataset_size = len(labels)
115
+ train_size = int(train_ratio * dataset_size)
116
+ indices = np.random.permutation(dataset_size)
117
+ train_indices = indices[:train_size]
118
+ test_indices = indices[train_size:]
119
+
120
+ train_rich.append(rich[train_indices])
121
+ train_poor.append(poor[train_indices])
122
+ train_labels.append(labels[train_indices])
123
+
124
+ test_rich.append(rich[test_indices])
125
+ test_poor.append(poor[test_indices])
126
+ test_labels.append(labels[test_indices])
127
+
128
+ except Exception as e:
129
+ print(f"Error processing {file_name}: {e}")
130
+
131
+ train_rich = np.concatenate(train_rich, axis=0)
132
+ train_poor = np.concatenate(train_poor, axis=0)
133
+ train_labels = np.concatenate(train_labels, axis=0)
134
+
135
+ test_rich = np.concatenate(test_rich, axis=0)
136
+ test_poor = np.concatenate(test_poor, axis=0)
137
+ test_labels = np.concatenate(test_labels, axis=0)
138
+
139
+ return train_rich, train_poor, train_labels, test_rich, test_poor, test_labels
140
+
141
+ class TextureDataset(Dataset):
142
+ def __init__(self, rich, poor, labels):
143
+ self.rich = rich
144
+ self.poor = poor
145
+ self.labels = labels
146
+
147
+ def __len__(self):
148
+ return len(self.labels)
149
+
150
+ def __getitem__(self, idx):
151
+ rich = torch.tensor(self.rich[idx], dtype=torch.float32)
152
+ poor = torch.tensor(self.poor[idx], dtype=torch.float32)
153
+ label = torch.tensor(self.labels[idx], dtype=torch.float32)
154
+ return rich, poor, label
155
+
156
+ def validate(model, test_loader, criterion, device):
157
+ model.eval()
158
+ val_loss = 0.0
159
+ correct = 0
160
+ total = 0
161
+
162
+ with torch.no_grad():
163
+ for rich, poor, labels in test_loader:
164
+ rich, poor, labels = rich.to(device), poor.to(device), labels.to(device)
165
+
166
+ outputs = model(rich, poor)
167
+ outputs = outputs.squeeze()
168
+
169
+ loss = criterion(outputs, labels)
170
+ val_loss += loss.item()
171
+
172
+ predicted = (outputs > 0.5).float()
173
+ total += labels.size(0)
174
+ correct += (predicted == labels).sum().item()
175
+
176
+ val_loss /= len(test_loader)
177
+ val_accuracy = correct / total
178
+ print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')
179
+ return val_loss, val_accuracy
180
+
181
+
182
+
183
+ h5_dir = '/content/drive/MyDrive/h5saves'
184
+ train_rich, train_poor, train_labels, test_rich, test_poor, test_labels = load_and_split_data(h5_dir, train_ratio=0.8)
185
+ print(f"Training data: {len(train_labels)} samples")
186
+ print(f"Testing data: {len(test_labels)} samples")
187
+ train_dataset = TextureDataset(train_rich, train_poor, train_labels)
188
+ test_dataset = TextureDataset(test_rich, test_poor, test_labels)
189
+ batch_size = 2048
190
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
191
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
192
+
193
+ input_shape = (128, 256)
194
+ model = TextureContrastClassifier(input_shape)
195
+ criterion = nn.BCELoss()
196
+ optimizer = optim.Adam(model.parameters(), lr=0.0001)
197
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)
198
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
199
+ model.to(device)
200
+
201
+ history = {'train_loss': [], 'val_loss': [], 'train_accuracy':[], 'val_accuracy': []}
202
+ save_dir = '/content/drive/MyDrive/model_checkpoints'
203
+ if not os.path.exists(save_dir):
204
+ os.makedirs(save_dir)
205
+ num_epochs = 100
206
+
207
+
208
+
209
+ for epoch in range(num_epochs):
210
+ model.train()
211
+ running_loss = 0.0
212
+ correct = 0
213
+ total = 0
214
+
215
+ batch_loss = 0.0
216
+
217
+ for batch_idx, (rich, poor, labels) in enumerate(train_loader):
218
+ rich, poor, labels = rich.to(device), poor.to(device), labels.to(device)
219
+
220
+ optimizer.zero_grad()
221
+
222
+ outputs = model(rich, poor)
223
+ outputs = outputs.squeeze()
224
+
225
+ loss = criterion(outputs, labels)
226
+ loss.backward()
227
+ optimizer.step()
228
+
229
+ running_loss += loss.item()
230
+ batch_loss += loss.item()
231
+
232
+ predicted = (outputs > 0.5).float()
233
+ total += labels.size(0)
234
+ correct += (predicted == labels).sum().item()
235
+
236
+ if (batch_idx + 1) % 5 == 0:
237
+ print(f'\rEpoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}], Loss: {batch_loss / 5:.4f}, Accuracy: {correct / total:.2f}', end='')
238
+ batch_loss = 0.0
239
+
240
+ avg_train_loss = running_loss / len(train_loader)
241
+ train_accuracy = correct / total
242
+
243
+ val_loss, val_accuracy = validate(model, test_loader, criterion, device)
244
+
245
+ history['train_loss'].append(avg_train_loss)
246
+ history['val_loss'].append(val_loss)
247
+ history['val_accuracy'].append(val_accuracy)
248
+ history['train_accuracy'].append(train_accuracy)
249
+
250
+ scheduler.step(val_loss)
251
+
252
+ checkpoint_path = os.path.join(save_dir, f'model_epoch_{epoch+1}.pth')
253
+ torch.save(model.state_dict(), checkpoint_path)
254
+ print(f'\nModel checkpoint saved for epoch {epoch+1}')
255
+
256
+ print(f'Epoch [{epoch+1}/{num_epochs:.4f}], Training Loss: {avg_train_loss:.4f}, Training Accuracy: {train_accuracy:.4f} Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')
257
+
258
+ history_path = os.path.join(save_dir, 'training_history.json')
259
+ with open(history_path, 'w') as f:
260
+ json.dump(history, f)
261
+
262
+ print('Finished Training')
263
+ print(f'Training history saved at {history_path}')
utils.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import PIL.Image
4
+ from scipy.interpolate import griddata
5
+
6
+ def RGB2gray(rgb):
7
+ r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2]
8
+ gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
9
+ return gray
10
+
11
+ def img_to_patches(img: PIL.Image.Image) -> tuple:
12
+ patch_size = 16
13
+ img = img.convert('RGB')
14
+
15
+ grayscale_imgs = []
16
+ imgs = []
17
+ coordinates = []
18
+
19
+ for i in range(0, img.height, patch_size):
20
+ for j in range(0, img.width, patch_size):
21
+ box = (j, i, j + patch_size, i + patch_size)
22
+ img_color = np.asarray(img.crop(box))
23
+ grayscale_image = cv2.cvtColor(src=img_color, code=cv2.COLOR_RGB2GRAY)
24
+ grayscale_imgs.append(grayscale_image.astype(dtype=np.int32))
25
+ imgs.append(img_color)
26
+ normalized_coord = (i + patch_size // 2, j + patch_size // 2)
27
+ coordinates.append(normalized_coord)
28
+
29
+ return grayscale_imgs, imgs, coordinates, (img.height, img.width)
30
+
31
+ def get_l1(v):
32
+ return np.sum(np.abs(v[:, :-1] - v[:, 1:]))
33
+
34
+ def get_l2(v):
35
+ return np.sum(np.abs(v[:-1, :] - v[1:, :]))
36
+
37
+ def get_l3l4(v):
38
+ l3 = np.sum(np.abs(v[:-1, :-1] - v[1:, 1:]))
39
+ l4 = np.sum(np.abs(v[1:, :-1] - v[:-1, 1:]))
40
+ return l3 + l4
41
+
42
+ def get_pixel_var_degree_for_patch(patch: np.array) -> int:
43
+ l1 = get_l1(patch)
44
+ l2 = get_l2(patch)
45
+ l3l4 = get_l3l4(patch)
46
+ return l1 + l2 + l3l4
47
+
48
+ def get_rich_poor_patches(img: PIL.Image.Image, coloured=True):
49
+ gray_scale_patches, color_patches, coordinates, img_size = img_to_patches(img)
50
+ var_with_patch = []
51
+ for i, patch in enumerate(gray_scale_patches):
52
+ if coloured:
53
+ var_with_patch.append((get_pixel_var_degree_for_patch(patch), color_patches[i], coordinates[i]))
54
+ else:
55
+ var_with_patch.append((get_pixel_var_degree_for_patch(patch), patch, coordinates[i]))
56
+
57
+ var_with_patch.sort(reverse=True, key=lambda x: x[0])
58
+ mid_point = len(var_with_patch) // 2
59
+ r_patch = [(patch, coor) for var, patch, coor in var_with_patch[:mid_point]]
60
+ p_patch = [(patch, coor) for var, patch, coor in var_with_patch[mid_point:]]
61
+ p_patch.reverse()
62
+ return r_patch, p_patch, img_size
63
+
64
+ def azimuthalAverage(image, center=None):
65
+ y, x = np.indices(image.shape)
66
+ if not center:
67
+ center = np.array([(x.max() - x.min()) / 2.0, (y.max() - y.min()) / 2.0])
68
+ r = np.hypot(x - center[0], y - center[1])
69
+ ind = np.argsort(r.flat)
70
+ r_sorted = r.flat[ind]
71
+ i_sorted = image.flat[ind]
72
+ r_int = r_sorted.astype(int)
73
+ deltar = r_int[1:] - r_int[:-1]
74
+ rind = np.where(deltar)[0]
75
+ nr = rind[1:] - rind[:-1]
76
+ csim = np.cumsum(i_sorted, dtype=float)
77
+ tbin = csim[rind[1:]] - csim[rind[:-1]]
78
+ radial_prof = tbin / nr
79
+ return radial_prof
80
+
81
+ def azimuthal_integral(img, epsilon=1e-8, N=50):
82
+ if len(img.shape) == 3 and img.shape[2] == 3:
83
+ img = RGB2gray(img)
84
+ f = np.fft.fft2(img)
85
+ fshift = np.fft.fftshift(f)
86
+ fshift += epsilon
87
+ magnitude_spectrum = 20 * np.log(np.abs(fshift))
88
+ psd1D = azimuthalAverage(magnitude_spectrum)
89
+ points = np.linspace(0, N, num=psd1D.size)
90
+ xi = np.linspace(0, N, num=N)
91
+ interpolated = griddata(points, psd1D, xi, method='cubic')
92
+ interpolated = (interpolated - np.min(interpolated)) / (np.max(interpolated) - np.min(interpolated))
93
+ return interpolated.astype(np.float32)
94
+
95
+ def positional_emb(coor, im_size, N):
96
+ img_height, img_width = im_size
97
+ center_y, center_x = coor
98
+ normalized_y = center_y / img_height
99
+ normalized_x = center_x / img_width
100
+ pos_emb = np.zeros(N)
101
+ indices = np.arange(N)
102
+ div_term = 10000 ** (2 * (indices // 2) / N)
103
+ pos_emb[0::2] = np.sin(normalized_y / div_term[0::2]) + np.sin(normalized_x / div_term[0::2])
104
+ pos_emb[1::2] = np.cos(normalized_y / div_term[1::2]) + np.cos(normalized_x / div_term[1::2])
105
+ return pos_emb
106
+
107
+ def azi_diff(img: PIL.Image.Image, patch_num, N):
108
+ r, p, im_size = get_rich_poor_patches(img)
109
+ r_len = len(r)
110
+ p_len = len(p)
111
+ patch_emb_r = np.zeros((patch_num, N))
112
+ patch_emb_p = np.zeros((patch_num, N))
113
+ positional_emb_r = np.zeros((patch_num, N))
114
+ positional_emb_p = np.zeros((patch_num, N))
115
+ coor_r = []
116
+ coor_p = []
117
+ if r_len != 0:
118
+ for idx in range(patch_num):
119
+ tmp_patch1 = r[idx % r_len][0]
120
+ tmp_coor1 = r[idx % r_len][1]
121
+ patch_emb_r[idx] = azimuthal_integral(tmp_patch1, N=N)
122
+ positional_emb_r[idx] = positional_emb(tmp_coor1, im_size, N)
123
+ coor_r.append(tmp_coor1)
124
+ if p_len != 0:
125
+ for idx in range(patch_num):
126
+ tmp_patch2 = p[idx % p_len][0]
127
+ tmp_coor2 = p[idx % p_len][1]
128
+ patch_emb_p[idx] = azimuthal_integral(tmp_patch2, N=N)
129
+ positional_emb_p[idx] = positional_emb(tmp_coor2, im_size, N)
130
+ coor_p.append(tmp_coor2)
131
+ output = {"total_emb": [patch_emb_r + positional_emb_r / 5, patch_emb_p + positional_emb_p / 5],
132
+ "positional_emb": [positional_emb_r / 5, positional_emb_p / 5], "coor": [coor_r, coor_p],
133
+ "image_size": im_size}
134
+ return output