= commited on
Commit
391f8df
·
1 Parent(s): 459b89f

add injection

Browse files
Files changed (4) hide show
  1. utils/comp_two_files.py +75 -0
  2. utils/paraProcess.py +80 -0
  3. utils/paraShow.py +163 -0
  4. utils/swin.py +298 -0
utils/comp_two_files.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bitstring import BitArray
2
+ import os
3
+
4
+ def compare_bits(file1, file2, max_diff_display=20):
5
+ """
6
+ 比较两个文件的bit级差异
7
+ :param file1: 第一个文件路径
8
+ :param file2: 第二个文件路径
9
+ :param max_diff_display: 最大差异位置显示数量
10
+ :return: 差异统计字典
11
+ """
12
+ # 读取文件并转换为bit数组
13
+ bits1 = BitArray(filename=file1)
14
+ bits2 = BitArray(filename=file2)
15
+
16
+ # 获取bit长度
17
+ len1, len2 = len(bits1), len(bits2)
18
+ min_len = min(len1, len2)
19
+
20
+ # 统计结果
21
+ diff_stats = {
22
+ 'total_bits_file1': len1,
23
+ 'total_bits_file2': len2,
24
+ 'differing_bits': 0,
25
+ 'diff_positions': [],
26
+ 'bit_length_mismatch': len1 != len2
27
+ }
28
+
29
+ # 逐bit比较
30
+ for i in range(min_len):
31
+ if bits1[i] != bits2[i]:
32
+ diff_stats['differing_bits'] += 1
33
+ if len(diff_stats['diff_positions']) < max_diff_display:
34
+ diff_stats['diff_positions'].append(i)
35
+
36
+ # 处理长度不一致的情况
37
+ if len1 != len2:
38
+ diff_stats['extra_bits'] = abs(len1 - len2)
39
+ else:
40
+ diff_stats['extra_bits'] = 0
41
+
42
+ return diff_stats
43
+
44
+ def print_diff_report(diff_stats):
45
+ """打印差异报告"""
46
+ print(f"Bit长度比较:")
47
+ print(f" File1: {diff_stats['total_bits_file1']} bits")
48
+ print(f" File2: {diff_stats['total_bits_file2']} bits")
49
+
50
+ if diff_stats['bit_length_mismatch']:
51
+ print(f"\n! 文件长度不一致,相差 {diff_stats['extra_bits']} bits")
52
+
53
+ print(f"\n差异bit总数: {diff_stats['differing_bits']}")
54
+
55
+ if diff_stats['differing_bits'] > 0:
56
+ print(f"\n前 {len(diff_stats['diff_positions'])} 个差异位置 (0-based):")
57
+ for pos in diff_stats['diff_positions']:
58
+ print(f" Bit位置 {pos}")
59
+
60
+ if __name__ == "__main__":
61
+ # 使用示例
62
+ file1 = "../malwares/generated_malware"
63
+ file2 = "../malwares/generated_malware_extracted"
64
+
65
+ # 比较文件
66
+ diff_stats = compare_bits(file1, file2)
67
+
68
+ # 打印报告
69
+ print_diff_report(diff_stats)
70
+
71
+ # 高级用法:直接访问差异数据
72
+ print("\n高级访问:")
73
+ print(f"总差异bit数: {diff_stats['differing_bits']}")
74
+ if diff_stats['differing_bits'] > 0:
75
+ print(f"第一个差异位置: {diff_stats['diff_positions'][0]}")
utils/paraProcess.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 处理参数,如参数翻转、参数替换等等
3
+ """
4
+ import os
5
+ import torch
6
+ import random
7
+ import struct
8
+ import pandas as pd
9
+ from bitstring import BitArray
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ import numpy as np
12
+ import math
13
+
14
+
15
+ def split_file(file_path, chunk_size=8):
16
+ """
17
+ 分割源文件成为元素BitArray的list
18
+ :param file_path: 源文件路径
19
+ :param chunk_size: 分割粒度 Bit
20
+ :return: 返回一个元素BitArray的list
21
+ """
22
+ # 以bit的形式读取文件
23
+ bit_data = BitArray(filename = file_path)
24
+ chunks = [bit_data[i:i+chunk_size] for i in range(0, len(bit_data), chunk_size)]
25
+ return chunks
26
+
27
+
28
+ def merge_file(output_file, chunks):
29
+ """
30
+ 将BitArray的list合并成一个文件
31
+ :param output_file: 目标文件路径
32
+ :param chunks: BitArray的list
33
+ :return: 合并后的文件
34
+ """
35
+ merge_data = BitArray()
36
+ for chunk in chunks:
37
+ merge_data.append(chunk)
38
+
39
+ with open(output_file, 'wb') as file:
40
+ merge_data.tofile(file)
41
+
42
+
43
+
44
+ def layer_low_n_bit_fLip(initParaPath, flipParaPath, bit_n, *layers):
45
+ """
46
+ 翻转pth的layers层fa的低n bit
47
+ :param initParaPath: 原始参数pth
48
+ :param flipParaPath: 翻转之后的参数pth
49
+ :param bit_n: 翻转低多少bit
50
+ :return: void
51
+ """
52
+ para = torch.load(initParaPath)
53
+
54
+ for layer in layers: # layers数组中的所有layer
55
+ if len(para[layer].data.shape) < 1:
56
+ continue # 单值除去
57
+ # print(layer, type(layer))
58
+ layerTensor = para[layer].data
59
+ # print(layerTensor.shape)
60
+ layerTensor_initView = layerTensor.view(torch.int32)
61
+ # print(format(layerTensor_initView[0][0][0][0], '032b'), layerTensor[0][0][0][0])
62
+ layerTensor_embedded_int = layerTensor_initView ^ bit_n
63
+ layerTensor_embedded = layerTensor_embedded_int.view(torch.float32)
64
+ # print(format(layerTensor_embedded_int[0][0][0][0], '032b'), layerTensor_embedded[0][0][0][0])
65
+
66
+ para[layer].data = layerTensor_embedded
67
+
68
+ torch.save(para, flipParaPath)
69
+ return
70
+
71
+
72
+
73
+
74
+
75
+
76
+
77
+
78
+ if __name__ == "__main__":
79
+
80
+ print("Test Done")
utils/paraShow.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 展示参数分布、属性
3
+ """
4
+ import os
5
+ import torch
6
+ import random
7
+ import struct
8
+ import pandas as pd
9
+ from bitstring import BitArray
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ import numpy as np
12
+ import math
13
+
14
+ from numpy.ma.core import shape
15
+
16
+
17
+ def showDif(file1, file2):
18
+ """
19
+ 对比提取的恶意软件和原始恶意软件的区别,返回出错的bit
20
+ :return:
21
+ """
22
+ malwareStr1 = BitArray(filename=file1).bin
23
+ malwareStr2 = BitArray(filename=file2).bin
24
+ diffNum = 0
25
+ if malwareStr1 != malwareStr2:
26
+ print("两个恶意软件大小不同,第一个Bit数为:", malwareStr1, " 第二个Bit数为:", malwareStr1)
27
+ return
28
+ for i in range(len(malwareStr1)):
29
+ if malwareStr1[i] != malwareStr2[i]: # 打印出所有不同的bit的位置
30
+ print("pos:", i, "initBit:", malwareStr1[i], "extractedBit:", malwareStr2[i])
31
+ diffNum += 1
32
+ # print(malwareStr1)
33
+ # print(malwareStr2)
34
+ print("different bit Num between the two files: ", diffNum)
35
+ return diffNum
36
+
37
+
38
+ def get_file_bit_num(file_path):
39
+ """
40
+ 通过文件路径,获得文件bit数
41
+ """
42
+ return os.path.getSize(file_path) * 8
43
+
44
+
45
+ def getExpEmbeddSize(initParaPath, layers, interval=1, correct=1):
46
+ """
47
+ 返回指数部分最大的嵌入容量,单位是字节Byte
48
+ :param initParaPath:
49
+ :param layers: list
50
+ :param interval: 每interval个中嵌入一个
51
+ :return: list
52
+ """
53
+ para = torch.load(initParaPath, map_location=torch.device("cpu"))
54
+ ret = []
55
+ for layer in layers:
56
+ paraTensor = para[layer].data
57
+ paraTensor_flat = paraTensor.flatten()
58
+ # print(initParaPath, layers, paraTensor_flat.size())
59
+ layerSize = len(paraTensor_flat) // (interval * correct * 8)
60
+ # print(layer, len(paraTensor_flat), layerSize)
61
+ ret.append(layerSize)
62
+ return ret
63
+
64
+
65
+ def generate_file_with_bits(file_path, num_bits):
66
+ """
67
+ 根据需要多少bit,随机生成对应大小的恶意软件
68
+ :param file_path:
69
+ :param num_bits:
70
+ :return:
71
+ """
72
+ # 计算需要的字节数,每字节有8个bit
73
+ num_bytes = (num_bits + 7) // 8 # 向上取整,保证比特数足够
74
+ print("Byte Num:", num_bytes)
75
+
76
+ # 创建一个包含随机字节的字节数组
77
+ byte_array = bytearray(random.getrandbits(8) for _ in range(num_bytes))
78
+
79
+ # 如果不需要最后一个字节的全部位,将多余的位清零
80
+ if num_bits % 8 != 0:
81
+ last_byte_bits = num_bits % 8
82
+ # 保留最后字节所需的位数,其它位清零
83
+ mask = (1 << last_byte_bits) - 1
84
+ byte_array[-1] &= mask
85
+
86
+ # 将字节数组写入文件
87
+ with open(file_path, 'wb') as f:
88
+ f.write(byte_array)
89
+
90
+ print(f"File '{file_path}' generated with {num_bits} bits.")
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+ if __name__ == "__main__":
99
+ """测试路径"""
100
+ path_swin = "../parameters/classification/swin/swin_b.pth"
101
+ path_yolo = "../parameters/detection/yolov10/yolov10b.pt"
102
+ path_rt = "../parameters/detection/rt_dert/rt.pth"
103
+ path_sam = "../parameters/segmentation/samv2/sam.pth"
104
+
105
+
106
+
107
+
108
+ """测试swin"""
109
+ # swin_keys = torch.load(path_swin).keys()
110
+ # print(type(torch.load(path_swin)))
111
+
112
+
113
+
114
+
115
+ pth_keys = torch.load(path_yolo)
116
+ print(pth_keys.keys())
117
+ print(type(pth_keys['model'].model.named_modules()))
118
+ print(pth_keys['model'].model.named_modules())
119
+ # for idx, layer in enumerate(pth_keys['model'].model):
120
+ # print(f"层 {idx}: {layer}")
121
+
122
+
123
+ # 遍历模型中所有的子模块(包括嵌套层)
124
+ for idx, (name, module) in enumerate(pth_keys['model'].model.named_modules()):
125
+ print(f"模块索引 {idx} - 名称 {name}: {module}")
126
+
127
+
128
+
129
+
130
+
131
+
132
+
133
+ # print((pth_keys['model'].model[23].cv2[0]))
134
+ # print((pth_keys['model']))
135
+ # print(len(pth_keys['model'].model.conv.weight))
136
+ # print(type(pth_keys['model'].model[1]))
137
+ # print(type(pth_keys['model'].model[0].conv.weight))
138
+ # print(type(pth_keys['model'].model[0].conv.weight.data))
139
+ # print(shape(pth_keys['model'].model[0].conv.weight.data))
140
+ # print(pth_keys['model'].model[0].conv.weight.data[0][0][0][0].dtype)
141
+
142
+
143
+
144
+ # path3 = "../parameters/detection/rt_dert/rt.pth"
145
+ # a = torch.load(path3)
146
+ # print(a['ema'].keys())
147
+ # print(a['ema']['module']['backbone.conv1.conv1_1.conv.weight'][0][0][0][0].dtype)
148
+ #
149
+ # path4 = "../parameters/segmentation/samv2/sam.pth"
150
+ # b = torch.load(path4)
151
+ # print(b.keys())
152
+ # print(b['image_encoder.neck.0.weight'][0][0][0][0].dtype)
153
+
154
+
155
+
156
+
157
+ # print(get_file_bit_num(path1))
158
+
159
+ # print(pth_keys['train_args'])
160
+
161
+
162
+ print("Test Done")
163
+
utils/swin.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ from bitstring import BitArray
5
+
6
+ # from utils import generate_file_with_bits, showDif
7
+
8
+
9
+
10
+ def generate_file_with_bits(file_path, num_bits):
11
+ """
12
+ 根据需要多少bit,随机生成对应大小的恶意软件
13
+ :param file_path:
14
+ :param num_bits:
15
+ :return:
16
+ """
17
+ # 计算需要的字节数,每字节有8个bit
18
+ num_bytes = (num_bits + 7) // 8 # 向上取整,保证比特数足够
19
+ print("Byte Num:", num_bytes)
20
+
21
+ # 创建一个包含随机字节的字节数组
22
+ byte_array = bytearray(random.getrandbits(8) for _ in range(num_bytes))
23
+
24
+ # 如果不需要最后一个字节的全部位,将多余的位清零
25
+ if num_bits % 8 != 0:
26
+ last_byte_bits = num_bits % 8
27
+ # 保留最后字节所需的位数,其它位清零
28
+ mask = (1 << last_byte_bits) - 1
29
+ byte_array[-1] &= mask
30
+
31
+ # 将字节数组写入文件
32
+ with open(file_path, 'wb') as f:
33
+ f.write(byte_array)
34
+
35
+ print(f"File '{file_path}' generated with {num_bits} bits.")
36
+
37
+
38
+
39
+
40
+ class swin:
41
+ def __init__(self, path):
42
+ """
43
+ 初始化使用参数的路径进行初始化
44
+ :param path:
45
+ """
46
+ self.path = path
47
+ return
48
+
49
+
50
+ def get_pth_keys(self):
51
+ """
52
+ 返回参数的key
53
+ :param paraPath: 待获得的参数pth
54
+ :return:
55
+ """
56
+ return torch.load(self.path, map_location=torch.device("cpu")).keys()
57
+
58
+
59
+ def get_pth_keys_float32(self):
60
+ """
61
+ 返回参数的key
62
+ :param paraPath: 待获得的参数pth
63
+ :return:
64
+ """
65
+ para = torch.load(self.path, map_location=torch.device("cpu"))
66
+ temp = para.keys()
67
+ layers = []
68
+ for i in temp:
69
+ if para[i].data.dtype == torch.float32:
70
+ layers.append(i)
71
+ return layers
72
+
73
+
74
+ def get_file_bit_num(self):
75
+ """
76
+ 通过文件路径,获得文件bit数
77
+ :return: bit size
78
+ """
79
+ return os.path.getSize(self.path) * 8
80
+
81
+
82
+ def layer_low_n_bit_fLip(self, flip_path, bit_n, *layers):
83
+ """
84
+ 翻转pth的layers层的低n bit
85
+ :param flip_path: 翻转之后的参数pth
86
+ :param bit_n: 翻转低多少bit
87
+ :return: void
88
+ """
89
+ para = torch.load(self.path)
90
+ mask= (1<<bit_n)-1
91
+ for layer in layers:
92
+ if len(para[layer].data.shape) < 1:
93
+ continue
94
+ layer_tensor = para[layer].data
95
+ # 确保 Tensor 的数据类型为浮点型
96
+ if layer_tensor.dtype == torch.float32:
97
+ # 使用整数视图来操作比特,但应谨慎操作
98
+ layer_tensor_view = layer_tensor.view(torch.int32)
99
+ layer_tensor_view ^= mask # 对低位进行翻转操作
100
+ para[layer].data = layer_tensor_view.view(torch.float32)
101
+ torch.save(para, flip_path)
102
+
103
+
104
+ def all_layers_low_n_bit_fLip(self, flip_path, bit_n):
105
+ """
106
+ 翻转所有层的低n bit,需要满足其中的数据类型是fp32
107
+ :param flip_path: 翻转之后的参数pth
108
+ :param bit_n: 翻转低多少bit
109
+ :return:
110
+ """
111
+ self.layer_low_n_bit_fLip(flip_path, bit_n, *self.get_pth_keys_float32())
112
+
113
+
114
+ def get_layers_low_n_bit_size(self, layers, bit_n):
115
+ """
116
+ usage:
117
+ size, size_list = agent.get_layers_low_n_bit_size(agent.get_pth_keys(), 16)
118
+ 返回指数部分最大的嵌入容量,单位是字节bit
119
+ :param layers: list
120
+ :param interval: 每interval个中嵌入一个
121
+ :return: 总大小,每一层的大小
122
+ """
123
+ para = torch.load(self.path, map_location=torch.device("cpu"))
124
+ size_with_layers_list = []
125
+ total_size = 0
126
+ for layer in layers:
127
+
128
+ if para[layer].data.dtype != torch.float32:
129
+ continue # 只查看是float32数据类型的数值进行嵌入
130
+
131
+ paraTensor = para[layer].data
132
+ paraTensor_flat = paraTensor.flatten()
133
+ layerSize = len(paraTensor_flat) * bit_n
134
+ size_with_layers_list.append(layerSize)
135
+ total_size += layerSize
136
+ return total_size, size_with_layers_list
137
+
138
+
139
+
140
+ def all_layers_low_n_bit_inject(self, inject_path, bit_n, malware, malware_len):
141
+ """
142
+ 随机生成一个软件、在所有层的低nbit进行嵌入
143
+ :param inject_path: 最后的嵌入pth文件
144
+ :param bit_n: 低nbit
145
+ :param malware: 需要嵌入的恶意软件
146
+ :param malware_len: 需要嵌入的恶意软件的长度,单位为bit
147
+ :return:
148
+ """
149
+ paras = torch.load(self.path, map_location=torch.device("cpu"))
150
+ malware_str = BitArray(filename=malware).bin
151
+ mal_index = 0 # 需要访问的恶意软件的bit区间起点
152
+ _, size_list = self.get_layers_low_n_bit_size(self.get_pth_keys_float32(), bit_n)
153
+ for layer, size in zip(self.get_pth_keys_float32(), size_list):
154
+ if paras[layer].data.dtype != torch.float32:
155
+ continue # 只考虑32位浮点数
156
+ print(layer, size)
157
+ para_tensor_flat = paras[layer].flatten()
158
+ # index 再恶意软件中和tensor中并不是同一个!
159
+ print(mal_index, mal_index + size)
160
+ print(para_tensor_flat.size(), size//bit_n)
161
+ para_index = 0 # 写入的参数的位置
162
+ for inject_pos in range(mal_index, min(mal_index + size, malware_len), bit_n):
163
+ current_write_content = malware_str[inject_pos: inject_pos + bit_n] # 将n bit数据提取出来
164
+ para_tensor_flat_str = BitArray(int=para_tensor_flat[para_index].view(torch.int32), length=32).bin
165
+ new_para_tensor_flat_str = para_tensor_flat_str[:32 - bit_n] + current_write_content
166
+
167
+ if int(new_para_tensor_flat_str, 2) >= 2 ** 31:
168
+ newParaInt = torch.tensor(int(new_para_tensor_flat_str, 2) - 2 ** 32, dtype=torch.int32)
169
+ para_tensor_flat[para_index] = newParaInt.view(torch.float32)
170
+ else:
171
+ newParaInt = torch.tensor(int(new_para_tensor_flat_str, 2), dtype=torch.int32)
172
+ para_tensor_flat[para_index] = newParaInt.view(torch.float32)
173
+ para_index += 1 # 写入的位置往后推1bit
174
+
175
+ if mal_index + size >= malware_len:
176
+ break
177
+ else:
178
+ mal_index = mal_index + size
179
+ paras[layer] = para_tensor_flat.reshape(paras[layer].data.shape)
180
+ torch.save(paras, inject_path)
181
+ return
182
+
183
+
184
+ def all_layers_low_n_bit_extract(self, inject_path, bit_n, extract_malware, malware_len):
185
+ """
186
+ :param malware_len: 需要嵌入的恶意软件的长度,单位为bit
187
+ :param inject_path: 嵌入后的模型参数路径
188
+ :param bit_n: 嵌入参数的后nbit
189
+ :param extract_malware: 提取出来的恶意软件路径
190
+ :param malware_len: 需要嵌入的恶意软件的长度,单位为bit
191
+ :return:
192
+ """
193
+ paras = torch.load(inject_path, map_location="cpu"); bits, idx = BitArray(), 0;
194
+ _, size_list = self.get_layers_low_n_bit_size(self.get_pth_keys_float32(), bit_n)
195
+ for layer, _ in zip(self.get_pth_keys_float32(), size_list):
196
+ p = paras[layer].data
197
+ if p.dtype != torch.float32: continue
198
+ for x in p.flatten()[:min(len(p.flatten()), (malware_len - idx + bit_n - 1) // bit_n)]:
199
+ bits.append(f'0b{BitArray(int=int(x.view(torch.int32)), length=32).bin[-bit_n:]}');
200
+ idx += bit_n
201
+ if idx >= malware_len: break
202
+ if idx >= malware_len: break
203
+ with open(extract_malware, 'wb') as f:
204
+ bits_clip = bits[:(malware_len-(malware_len%bit_n))] + bits[-(malware_len%bit_n):]
205
+ bits_clip[:malware_len].tofile(f)
206
+ return
207
+
208
+
209
+ # def all_layers_low_n_bit_extract(self, inject_path, bit_n, extract_malware, malware_len):
210
+ # paras = torch.load(inject_path, map_location="cpu");
211
+ # pl = torch.load(self.path, map_location="cpu");
212
+ # layers = [k for k, v in pl.items() if v.dtype == torch.float32];
213
+ # bits, idx = BitArray(), 0
214
+ # for l in layers:
215
+ # f = paras[l].data.flatten();
216
+ # r = min(len(f), (malware_len - idx + bit_n - 1) // bit_n)
217
+ # for x in f[:r]:
218
+ # bits.append(f'0b{BitArray(int=int(x.view(torch.int32)), length=32).bin[-bit_n:]}');
219
+ # idx += bit_n
220
+ # if idx >= malware_len: break
221
+ # if idx >= malware_len: break
222
+ # with open(extract_malware, 'wb') as f:
223
+ # bits[:malware_len].tofile(f); return
224
+
225
+
226
+ if __name__ == "__main__":
227
+
228
+
229
+
230
+ path = "../parameters/classification/swin_face/swin_face.pth"
231
+ # flip_path = "../parametersProcess/swin/swin_flip_16.pth"
232
+ inject_path = "../parametersProcess/swin_face/swin_evilfiles_16.pth"
233
+ malware = "../malwares/generated_malware"
234
+ extract_malware = "../malwares/generated_malware_extracted"
235
+ agent = swin(path)
236
+
237
+
238
+ # print("layers name: ", agent.get_pth_keys_float32())
239
+ # print("type: ", type(agent.get_pth_keys_float32()))
240
+ # print("layers num: ", len(agent.get_pth_keys_float32()))
241
+
242
+ size, size_list = agent.get_layers_low_n_bit_size(agent.get_pth_keys_float32(), 16)
243
+ # print("all layers injection size with low-16 bits: ", size / 8000000, " MB")
244
+ # print(size_list)
245
+ # print(len(size_list))
246
+ '''随机生成一个恶意软件,全部嵌入模型的层(简化流程)'''
247
+ generate_file_with_bits(malware, size)
248
+
249
+
250
+ print("malware bit size: ",os.path.getsize(malware) * 8)
251
+
252
+
253
+ '''嵌入'''
254
+ agent.all_layers_low_n_bit_inject(inject_path, 20, malware, os.path.getsize(malware) * 8)
255
+
256
+ '''提取'''
257
+
258
+
259
+ # def all_layers_low_n_bit_extract(ip, bn, em, ml):
260
+ # p = torch.load(ip, map_location="cpu");
261
+ # b, i = BitArray(), 0;
262
+ # lrs = [k for k, v in p.items() if v.dtype == torch.float32]
263
+ # for l in lrs:
264
+ # for x in p[l].data.flatten()[:min(len(p[l].data.flatten()), (ml - i + bn - 1) // bn)]:
265
+ # b.append(f'0b{BitArray(int=int(x.view(torch.int32)), length=32).bin[-bn:]}');
266
+ # i += bn
267
+ # if i >= ml: break
268
+ # if i >= ml: break
269
+ # with open(em, 'wb') as f:
270
+ # b[:ml].tofile(f);return
271
+
272
+
273
+ agent.all_layers_low_n_bit_extract(inject_path, 20, extract_malware, os.path.getsize(malware) * 8)
274
+ # agent.all_layers_low_n_bit_extract(inject_path, 16, extract_malware, 36272)
275
+
276
+
277
+ # def all_layers_low_n_bit_extract(ip, bn, em, ml):
278
+ # p = torch.load(ip, map_location="cpu");
279
+ # b, i = BitArray(), 0;
280
+ # lrs = [k for k, v in p.items() if v.dtype == torch.float32]
281
+ # for l in lrs:
282
+ # for x in p[l].data.flatten()[:min(len(p[l].data.flatten()), (ml - i + bn - 1) // bn)]:
283
+ # b.append(f'0b{BitArray(int=int(x.view(torch.int32)), length=32).bin[-bn:]}');
284
+ # i += bn
285
+ # if i >= ml: break
286
+ # if i >= ml: break
287
+ # with open(em, 'wb') as f:
288
+ # b[:ml].tofile(f);return
289
+ # all_layers_low_n_bit_extract("~/data/ATATK/parametersProcess/swin/swin_inject_16.pth", 16, "~/data/ATATK/malwares/Zherkov_extract.EXE", 36272)
290
+
291
+
292
+
293
+
294
+
295
+
296
+
297
+
298
+ # agent.all_layers_low_n_bit_fLip(flip_path, 20)