Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setting sort_within_batch of pool to True when called in the method create_batches of BucketIterator? #641

Open
LeoLaugier opened this issue Nov 14, 2019 · 1 comment

Comments

@LeoLaugier
Copy link

LeoLaugier commented Nov 14, 2019

Hi,

I am working with torchtext and I had a question about the pool function in the iterator module.

I have a train_dataset, valid_dataset and test_dataset. I want to create a train iterator with minibatches of similar lengths, with random internal order, and eventually shuffle the order of the minibatches. For the valid and test set, I want to keep their initial orders and create batches sequentially, based on that order.

I found the splits / init method of BucketIterator rather counterintuitive to use and I agree with this post:

While Torchtext is brilliant, it’s sort_key based batching leaves a little to be desired. Often the sentences aren’t of the same length at all, and you end up feeding a lot of padding into your network

I'm a bit confused with why the argument sort_within_batch of pool is set to self.sort_within_batch when pool is called in the method create_batches of BucketIterator.
My issue is that if I want to effectively create minibatches of similar lengths, I have to set sort_within_batch to True when I call data.BucketIterator.splits.

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits((train_dataset, 
                                                                            valid_dataset,
                                                                            test_dataset), 
                                                                            batch_sizes=(train_batch_size,
                                                                                         valid_batch_size,
                                                                                         test_batch_size), 
                                                                            sort_key=lambda x: len(x.text),
                                                                            sort=False,
                                                                            sort_within_batch=True)

But then, in addition to sort the samples in the chunks / buckets, it will also sort the samples in the created minibatches (in iter), which is not needed. As a side effect it will also sort both the chunks / buckets and the minibatches in the validation and test iterators, which I don't want.

Wouldn't it be more intuitive to set the argument sort_within_batch of pool to True when called in the method create_batches of BucketIterator?
In that case, if you don't want to reorder the samples in the minibatches of the train, validation nor test set, you would do

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits((train_dataset, 
                                                                            valid_dataset,
                                                                            test_dataset), 
                                                                            batch_sizes=(train_batch_size,
                                                                                         valid_batch_size,
                                                                                         test_batch_size), 
                                                                            sort_key=lambda x: len(x.text),
                                                                            sort=False,
                                                                            sort_within_batch=False)

If you want to sort the samples in the minibatches of the validation and test set, you would do

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits((train_dataset, 
                                                                            valid_dataset,
                                                                            test_dataset), 
                                                                            batch_sizes=(train_batch_size,
                                                                                         valid_batch_size,
                                                                                         test_batch_size), 
                                                                            sort_key=lambda x: len(x.text))

And if you want to sort the samples in the minibatches of the train, validation and test set, you would do

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits((train_dataset, 
                                                                            valid_dataset,
                                                                            test_dataset), 
                                                                            batch_sizes=(train_batch_size,
                                                                                         valid_batch_size,
                                                                                         test_batch_size), 
                                                                            sort_key=lambda x: len(x.text),
                                                                            sort_within_batch=True)

Besides, you may want to let the factor 100 (for the chunk / bucket size) as a paremeter because it can be useful to tune it when working with toy datasets.

Please tell me if I am missing something.

Thanks.

@zhangguanheng66
Copy link
Contributor

@mttk I think this is a good example to decouple those functionals.

@LeoLaugier If you don't want to sort valid/test dataset, can you pass them separately to data.BucketIterator.splits()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants