WebDataset
1. WebDataset
WebDataset store data in tar files, each file contains a set of samples. Some common formats are:
.cls,.cls2,.index,.inx, and.idare commonly used for integers.jsonand.mp(msgpack) are used for more complex annotations of images.jpgand.jpgis recommended for images.pickleis recommended for general Python data structures (but is not portable).pthand.pydPyTorch dumps, good for torch tensors.npyand.npzare NumPy array dumps
1.1. Prepare data
To prepare data for WebDataset, we can use the following code:
1import os2import time34import PIL.PngImagePlugin5import webdataset as wds6from datasets import load_dataset7from huggingface_hub import HfApi8from tqdm import tqdm910PIL.PngImagePlugin.MAX_TEXT_CHUNK = 10 * (1024**2) # 10MB11os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "300"1213hf_api = HfApi()14output_dir = ""15os.makedirs(output_dir, exist_ok=True)1617max_retries = 318for attempt in range(max_retries):19 try:20 train_dataset = load_dataset(21 "",22 split="train",23 streaming=True,24 )25 train_dataset = train_dataset.decode(False)26 break27 except Exception as e:28 print(f"Load dataset failed (attempt {attempt + 1}/{max_retries}): {e}")29 if attempt < max_retries - 1:30 time.sleep(30 * (attempt + 1)) # exponential backoff31 else:32 raise3334repo_id = ""3536start_shard = 037maxcount = 10000038skip_samples = start_shard * maxcount394041def upload_to_hf(fname):42 """Upload completed shard to Hugging Face"""43 try:44 print(f"Uploading {fname} to Hugging Face...")45 hf_api.upload_file(46 path_or_fileobj=fname,47 path_in_repo=os.path.basename(fname),48 repo_id=repo_id,49 repo_type="dataset",50 commit_message=f"Add shard {os.path.basename(fname)}",51 )52 print(f"✅ Successfully uploaded {fname}")5354 # Optional: delete local file to save space55 os.remove(fname)5657 except Exception as e:58 print(f"❌ Upload failed {fname}: {e}")596061with wds.writer.ShardWriter(62 os.path.join(output_dir, "%05d.tar"),63 maxcount=maxcount,64 maxsize=50000000000, # 50GB65 post=upload_to_hf,66 start_shard=start_shard,67) as shard_writer:68 error_count = 069 for i, data in enumerate(tqdm(train_dataset, desc="store to webdataset format")):70 if i < skip_samples:71 continue72 try:73 image = data["image"]["bytes"]74 class_label = data["class"]75 id = data["id"]76 recaption = data["recaption"]77 recaption_short = data["recaption_short"]78 height = data["height"]79 width = data["width"]8081 sample = {82 "__key__": f"sample_{i:08d}",83 "image.jpg": image, # PIL Image object84 "class.txt": class_label, # string85 "id.txt": id, # string86 "recaption.txt": recaption, # string87 "recaption_short.txt": recaption_short, # string88 "height.cls": height, # int89 "width.cls": width, # int90 }91 shard_writer.write(sample)92 except Exception as e:93 error_count += 194 print(f"Error {error_count}: {e}")95 continue96 print(f"Error count: {error_count}")
1.2. Use data
1import ast2import json3import math4import os5import random6from collections.abc import Callable, Iterable7from multiprocessing import Value89import braceexpand10import webdataset as wds11from huggingface_hub import HfFileSystem, get_token, hf_hub_url, repo_exists12from PIL import Image13from torch.utils.data import default_collate14from torchvision import transforms15from webdataset.filters import _shuffle1617Image.MAX_IMAGE_PIXELS = None18_SHARD_SHUFFLE_SIZE = 100019_SHARD_SHUFFLE_INITIAL = 10020_SAMPLE_SHUFFLE_SIZE = 500021_SAMPLE_SHUFFLE_INITIAL = 1000222324def filter_by_res_ratio(25 min_res: int = 256, min_ratio: float = 0.5, max_ratio: float = 2.026):27 def _f(sample: dict):28 h, w = sample["height"], sample["width"]29 ratio = h / w30 longer_side = max(h, w)31 return ratio >= min_ratio and ratio <= max_ratio and longer_side >= min_res3233 return _f343536class ImageTransform:37 def __init__(38 self,39 crop_size: int = 256,40 random_crop: bool = True,41 random_flip: bool = True,42 normalize_mean: tuple[float, float, float] = (0.0, 0.0, 0.0),43 normalize_std: tuple[float, float, float] = (1.0, 1.0, 1.0),44 ):45 train_transform = []46 if random_crop:47 train_transform.append(transforms.RandomResizedCrop(crop_size))48 else:49 train_transform.extend(50 [51 transforms.Resize(crop_size),52 transforms.CenterCrop(crop_size),53 ]54 )55 if random_flip:56 train_transform.append(transforms.RandomHorizontalFlip())57 # normalize_mean = [0, 0, 0] and normalize_std = [1, 1, 1] will normalize images into [0, 1],58 # normalize_mean = [0.5, 0.5, 0.5] and normalize_std = [0.5, 0.5, 0.5] will normalize images into [-1, 1].59 train_transform.extend(60 [61 transforms.ToTensor(),62 transforms.Normalize(normalize_mean, normalize_std),63 ]64 )6566 self.train_transform = transforms.Compose(train_transform)67 self.eval_transform = transforms.Compose(68 [69 transforms.Resize(crop_size),70 transforms.CenterCrop(crop_size),71 transforms.ToTensor(),72 transforms.Normalize(normalize_mean, normalize_std),73 ]74 )75 print(f"self.train_transform: {self.train_transform}")76 print(f"self.eval_transform: {self.eval_transform}")777879class SharedEpoch:80 def __init__(self, epoch: int = 0):81 self.shared_epoch = Value("i", epoch)8283 def set_value(self, epoch: int):84 self.shared_epoch.value = epoch8586 def get_value(self):87 return self.shared_epoch.value888990def normalize_shards_path(shards_path: str) -> str:91 if shards_path.startswith("http"):92 return f"pipe: curl -s -L -H 'Authorization:Bearer {get_token()}' {shards_path}"9394 if repo_exists(shards_path, repo_type="dataset"):95 fs = HfFileSystem()96 glob_pattern = f"hf://datasets/{shards_path}/**/*.tar"97 files = [fs.resolve_path(path) for path in fs.glob(glob_pattern)]98 if not files:99 raise ValueError(100 f"No tar files found in HuggingFace dataset: {shards_path}"101 )102 urls = [103 hf_hub_url(file.repo_id, file.path_in_repo, repo_type="dataset")104 for file in files105 ]106 shards_path = "::".join(urls)107 return f"pipe: curl -s -L -H 'Authorization:Bearer {get_token()}' {shards_path}"108109 return shards_path110111112def expand_urls(urls: str, weights: str | list[float] | None = None):113 if weights is None:114 expanded_urls = wds.shardlists.expand_urls(urls)115 return expanded_urls, None116 else:117 url_list = urls.split("::")118 weights = weights.split("::")119 assert len(weights) == len(url_list), (120 f"Expected the number of data components ({len(url_list)}) and weights({len(weights)}) to match."121 )122 weights = [float(weight) for weight in weights]123 all_urls, all_weights = [], []124 for url, weight in zip(url_list, weights, strict=True):125 expanded_url = list(braceexpand.braceexpand(url))126 expanded_weights = [weight for _ in expanded_url]127 all_urls.extend(expanded_url)128 all_weights.extend(expanded_weights)129 return all_urls, all_weights130131132def get_dataset_size(shards: str):133 shards_list, _ = expand_urls(shards)134 dir_path = os.path.dirname(shards_list[0])135 sizes_filename = os.path.join(dir_path, "sizes.json")136 len_filename = os.path.join(dir_path, "__len__")137 if os.path.exists(sizes_filename):138 with open(sizes_filename) as f:139 sizes = json.load(f)140 total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list])141 elif os.path.exists(len_filename):142 with open(len_filename) as f:143 total_size = ast.literal_eval(f.read())144 else:145 raise ValueError("Please specify the number of samples in the dataset.")146 num_shards = len(shards_list)147 return total_size, num_shards148149150class det_shuffle(wds.PipelineStage):151 def __init__(152 self,153 bufsize: int = 1000,154 initial: int = 100,155 seed: int = 0,156 epoch: int = -1,157 ):158 super().__init__()159 self.bufsize = bufsize160 self.initial = initial161 self.seed = seed162 self.epoch = epoch163164 def run(self, src: Iterable[dict]):165 epoch = self.epoch.get_value()166 rng = random.Random()167 rng.seed(self.seed + epoch)168 return _shuffle(src, self.bufsize, self.initial, rng)169170171class SimpleImageDataset:172 def __init__(173 self,174 train_shards_path: str | list[str] | None = None,175 eval_shards_path: str | list[str] | None = None,176 per_gpu_batch_size: int = 32,177 world_size: int = 1,178 train_field_names: dict[str, str] | None = None,179 eval_field_names: dict[str, str] | None = None,180 num_train_examples: int | None = None,181 num_workers_per_gpu: int = 4,182 crop_size: int = 256,183 random_crop: bool = True,184 random_flip: bool = True,185 normalize_mean: tuple[float, float, float] = (0.0, 0.0, 0.0),186 normalize_std: tuple[float, float, float] = (1.0, 1.0, 1.0),187 res_ratio_filtering: bool = False,188 res_filter_params: tuple[int, float, float] = (256, 0.5, 2.0),189 seed: int = 42,190 tokenizer: Callable = None,191 apply_transform: bool = True,192 ):193 if train_shards_path is None and eval_shards_path is None:194 raise ValueError("At least one of train_shards_path or eval_shards_path must be provided")195196 if train_shards_path is not None:197 train_shards_path = normalize_shards_path(train_shards_path)198 if eval_shards_path is not None:199 eval_shards_path = normalize_shards_path(eval_shards_path)200201 if train_field_names is None:202 train_field_names = {"image": "png;jpg;jpeg;webp", "text": "txt"}203 if eval_field_names is None:204 eval_field_names = {"image": "png;jpg;jpeg;webp"}205206 transform = ImageTransform(207 crop_size=crop_size,208 random_crop=random_crop,209 random_flip=random_flip,210 normalize_mean=normalize_mean,211 normalize_std=normalize_std,212 )213214 if train_shards_path is not None:215 self.num_train_examples = num_train_examples216 if self.num_train_examples is None:217 self.num_train_examples, self.num_shards = get_dataset_size(218 train_shards_path219 )220 assert self.num_shards >= num_workers_per_gpu * world_size, (221 "number of shards must be >= total workers"222 )223 self.shared_epoch = SharedEpoch(epoch=0)224 train_pipeline = [225 wds.SimpleShardList(train_shards_path),226 det_shuffle(227 bufsize=_SHARD_SHUFFLE_SIZE,228 initial=_SHARD_SHUFFLE_INITIAL,229 seed=seed,230 epoch=self.shared_epoch,231 ),232 wds.split_by_node,233 wds.split_by_worker,234 wds.tarfile_to_samples(handler=wds.warn_and_continue),235 wds.shuffle(236 bufsize=_SAMPLE_SHUFFLE_SIZE,237 initial=_SAMPLE_SHUFFLE_INITIAL,238 ),239 ]240 if res_ratio_filtering:241 train_pipeline.append(wds.select(filter_by_res_ratio(*res_filter_params)))242243 train_pipeline.extend(244 [245 wds.decode(246 wds.autodecode.ImageHandler("pil"),247 handler=wds.warn_and_continue,248 ),249 wds.rename(250 **train_field_names,251 handler=wds.warn_and_continue,252 ),253 wds.map(254 lambda sample: {255 k: v for k, v in sample.items() if k in train_field_names256 }257 ),258 ]259 )260 if apply_transform:261 if tokenizer is not None:262 train_pipeline.append(263 wds.map_dict(264 image=transform.train_transform,265 text=lambda text: tokenizer(text)[0],266 handler=wds.warn_and_continue,267 )268 )269 else:270 train_pipeline.append(wds.map_dict(image=transform.train_transform))271 train_pipeline.append(272 wds.batched(273 per_gpu_batch_size, partial=False, collation_fn=default_collate274 )275 )276277 if apply_transform:278 num_batches = math.ceil(279 self.num_train_examples / (per_gpu_batch_size * world_size)280 )281 num_worker_batches = math.ceil(282 self.num_train_examples283 / (per_gpu_batch_size * world_size * num_workers_per_gpu)284 )285 num_batches = num_worker_batches * num_workers_per_gpu286 num_samples = num_batches * per_gpu_batch_size * world_size287288 self._train_dataset = wds.DataPipeline(*train_pipeline).with_epoch(289 num_worker_batches290 )291 self._train_dataloader = wds.WebLoader(292 self._train_dataset,293 batch_size=None,294 shuffle=False,295 num_workers=num_workers_per_gpu,296 pin_memory=True,297 persistent_workers=True,298 )299 self._train_dataloader.num_batches = num_batches300 self._train_dataloader.num_samples = num_samples301 else:302 # No batching when apply_transform=False, just return raw samples303 self._train_dataset = wds.DataPipeline(*train_pipeline)304 self._train_dataloader = None305 else:306 self.shared_epoch = None307 self._train_dataset = None308 self._train_dataloader = None309310 if eval_shards_path is not None:311 eval_pipeline = [312 wds.SimpleShardList(eval_shards_path),313 wds.split_by_node,314 wds.split_by_worker,315 wds.tarfile_to_samples(handler=wds.ignore_and_continue),316 wds.decode(317 wds.autodecode.ImageHandler("pil"),318 handler=wds.warn_and_continue,319 ),320 wds.rename(321 **eval_field_names,322 handler=wds.warn_and_continue,323 ),324 wds.map(325 lambda sample: {326 k: v for k, v in sample.items() if k in eval_field_names327 }328 ),329 ]330 if apply_transform:331 eval_pipeline.extend(332 [333 wds.map_dict(334 image=transform.eval_transform,335 handler=wds.warn_and_continue,336 ),337 wds.batched(338 per_gpu_batch_size,339 partial=False,340 collation_fn=default_collate,341 ),342 ]343 )344345 self._eval_dataset = wds.DataPipeline(*eval_pipeline)346347 if apply_transform:348 self._eval_dataloader = wds.WebLoader(349 self._eval_dataset,350 batch_size=None,351 shuffle=False,352 num_workers=num_workers_per_gpu,353 pin_memory=True,354 persistent_workers=False,355 )356 else:357 self._eval_dataloader = None358 else:359 self._eval_dataset = None360 self._eval_dataloader = None361362 def set_epoch(self, epoch: int):363 if self.shared_epoch is not None:364 self.shared_epoch.set_value(epoch)365366 @property367 def train_dataset(self):368 return self._train_dataset369370 @property371 def train_dataloader(self):372 return self._train_dataloader373374 @property375 def eval_dataset(self):376 return self._eval_dataset377378 @property379 def eval_dataloader(self):380 return self._eval_dataloader381382383if __name__ == "__main__":384 train_url = "https://huggingface.co/datasets/timm/imagenet-12k-wds/resolve/main/imagenet12k-train-{{0000..1023}}.tar"385 eval_url = "https://huggingface.co/datasets/timm/imagenet-12k-wds/resolve/main/imagenet12k-validation-{{0000..0511}}.tar"386387 data = SimpleImageDataset(388 train_url,389 eval_url,390 per_gpu_batch_size=4,391 world_size=4,392 num_train_examples=120000,393 train_field_names={394 "image": "jpg",395 "cls": "cls",396 },397 )398399 train_dataloader, eval_dataloader = data.train_dataloader, data.eval_dataloader400401 for batch in train_dataloader:402 print(batch)403 break404405 for batch in eval_dataloader:406 print(batch)407 break