Sampler
BucketBatchSampler
基类:Sampler[List[int]]
一个批采样器,用于创建批次,其中批次中元素的大小来自每个预定义的桶范围。
数据集的元素首先根据桶范围和元素的大小分组到每个桶中。然后,为每个桶使用一个基本批采样器来创建小批量。
桶范围由 bucket_boundaries
指定,它将首先在内部排序,并用于创建 len(bucket_boundaries) - 1
个左闭右开区间。例如,如果 bucket_boundaries 张量为 [10, 5, 0, 16],它将被排序为 [0, 5, 10, 16],并将创建 3 个桶,范围为:[0, 5)、[5, 10)、[10, 16)。
基本批采样器将通过传递每个桶中的元素索引作为数据源,并将 base_batch_sampler_shared_kwargs
和 base_batch_sampler_individual_kwargs
传递给指定为 base_batch_sampler_class
的基本批采样器类的构造函数来创建。例如,base_batch_sampler_shared_kwargs = {'drop_last': True}
和 base_batch_sampler_individual_kwargs = {'batch_size': [8,10,12]}
将用于创建 3 个批采样器,其中 drop_last=True,batch_size 分别为 8、10 和 12,并像 base_batch_sampler_class(bucket_element_indices[0], batch_size=8, drop_last=True)
这样初始化。
在 __iter__
方法中,如果 shuffle
为 True
,则每个桶中的元素索引将被洗牌,并且每次随机选择一个桶来创建小批量。如果 shuffle
为 False
,则不会对元素索引进行洗牌,并且桶将按其区间边界的升序选择。
此类用于创建同质数据批次,用于训练或评估,并减少对齐元素形状所需的填充。
修改自 https://github.com/rssrwn/semla-flow/blob/main/semlaflow/data/util.py
示例
>>> import torch
>>> from bionemo.size_aware_batching.sampler import BucketBatchSampler
>>> # Define the sizes for a dataset
>>> sizes = torch.arange(25)
>>> # Define bucket ranges
>>> bucket_boundaries = torch.tensor([0, 6, 15, 25])
>>> # Create a bucket batch sampler with torch.utils.data.BatchSampler as base batch sampler
>>> # As there are 3 buckets, there will be 3 base batch samplers with batch sizes 2, 3, and 5.
>>> batch_sampler = BucketBatchSampler(
sizes=sizes,
bucket_boundaries=bucket_boundaries,
base_batch_sampler_class=torch.utils.data.BatchSampler,
base_batch_sampler_shared_kwargs={'drop_last': False},
base_batch_sampler_individual_kwargs={'batch_size': [2,3,5]},
shuffle=False,
)
>>> # Iterate over batches of indices that lies in the same bucket and with different batch sizes.
>>> print(list(batch_sampler))
[[0, 1], [2, 3], [4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14], [15, 16, 17, 18, 19], [20, 21, 22, 23, 24]]
>>> # randomize the dataset and buckets
>>> batch_sampler = BucketBatchSampler(
sizes=sizes,
bucket_boundaries=bucket_boundaries,
base_batch_sampler_class=torch.utils.data.BatchSampler,
base_batch_sampler_shared_kwargs={'drop_last': False},
base_batch_sampler_individual_kwargs={'batch_size': [2,3,5]},
shuffle=True,
generator=torch.Generator().manual_seed(0),
)
>>> print(list(batch_sampler))
[[24, 17, 16, 22, 19], [2, 5], [12, 10, 11], [3, 0], [15, 18, 20, 21, 23], [7, 13, 6], [14, 9, 8], [1, 4]]
>>> print(list(batch_sampler))
[[14, 9, 13], [23, 16, 20, 21, 15], [5, 0], [8, 10, 11], [17, 24, 22, 18, 19], [12, 6, 7], [4, 2], [3, 1]]
>>> # Combine with SizeAwareBatchSampler to control the cost of each batch
>>> from bionemo.size_aware_batching.sampler import SizeAwareBatchSampler
>>> item_costs = sizes.tolist()
>>> def cost_of_element(index):
return item_costs[index]
>>> batch_sampler = BucketBatchSampler(
sizes=sizes,
bucket_boundaries=bucket_boundaries,
base_batch_sampler_class=SizeAwareBatchSampler,
base_batch_sampler_shared_kwargs={"sizeof": cost_of_element, "max_total_size": 40},
base_batch_sampler_individual_kwargs={},
shuffle=True,
generator=torch.Generator().manual_seed(0),
)
>>> print(list(iter(batch_sampler)))
[[24], [2, 5, 3, 0, 1, 4], [12, 10, 11, 7], [13, 6, 14], [17, 16], [22], [19, 15], [9, 8], [18, 20], [21], [23]]
源代码位于 bionemo/size_aware_batching/sampler.py
中
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 |
|
__init__(sizes, bucket_boundaries, base_batch_sampler_class, base_batch_sampler_shared_kwargs=None, base_batch_sampler_individual_kwargs=None, shuffle=True, generator=None)
初始化 BucketBatchSampler。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
sizes
|
Tensor
|
表示数据集中每个元素大小的实数一维张量。 |
必需 |
bucket_boundaries
|
Tensor
|
表示桶范围边界的实数一维张量。它将首先被排序,并用于创建 |
必需 |
base_batch_sampler_class
|
Type[S]
|
基本批采样器类类型,它将用于每个桶,并使用桶元素索引、 |
必需 |
base_batch_sampler_shared_kwargs
|
Optional[Dict[str, Any]]
|
用于初始化所有桶的所有基本批采样器的共享关键字参数字典。应为 |
None
|
base_batch_sampler_individual_kwargs
|
Optional[Dict[str, Iterable]]
|
用于使用相应的键值对初始化每个桶批采样器的关键字参数字典。此字典中每个值的长度必须等于 len(bucket_boundaries) - 1(桶的数量)。应为 |
None
|
shuffle
|
Optional[bool]
|
一个布尔值,指示是否洗牌数据集和桶。默认为 True。 |
True
|
generator
|
Optional[Generator]
|
采样中使用的生成器。默认为 None。 |
None
|
引发
类型 | 描述 |
---|---|
ValueError
|
如果 |
ValueError
|
如果 |
ValueError
|
如果 |
ValueError
|
如果 |
RuntimeError
|
如果桶范围 |
源代码位于 bionemo/size_aware_batching/sampler.py
中
365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 |
|
__iter__()
迭代索引批次。
此函数产生来自每个桶范围的大小元素的索引批次。
产生
类型 | 描述 |
---|---|
List[int]
|
List[int]:来自每个桶范围的大小元素的索引批次。 |
源代码位于 bionemo/size_aware_batching/sampler.py
中
561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 |
|
__len__()
获取批次数量。
只有在 base_batch_sampler_class
实现了 len() 时才能调用
返回
名称 | 类型 | 描述 |
---|---|---|
int |
int
|
批次数量 |
源代码位于 bionemo/size_aware_batching/sampler.py
中
550 551 552 553 554 555 556 557 558 559 |
|
SizeAwareBatchSampler
基类:Sampler[List[int]]
可变大小批处理数据采样器类,确保批次大小不超过最大值。
一种采样器,用于批处理大小不一的元素,同时确保每个批次的总大小不超过指定的最大值。
这在处理每个元素大小不同的数据集时非常有用,例如图形或长度不一的序列。采样器使用提供的 sizeof
函数来确定数据集中每个元素的大小,并确保每个批次的总大小不超过指定的 max_total_size
。
示例
>>> import torch
>>> from bionemo.size_aware_batching.sampler import SizeAwareBatchSampler
>>> # Define a sample dataset with torch.tensor
>>> dataset = [torch.tensor([1, 2]), torch.tensor([3, 4]), torch.tensor([5, 6]),
... torch.tensor([7, 8]), torch.tensor([9, 10])]
>>> # Define a function that returns the size of each element in the dataset.
>>> def sizeof(index):
... return dataset[index].numel()
>>> # Create a SizeAwareBatchSampler with a maximum total batch size of 10.
>>> batch_sampler = SizeAwareBatchSampler(
... sampler=torch.utils.data.SequentialSampler(dataset),
... sizeof=sizeof,
... max_total_size=4
... )
>>> # Iterate over batches of indices that do not exceed the maximum total size.
>>> print(list(batch_sampler))
[[0, 1], [2, 3], [4]]
源代码位于 bionemo/size_aware_batching/sampler.py
中
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 |
|
__init__(sampler, sizeof, max_total_size, info_logger=None, warn_logger=None)
初始化 SizeAwareBatchSampler。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
sampler
|
Union[Sampler[List[int]], Iterable[int]]
|
底层采样器。 |
必需 |
sizeof
|
Callable[[int], Real]
|
一个函数,它返回每个索引处的大小。例如,这可以用于确定一个元素消耗多少内存。其返回类型必须与 |
必需 |
max_total_size
|
Real
|
小批量的最大总大小。“大小”的语义由 |
必需 |
info_logger
|
Optional[Callable[[str], None]]
|
用于记录信息的函数。默认为 None。 |
None
|
warn_logger
|
Optional[Callable[[str], None]]
|
用于记录警告的函数。默认为 None。 |
None
|
引发
类型 | 描述 |
---|---|
TypeError
|
如果采样器不是 Sampler 或 Iterable 的实例,或者如果 sizeof 不是可调用对象、字典或序列容器。 |
ValueError
|
如果 max_total_size 不是正数。 |
源代码位于 bionemo/size_aware_batching/sampler.py
中
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 |
|
__iter__()
迭代索引批次。
此函数产生不超过最大总大小的索引批次。
产生
类型 | 描述 |
---|---|
List[int]
|
不超过最大总大小的索引批次。 |
源代码位于 bionemo/size_aware_batching/sampler.py
中
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 |
|
size_aware_batching(dataset, sizeof, max_total_size, collate_fn=None, info_logger=None, warn_logger=None)
创建一个批处理迭代器,其中每个批次大小根据内存消耗而变化(在最大限制内)。
一个生成器,用于从可迭代对象中批处理元素,同时确保每个批次的总大小不超过指定的最大值。这里的大小可以是批次中元素内存消耗的度量。这对于可索引数据或不可索引但可迭代的数据都很有用。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
dataset
|
Iterable[Data]
|
输入可迭代对象。 |
必需 |
sizeof
|
Callable[[Data], Real]
|
一个函数或映射,它返回 |
必需 |
max_total_size
|
Real
|
每个批次的最大总“大小”。“大小”的语义由 |
必需 |
collate_fn
|
Optional[Callable[[Iterable[Data]], BatchCollated]]
|
一个可选的函数,用于整理批次。默认为 None,在这种情况下,每个批次都是来自输入数据集的元素列表 |
None
|
info_logger
|
Optional[Callable[[str], None]]
|
用于记录信息的函数。默认为 None。 |
None
|
warn_logger
|
Optional[Callable[[str], None]]
|
用于记录警告的函数。默认为 None。 |
None
|
产生
类型 | 描述 |
---|---|
Union[List[Data], BatchCollated]
|
一个从 |
假设 1. 线性复杂度。此函数消耗给定的数据 Iterable (dataset
) 一次,通过逐个遍历数据项来构建批次,并在将下一个数据项添加到批次中会超过 max_total_size
或如果批次是最后一个批次(迭代结束)时立即产生批次。2. 可加性大小测量。对于构建具有批次内存消耗阈值的小批量的通用用例,它假设批次的大小是批次中所有元素的总和(可加性)。3. max_total_size
和 sizeof
返回的可比较类型。 sizeof
的返回值必须与 max_total_size
进行比较,以阈值化批次的大小
注意事项 1:生成的批次大小可能具有很大的差异 - 如何解决:使用批次大小阈值过滤此生成器的输出 2:不同 epoch 之间的批次数量可能会有很大差异。 - 如何解决:增加构成一个 epoch 的步骤数,例如,在 Lightning 训练/验证循环中,这有效地增加了每个 epoch 的输入数据集大小
示例
>>> import torch
>>> from torch.utils.data import default_collate
>>> from bionemo.size_aware_batching.sampler import size_aware_batching
>>> # Define a sample dataset with torch.tensor
>>> dataset = [torch.tensor([1, 2]), torch.tensor([3, 4]), torch.tensor([5, 6]),
... torch.tensor([7, 8]), torch.tensor([9, 10])]
>>> # Define a sizeof function that returns the size of each tensor
>>> def sizeof(x):
... return x.numel()
>>> # Create a generator with max_total_size=4 and default_collate_fn
>>> gen = size_aware_batching(dataset, sizeof, 4, collate_fn=default_collate)
>>> batches = list(gen)
>>> print(batches)
[tensor([[1, 2], [3, 4]]), tensor([[5, 6], [7, 8]]), tensor([[9, 10]])]
源代码位于 bionemo/size_aware_batching/sampler.py
中
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
|