Skip to content

Commit 9332649

Browse files
authored
picklable batch_fn (#7826)
1 parent d10e846 commit 9332649

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/datasets/iterable_dataset.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3581,15 +3581,12 @@ def batch(self, batch_size: int, drop_last_batch: bool = False) -> "IterableData
35813581
```
35823582
"""
35833583

3584-
def batch_fn(unbatched):
3585-
return {k: [v] for k, v in unbatched.items()}
3586-
35873584
if self.features:
35883585
features = Features({col: List(feature) for col, feature in self.features.items()})
35893586
else:
35903587
features = None
35913588
return self.map(
3592-
batch_fn, batched=True, batch_size=batch_size, drop_last_batch=drop_last_batch, features=features
3589+
_batch_fn, batched=True, batch_size=batch_size, drop_last_batch=drop_last_batch, features=features
35933590
)
35943591

35953592
def to_dict(self, batch_size: Optional[int] = None, batched: bool = False) -> Union[dict, Iterator[dict]]:
@@ -4659,3 +4656,7 @@ async def _apply_async(pool, func, x):
46594656
return future.get()
46604657
else:
46614658
await asyncio.sleep(0)
4659+
4660+
4661+
def _batch_fn(unbatched):
4662+
return {k: [v] for k, v in unbatched.items()}

0 commit comments

Comments
 (0)