WebDataset

1. WebDataset

WebDataset store data in tar files, each file contains a set of samples. Some common formats are:

  • .cls, .cls2, .index, .inx, and .id are commonly used for integers
  • .json and .mp (msgpack) are used for more complex annotations of images
  • .jpg and .jpg is recommended for images
  • .pickle is recommended for general Python data structures (but is not portable)
  • .pth and .pyd PyTorch dumps, good for torch tensors
  • .npy and .npz are NumPy array dumps

1.1. Prepare data

To prepare data for WebDataset, we can use the following code:

1
import os
2
import time
3
4
import PIL.PngImagePlugin
5
import webdataset as wds
6
from datasets import load_dataset
7
from huggingface_hub import HfApi
8
from tqdm import tqdm
9
10
PIL.PngImagePlugin.MAX_TEXT_CHUNK = 10 * (1024**2) # 10MB
11
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "300"
12
13
hf_api = HfApi()
14
output_dir = ""
15
os.makedirs(output_dir, exist_ok=True)
16
17
max_retries = 3
18
for attempt in range(max_retries):
19
try:
20
train_dataset = load_dataset(
21
"",
22
split="train",
23
streaming=True,
24
)
25
train_dataset = train_dataset.decode(False)
26
break
27
except Exception as e:
28
print(f"Load dataset failed (attempt {attempt + 1}/{max_retries}): {e}")
29
if attempt < max_retries - 1:
30
time.sleep(30 * (attempt + 1)) # exponential backoff
31
else:
32
raise
33
34
repo_id = ""
35
36
start_shard = 0
37
maxcount = 100000
38
skip_samples = start_shard * maxcount
39
40
41
def upload_to_hf(fname):
42
"""Upload completed shard to Hugging Face"""
43
try:
44
print(f"Uploading {fname} to Hugging Face...")
45
hf_api.upload_file(
46
path_or_fileobj=fname,
47
path_in_repo=os.path.basename(fname),
48
repo_id=repo_id,
49
repo_type="dataset",
50
commit_message=f"Add shard {os.path.basename(fname)}",
51
)
52
print(f"✅ Successfully uploaded {fname}")
53
54
# Optional: delete local file to save space
55
os.remove(fname)
56
57
except Exception as e:
58
print(f"❌ Upload failed {fname}: {e}")
59
60
61
with wds.writer.ShardWriter(
62
os.path.join(output_dir, "%05d.tar"),
63
maxcount=maxcount,
64
maxsize=50000000000, # 50GB
65
post=upload_to_hf,
66
start_shard=start_shard,
67
) as shard_writer:
68
error_count = 0
69
for i, data in enumerate(tqdm(train_dataset, desc="store to webdataset format")):
70
if i < skip_samples:
71
continue
72
try:
73
image = data["image"]["bytes"]
74
class_label = data["class"]
75
id = data["id"]
76
recaption = data["recaption"]
77
recaption_short = data["recaption_short"]
78
height = data["height"]
79
width = data["width"]
80
81
sample = {
82
"__key__": f"sample_{i:08d}",
83
"image.jpg": image, # PIL Image object
84
"class.txt": class_label, # string
85
"id.txt": id, # string
86
"recaption.txt": recaption, # string
87
"recaption_short.txt": recaption_short, # string
88
"height.cls": height, # int
89
"width.cls": width, # int
90
}
91
shard_writer.write(sample)
92
except Exception as e:
93
error_count += 1
94
print(f"Error {error_count}: {e}")
95
continue
96
print(f"Error count: {error_count}")

1.2. Use data

1
import ast
2
import json
3
import math
4
import os
5
import random
6
from collections.abc import Callable, Iterable
7
from multiprocessing import Value
8
9
import braceexpand
10
import webdataset as wds
11
from huggingface_hub import HfFileSystem, get_token, hf_hub_url, repo_exists
12
from PIL import Image
13
from torch.utils.data import default_collate
14
from torchvision import transforms
15
from webdataset.filters import _shuffle
16
17
Image.MAX_IMAGE_PIXELS = None
18
_SHARD_SHUFFLE_SIZE = 1000
19
_SHARD_SHUFFLE_INITIAL = 100
20
_SAMPLE_SHUFFLE_SIZE = 5000
21
_SAMPLE_SHUFFLE_INITIAL = 1000
22
23
24
def filter_by_res_ratio(
25
min_res: int = 256, min_ratio: float = 0.5, max_ratio: float = 2.0
26
):
27
def _f(sample: dict):
28
h, w = sample["height"], sample["width"]
29
ratio = h / w
30
longer_side = max(h, w)
31
return ratio >= min_ratio and ratio <= max_ratio and longer_side >= min_res
32
33
return _f
34
35
36
class ImageTransform:
37
def __init__(
38
self,
39
crop_size: int = 256,
40
random_crop: bool = True,
41
random_flip: bool = True,
42
normalize_mean: tuple[float, float, float] = (0.0, 0.0, 0.0),
43
normalize_std: tuple[float, float, float] = (1.0, 1.0, 1.0),
44
):
45
train_transform = []
46
if random_crop:
47
train_transform.append(transforms.RandomResizedCrop(crop_size))
48
else:
49
train_transform.extend(
50
[
51
transforms.Resize(crop_size),
52
transforms.CenterCrop(crop_size),
53
]
54
)
55
if random_flip:
56
train_transform.append(transforms.RandomHorizontalFlip())
57
# normalize_mean = [0, 0, 0] and normalize_std = [1, 1, 1] will normalize images into [0, 1],
58
# normalize_mean = [0.5, 0.5, 0.5] and normalize_std = [0.5, 0.5, 0.5] will normalize images into [-1, 1].
59
train_transform.extend(
60
[
61
transforms.ToTensor(),
62
transforms.Normalize(normalize_mean, normalize_std),
63
]
64
)
65
66
self.train_transform = transforms.Compose(train_transform)
67
self.eval_transform = transforms.Compose(
68
[
69
transforms.Resize(crop_size),
70
transforms.CenterCrop(crop_size),
71
transforms.ToTensor(),
72
transforms.Normalize(normalize_mean, normalize_std),
73
]
74
)
75
print(f"self.train_transform: {self.train_transform}")
76
print(f"self.eval_transform: {self.eval_transform}")
77
78
79
class SharedEpoch:
80
def __init__(self, epoch: int = 0):
81
self.shared_epoch = Value("i", epoch)
82
83
def set_value(self, epoch: int):
84
self.shared_epoch.value = epoch
85
86
def get_value(self):
87
return self.shared_epoch.value
88
89
90
def normalize_shards_path(shards_path: str) -> str:
91
if shards_path.startswith("http"):
92
return f"pipe: curl -s -L -H 'Authorization:Bearer {get_token()}' {shards_path}"
93
94
if repo_exists(shards_path, repo_type="dataset"):
95
fs = HfFileSystem()
96
glob_pattern = f"hf://datasets/{shards_path}/**/*.tar"
97
files = [fs.resolve_path(path) for path in fs.glob(glob_pattern)]
98
if not files:
99
raise ValueError(
100
f"No tar files found in HuggingFace dataset: {shards_path}"
101
)
102
urls = [
103
hf_hub_url(file.repo_id, file.path_in_repo, repo_type="dataset")
104
for file in files
105
]
106
shards_path = "::".join(urls)
107
return f"pipe: curl -s -L -H 'Authorization:Bearer {get_token()}' {shards_path}"
108
109
return shards_path
110
111
112
def expand_urls(urls: str, weights: str | list[float] | None = None):
113
if weights is None:
114
expanded_urls = wds.shardlists.expand_urls(urls)
115
return expanded_urls, None
116
else:
117
url_list = urls.split("::")
118
weights = weights.split("::")
119
assert len(weights) == len(url_list), (
120
f"Expected the number of data components ({len(url_list)}) and weights({len(weights)}) to match."
121
)
122
weights = [float(weight) for weight in weights]
123
all_urls, all_weights = [], []
124
for url, weight in zip(url_list, weights, strict=True):
125
expanded_url = list(braceexpand.braceexpand(url))
126
expanded_weights = [weight for _ in expanded_url]
127
all_urls.extend(expanded_url)
128
all_weights.extend(expanded_weights)
129
return all_urls, all_weights
130
131
132
def get_dataset_size(shards: str):
133
shards_list, _ = expand_urls(shards)
134
dir_path = os.path.dirname(shards_list[0])
135
sizes_filename = os.path.join(dir_path, "sizes.json")
136
len_filename = os.path.join(dir_path, "__len__")
137
if os.path.exists(sizes_filename):
138
with open(sizes_filename) as f:
139
sizes = json.load(f)
140
total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list])
141
elif os.path.exists(len_filename):
142
with open(len_filename) as f:
143
total_size = ast.literal_eval(f.read())
144
else:
145
raise ValueError("Please specify the number of samples in the dataset.")
146
num_shards = len(shards_list)
147
return total_size, num_shards
148
149
150
class det_shuffle(wds.PipelineStage):
151
def __init__(
152
self,
153
bufsize: int = 1000,
154
initial: int = 100,
155
seed: int = 0,
156
epoch: int = -1,
157
):
158
super().__init__()
159
self.bufsize = bufsize
160
self.initial = initial
161
self.seed = seed
162
self.epoch = epoch
163
164
def run(self, src: Iterable[dict]):
165
epoch = self.epoch.get_value()
166
rng = random.Random()
167
rng.seed(self.seed + epoch)
168
return _shuffle(src, self.bufsize, self.initial, rng)
169
170
171
class SimpleImageDataset:
172
def __init__(
173
self,
174
train_shards_path: str | list[str] | None = None,
175
eval_shards_path: str | list[str] | None = None,
176
per_gpu_batch_size: int = 32,
177
world_size: int = 1,
178
train_field_names: dict[str, str] | None = None,
179
eval_field_names: dict[str, str] | None = None,
180
num_train_examples: int | None = None,
181
num_workers_per_gpu: int = 4,
182
crop_size: int = 256,
183
random_crop: bool = True,
184
random_flip: bool = True,
185
normalize_mean: tuple[float, float, float] = (0.0, 0.0, 0.0),
186
normalize_std: tuple[float, float, float] = (1.0, 1.0, 1.0),
187
res_ratio_filtering: bool = False,
188
res_filter_params: tuple[int, float, float] = (256, 0.5, 2.0),
189
seed: int = 42,
190
tokenizer: Callable = None,
191
apply_transform: bool = True,
192
):
193
if train_shards_path is None and eval_shards_path is None:
194
raise ValueError("At least one of train_shards_path or eval_shards_path must be provided")
195
196
if train_shards_path is not None:
197
train_shards_path = normalize_shards_path(train_shards_path)
198
if eval_shards_path is not None:
199
eval_shards_path = normalize_shards_path(eval_shards_path)
200
201
if train_field_names is None:
202
train_field_names = {"image": "png;jpg;jpeg;webp", "text": "txt"}
203
if eval_field_names is None:
204
eval_field_names = {"image": "png;jpg;jpeg;webp"}
205
206
transform = ImageTransform(
207
crop_size=crop_size,
208
random_crop=random_crop,
209
random_flip=random_flip,
210
normalize_mean=normalize_mean,
211
normalize_std=normalize_std,
212
)
213
214
if train_shards_path is not None:
215
self.num_train_examples = num_train_examples
216
if self.num_train_examples is None:
217
self.num_train_examples, self.num_shards = get_dataset_size(
218
train_shards_path
219
)
220
assert self.num_shards >= num_workers_per_gpu * world_size, (
221
"number of shards must be >= total workers"
222
)
223
self.shared_epoch = SharedEpoch(epoch=0)
224
train_pipeline = [
225
wds.SimpleShardList(train_shards_path),
226
det_shuffle(
227
bufsize=_SHARD_SHUFFLE_SIZE,
228
initial=_SHARD_SHUFFLE_INITIAL,
229
seed=seed,
230
epoch=self.shared_epoch,
231
),
232
wds.split_by_node,
233
wds.split_by_worker,
234
wds.tarfile_to_samples(handler=wds.warn_and_continue),
235
wds.shuffle(
236
bufsize=_SAMPLE_SHUFFLE_SIZE,
237
initial=_SAMPLE_SHUFFLE_INITIAL,
238
),
239
]
240
if res_ratio_filtering:
241
train_pipeline.append(wds.select(filter_by_res_ratio(*res_filter_params)))
242
243
train_pipeline.extend(
244
[
245
wds.decode(
246
wds.autodecode.ImageHandler("pil"),
247
handler=wds.warn_and_continue,
248
),
249
wds.rename(
250
**train_field_names,
251
handler=wds.warn_and_continue,
252
),
253
wds.map(
254
lambda sample: {
255
k: v for k, v in sample.items() if k in train_field_names
256
}
257
),
258
]
259
)
260
if apply_transform:
261
if tokenizer is not None:
262
train_pipeline.append(
263
wds.map_dict(
264
image=transform.train_transform,
265
text=lambda text: tokenizer(text)[0],
266
handler=wds.warn_and_continue,
267
)
268
)
269
else:
270
train_pipeline.append(wds.map_dict(image=transform.train_transform))
271
train_pipeline.append(
272
wds.batched(
273
per_gpu_batch_size, partial=False, collation_fn=default_collate
274
)
275
)
276
277
if apply_transform:
278
num_batches = math.ceil(
279
self.num_train_examples / (per_gpu_batch_size * world_size)
280
)
281
num_worker_batches = math.ceil(
282
self.num_train_examples
283
/ (per_gpu_batch_size * world_size * num_workers_per_gpu)
284
)
285
num_batches = num_worker_batches * num_workers_per_gpu
286
num_samples = num_batches * per_gpu_batch_size * world_size
287
288
self._train_dataset = wds.DataPipeline(*train_pipeline).with_epoch(
289
num_worker_batches
290
)
291
self._train_dataloader = wds.WebLoader(
292
self._train_dataset,
293
batch_size=None,
294
shuffle=False,
295
num_workers=num_workers_per_gpu,
296
pin_memory=True,
297
persistent_workers=True,
298
)
299
self._train_dataloader.num_batches = num_batches
300
self._train_dataloader.num_samples = num_samples
301
else:
302
# No batching when apply_transform=False, just return raw samples
303
self._train_dataset = wds.DataPipeline(*train_pipeline)
304
self._train_dataloader = None
305
else:
306
self.shared_epoch = None
307
self._train_dataset = None
308
self._train_dataloader = None
309
310
if eval_shards_path is not None:
311
eval_pipeline = [
312
wds.SimpleShardList(eval_shards_path),
313
wds.split_by_node,
314
wds.split_by_worker,
315
wds.tarfile_to_samples(handler=wds.ignore_and_continue),
316
wds.decode(
317
wds.autodecode.ImageHandler("pil"),
318
handler=wds.warn_and_continue,
319
),
320
wds.rename(
321
**eval_field_names,
322
handler=wds.warn_and_continue,
323
),
324
wds.map(
325
lambda sample: {
326
k: v for k, v in sample.items() if k in eval_field_names
327
}
328
),
329
]
330
if apply_transform:
331
eval_pipeline.extend(
332
[
333
wds.map_dict(
334
image=transform.eval_transform,
335
handler=wds.warn_and_continue,
336
),
337
wds.batched(
338
per_gpu_batch_size,
339
partial=False,
340
collation_fn=default_collate,
341
),
342
]
343
)
344
345
self._eval_dataset = wds.DataPipeline(*eval_pipeline)
346
347
if apply_transform:
348
self._eval_dataloader = wds.WebLoader(
349
self._eval_dataset,
350
batch_size=None,
351
shuffle=False,
352
num_workers=num_workers_per_gpu,
353
pin_memory=True,
354
persistent_workers=False,
355
)
356
else:
357
self._eval_dataloader = None
358
else:
359
self._eval_dataset = None
360
self._eval_dataloader = None
361
362
def set_epoch(self, epoch: int):
363
if self.shared_epoch is not None:
364
self.shared_epoch.set_value(epoch)
365
366
@property
367
def train_dataset(self):
368
return self._train_dataset
369
370
@property
371
def train_dataloader(self):
372
return self._train_dataloader
373
374
@property
375
def eval_dataset(self):
376
return self._eval_dataset
377
378
@property
379
def eval_dataloader(self):
380
return self._eval_dataloader
381
382
383
if __name__ == "__main__":
384
train_url = "https://huggingface.co/datasets/timm/imagenet-12k-wds/resolve/main/imagenet12k-train-{{0000..1023}}.tar"
385
eval_url = "https://huggingface.co/datasets/timm/imagenet-12k-wds/resolve/main/imagenet12k-validation-{{0000..0511}}.tar"
386
387
data = SimpleImageDataset(
388
train_url,
389
eval_url,
390
per_gpu_batch_size=4,
391
world_size=4,
392
num_train_examples=120000,
393
train_field_names={
394
"image": "jpg",
395
"cls": "cls",
396
},
397
)
398
399
train_dataloader, eval_dataloader = data.train_dataloader, data.eval_dataloader
400
401
for batch in train_dataloader:
402
print(batch)
403
break
404
405
for batch in eval_dataloader:
406
print(batch)
407
break

References

  1. WebDataset
  2. https://github.com/bytedance/1d-tokenizer/blob/main/data/webdataset_reader.py
  3. https://github.com/mlfoundations/open_clip/blob/main/src/open_clip_train/data.py