diff --git a/src/backend/executor/nodeHash.c b/src/backend/executor/nodeHash.c index c90fe40b3c9b21073943b2b32117069052917e87..5d0fc77c3015a739e2270d16a704818aa3bbedc9 100644 --- a/src/backend/executor/nodeHash.c +++ b/src/backend/executor/nodeHash.c @@ -500,7 +500,9 @@ ExecChooseHashTableSize(double ntuples, int tupwidth, bool useskew, * Both nbuckets and nbatch must be powers of 2 to make * ExecHashGetBucketAndBatch fast. We already fixed nbatch; now inflate * nbuckets to the next larger power of 2. We also force nbuckets to not - * be real small, by starting the search at 2^10. + * be real small, by starting the search at 2^10. (Note: above we made + * sure that nbuckets is not more than INT_MAX / 2, so this loop cannot + * overflow, nor can the final shift to recalculate nbuckets.) */ i = 10; while ((1 << i) < nbuckets) diff --git a/src/backend/utils/hash/dynahash.c b/src/backend/utils/hash/dynahash.c index 31ac2b25e8ffa2f50dfa4c9bde0d248baf2c0979..07f7a84943029f876abcd853c7b2b85bdd0e048a 100644 --- a/src/backend/utils/hash/dynahash.c +++ b/src/backend/utils/hash/dynahash.c @@ -68,6 +68,8 @@ #include "postgres.h" +#include <limits.h> + #include "access/xact.h" #include "storage/shmem.h" #include "storage/spin.h" @@ -205,6 +207,8 @@ static void hdefault(HTAB *hashp); static int choose_nelem_alloc(Size entrysize); static bool init_htab(HTAB *hashp, long nelem); static void hash_corrupted(HTAB *hashp); +static long next_pow2_long(long num); +static int next_pow2_int(long num); static void register_seq_scan(HTAB *hashp); static void deregister_seq_scan(HTAB *hashp); static bool has_seq_scans(HTAB *hashp); @@ -379,8 +383,13 @@ hash_create(const char *tabname, long nelem, HASHCTL *info, int flags) { /* Doesn't make sense to partition a local hash table */ Assert(flags & HASH_SHARED_MEM); - /* # of partitions had better be a power of 2 */ - Assert(info->num_partitions == (1L << my_log2(info->num_partitions))); + + /* + * The number of partitions had better be a power of 2. Also, it must + * be less than INT_MAX (see init_htab()), so call the int version of + * next_pow2. + */ + Assert(info->num_partitions == next_pow2_int(info->num_partitions)); hctl->num_partitions = info->num_partitions; } @@ -523,7 +532,6 @@ init_htab(HTAB *hashp, long nelem) { HASHHDR *hctl = hashp->hctl; HASHSEGMENT *segp; - long lnbuckets; int nbuckets; int nsegs; @@ -538,9 +546,7 @@ init_htab(HTAB *hashp, long nelem) * number of buckets. Allocate space for the next greater power of two * number of buckets */ - lnbuckets = (nelem - 1) / hctl->ffactor + 1; - - nbuckets = 1 << my_log2(lnbuckets); + nbuckets = next_pow2_int((nelem - 1) / hctl->ffactor + 1); /* * In a partitioned table, nbuckets must be at least equal to @@ -558,7 +564,7 @@ init_htab(HTAB *hashp, long nelem) * Figure number of directory segments needed, round up to a power of 2 */ nsegs = (nbuckets - 1) / hctl->ssize + 1; - nsegs = 1 << my_log2(nsegs); + nsegs = next_pow2_int(nsegs); /* * Make sure directory is big enough. If pre-allocated directory is too @@ -628,9 +634,9 @@ hash_estimate_size(long num_entries, Size entrysize) elementAllocCnt; /* estimate number of buckets wanted */ - nBuckets = 1L << my_log2((num_entries - 1) / DEF_FFACTOR + 1); + nBuckets = next_pow2_long((num_entries - 1) / DEF_FFACTOR + 1); /* # of segments needed for nBuckets */ - nSegments = 1L << my_log2((nBuckets - 1) / DEF_SEGSIZE + 1); + nSegments = next_pow2_long((nBuckets - 1) / DEF_SEGSIZE + 1); /* directory entries */ nDirEntries = DEF_DIRSIZE; while (nDirEntries < nSegments) @@ -671,9 +677,9 @@ hash_select_dirsize(long num_entries) nDirEntries; /* estimate number of buckets wanted */ - nBuckets = 1L << my_log2((num_entries - 1) / DEF_FFACTOR + 1); + nBuckets = next_pow2_long((num_entries - 1) / DEF_FFACTOR + 1); /* # of segments needed for nBuckets */ - nSegments = 1L << my_log2((nBuckets - 1) / DEF_SEGSIZE + 1); + nSegments = next_pow2_long((nBuckets - 1) / DEF_SEGSIZE + 1); /* directory entries */ nDirEntries = DEF_DIRSIZE; while (nDirEntries < nSegments) @@ -1408,11 +1414,32 @@ my_log2(long num) int i; long limit; + /* guard against too-large input, which would put us into infinite loop */ + if (num > LONG_MAX / 2) + num = LONG_MAX / 2; + for (i = 0, limit = 1; limit < num; i++, limit <<= 1) ; return i; } +/* calculate first power of 2 >= num, bounded to what will fit in a long */ +static long +next_pow2_long(long num) +{ + /* my_log2's internal range check is sufficient */ + return 1L << my_log2(num); +} + +/* calculate first power of 2 >= num, bounded to what will fit in an int */ +static int +next_pow2_int(long num) +{ + if (num > INT_MAX / 2) + num = INT_MAX / 2; + return 1 << my_log2(num); +} + /************************* SEQ SCAN TRACKING ************************/