alaaelnouby commited on
Commit
89d2039
·
verified ·
1 Parent(s): af6cc8c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +201 -5
README.md CHANGED
@@ -1,5 +1,201 @@
1
- ---
2
- license: other
3
- license_name: apple-sample-code-license
4
- license_link: LICENSE
5
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AIM: Autoregressive Image Models
2
+
3
+ *Alaaeldin El-Nouby, Michal Klein, Shuangfei Zhai, Miguel Angel Bautista, Alexander Toshev, Vaishaal Shankar,
4
+ Joshua M Susskind, and Armand Joulin*
5
+
6
+
7
+ This software project accompanies the research paper, Scalable Pre-training of Large Autoregressive Image Models.
8
+
9
+ We introduce **AIM** a collection of vision models pre-trained with an autoregressive generative objective.
10
+ We show that autoregressive pre-training of image features exhibits similar scaling properties to their
11
+ textual counterpart (i.e. Large Language Models). Specifically, we highlight two findings:
12
+ 1. the model capacity can be trivially scaled to billions of parameters, and
13
+ 2. AIM effectively leverages large collections of uncurated image data.
14
+
15
+ ## Installation
16
+ Please install PyTorch using the official [installation instructions](https://pytorch.org/get-started/locally/).
17
+ Afterward, install the package as:
18
+ ```commandline
19
+ pip install git+https://[email protected]/apple/ml-aim.git
20
+ ```
21
+ We also offer [MLX](https://github.com/ml-explore/mlx) backend support for research and experimentation on Apple silicon.
22
+ To enable MLX support, simply run:
23
+ ```commandline
24
+ pip install mlx
25
+ ```
26
+
27
+ ## Usage
28
+ Below we provide an example of usage in [PyTorch](https://pytorch.org/):
29
+ ```python
30
+ from PIL import Image
31
+
32
+ from aim.utils import load_pretrained
33
+ from aim.torch.data import val_transforms
34
+
35
+ img = Image.open(...)
36
+ model = load_pretrained("aim-600M-2B-imgs", backend="torch")
37
+ transform = val_transforms()
38
+
39
+ inp = transform(img).unsqueeze(0)
40
+ logits, _ = model(inp)
41
+ ```
42
+
43
+ <details>
44
+ <summary>and in both <a href="https://ml-explore.github.io/mlx/">MLX</a></summary>
45
+
46
+ ```python
47
+ from PIL import Image
48
+ import mlx.core as mx
49
+
50
+ from aim.utils import load_pretrained
51
+ from aim.torch.data import val_transforms
52
+
53
+ img = Image.open(...)
54
+ model = load_pretrained("aim-600M-2B-imgs", backend="mlx")
55
+ transform = val_transforms()
56
+
57
+ inp = transform(img).unsqueeze(0)
58
+ inp = mx.array(inp.numpy())
59
+ logits, _ = model(inp)
60
+ ```
61
+ </details>
62
+
63
+ <details>
64
+ <summary>and <a href="https://jax.readthedocs.io/">JAX</a></summary>
65
+
66
+ ```python
67
+ from PIL import Image
68
+ import jax.numpy as jnp
69
+
70
+ from aim.utils import load_pretrained
71
+ from aim.torch.data import val_transforms
72
+
73
+ img = Image.open(...)
74
+ model, params = load_pretrained("aim-600M-2B-imgs", backend="jax")
75
+ transform = val_transforms()
76
+
77
+ inp = transform(img).unsqueeze(0)
78
+ inp = jnp.array(inp)
79
+ (logits, _), _ = model.apply(params, inp, mutable=['batch_stats'])
80
+ ```
81
+ </details>
82
+
83
+
84
+ ## Pre-trained checkpoints
85
+
86
+ The pre-trained models can be accessed via [PyTorch Hub](https://pytorch.org/hub/) as:
87
+ ```python
88
+ import torch
89
+
90
+ aim_600m = torch.hub.load("apple/ml-aim", "aim-600M")
91
+ aim_1b = torch.hub.load("apple/ml-aim", "aim-1B")
92
+ aim_3b = torch.hub.load("apple/ml-aim", "aim-3B")
93
+ aim_7b = torch.hub.load("apple/ml-aim", "aim-7B")
94
+ ```
95
+
96
+ ### Pre-trained backbones
97
+
98
+ The following table contains pre-trained backbones used in our paper.
99
+
100
+ <table style="margin: auto">
101
+ <thead>
102
+ <tr>
103
+ <th>model</th>
104
+ <th>#params</th>
105
+ <th>attn (best layer)</th>
106
+ <th>backbone, SHA256</th>
107
+ </tr>
108
+ </thead>
109
+ <tbody>
110
+ <tr>
111
+ <td>AIM-0.6B</td>
112
+ <td>0.6B</td>
113
+ <td>79.4%</td>
114
+ <td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_600m_2bimgs_attnprobe_backbone.pth">link</a>, 0d6f6b8f</td>
115
+ </tr>
116
+ <tr>
117
+ <td>AIM-1B</td>
118
+ <td>1B</td>
119
+ <td>82.3%</td>
120
+ <td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_1b_5bimgs_attnprobe_backbone.pth">link</a>, d254ecd3</td>
121
+ </tr>
122
+ <tr>
123
+ <td>AIM-3B</td>
124
+ <td>3B</td>
125
+ <td>83.3%</td>
126
+ <td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_3b_5bimgs_attnprobe_backbone.pth">link</a>, 8475ce4e</td>
127
+ </tr>
128
+ <tr>
129
+ <td>AIM-7B</td>
130
+ <td>7B</td>
131
+ <td>84.0%</td>
132
+ <td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_7b_5bimgs_attnprobe_backbone.pth">link</a>, 184ed94c</td>
133
+ </tr>
134
+ </tbody>
135
+ </table>
136
+
137
+ ### Pre-trained attention heads
138
+
139
+ The table below contains the classification results on ImageNet-1k validation set.
140
+
141
+ <table style="margin: auto">
142
+ <thead>
143
+ <tr>
144
+ <th rowspan="2">model</th>
145
+ <th colspan="2">top-1 IN-1k</th>
146
+ <th colspan="2">attention head, SHA256</th>
147
+ </tr>
148
+ <tr>
149
+ <th>last layer</th>
150
+ <th>best layer</th>
151
+ <th>last layer</th>
152
+ <th>best layer</th>
153
+ </tr>
154
+ </thead>
155
+
156
+ <tbody>
157
+ <tr>
158
+ <td>AIM-0.6B</td>
159
+ <td>78.5%</td>
160
+ <td>79.4%</td>
161
+ <td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_600m_2bimgs_attnprobe_head_last_layers.pth">link</a>, 5ce5a341</td>
162
+ <td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_600m_2bimgs_attnprobe_head_best_layers.pth">link</a>, ebd45c05</td>
163
+ </tr>
164
+ <tr>
165
+ <td>AIM-1B</td>
166
+ <td>80.6%</td>
167
+ <td>82.3%</td>
168
+ <td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_1b_5bimgs_attnprobe_head_last_layers.pth">link</a>, db3be2ad</td>
169
+ <td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_1b_5bimgs_attnprobe_head_best_layers.pth">link</a>, f1ed7852</td>
170
+ </tr>
171
+ <tr>
172
+ <td>AIM-3B</td>
173
+ <td>82.2%</td>
174
+ <td>83.3%</td>
175
+ <td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_3b_5bimgs_attnprobe_head_last_layers.pth">link</a>, 5c057b30</td>
176
+ <td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_3b_5bimgs_attnprobe_head_best_layers.pth">link</a>, ad380e16</td>
177
+ </tr>
178
+ <tr>
179
+ <td>AIM-7B</td>
180
+ <td>82.4%</td>
181
+ <td>84.0%</td>
182
+ <td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_7b_5bimgs_attnprobe_head_last_layers.pth">link</a>, 1e5c99ba</td>
183
+ <td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_7b_5bimgs_attnprobe_head_best_layers.pth">link</a>, 73ecd732</td>
184
+ </tr>
185
+ </tbody>
186
+ </table>
187
+
188
+ ## Reproducing the IN-1k classification results
189
+ The commands below reproduce the [attention probe results](#pre-trained-attention-heads) on ImageNet-1k
190
+ validation set. We run the evaluation using 1 node with 8 GPUs:
191
+ ```commandline
192
+ torchrun --standalone --nnodes=1 --nproc-per-node=8 main_attnprobe.py \
193
+ --model=aim-7B \
194
+ --batch-size=64 \
195
+ --data-path=/path/to/imagenet \
196
+ --probe-layers=last \
197
+ --backbone-ckpt-path=/path/to/backbone_ckpt.pth \
198
+ --head-ckpt-path=/path/to/head_ckpt.pth
199
+ ```
200
+ By default, we probe the last 6 layers. To change this, simply pass `--probe-layers=best`.
201
+