From 7b405b3e04779fc0a026c9c6ac3e06194948b253 Mon Sep 17 00:00:00 2001
From: Tom Lane <tgl@sss.pgh.pa.us>
Date: Tue, 23 Aug 2016 09:39:54 -0400
Subject: [PATCH] Refactor some network.c code to create
 cidr_set_masklen_internal().

Merge several copies of "copy an inet value and adjust the mask length"
code to create a single, conveniently C-callable function.  This function
is exported for future use by inet SPGiST support, but it's good cleanup
anyway since we had three slightly-different-for-no-good-reason copies.

(Extracted from a larger patch, to separate new code from refactoring
of old code)

Emre Hasegeli
---
 src/backend/utils/adt/network.c | 109 +++++++++-----------------------
 src/include/utils/inet.h        |   5 +-
 2 files changed, 33 insertions(+), 81 deletions(-)

diff --git a/src/backend/utils/adt/network.c b/src/backend/utils/adt/network.c
index 1f8469a2cbc..3f6987af048 100644
--- a/src/backend/utils/adt/network.c
+++ b/src/backend/utils/adt/network.c
@@ -268,11 +268,7 @@ Datum
 inet_to_cidr(PG_FUNCTION_ARGS)
 {
 	inet	   *src = PG_GETARG_INET_PP(0);
-	inet	   *dst;
 	int			bits;
-	int			byte;
-	int			nbits;
-	int			maxbytes;
 
 	bits = ip_bits(src);
 
@@ -280,29 +276,7 @@ inet_to_cidr(PG_FUNCTION_ARGS)
 	if ((bits < 0) || (bits > ip_maxbits(src)))
 		elog(ERROR, "invalid inet bit length: %d", bits);
 
-	/* clone the original data */
-	dst = (inet *) palloc(VARSIZE_ANY(src));
-	memcpy(dst, src, VARSIZE_ANY(src));
-
-	/* zero out any bits to the right of the netmask */
-	byte = bits / 8;
-
-	nbits = bits % 8;
-	/* clear the first byte, this might be a partial byte */
-	if (nbits != 0)
-	{
-		ip_addr(dst)[byte] &= ~(0xFF >> nbits);
-		byte++;
-	}
-	/* clear remaining bytes */
-	maxbytes = ip_addrsize(dst);
-	while (byte < maxbytes)
-	{
-		ip_addr(dst)[byte] = 0;
-		byte++;
-	}
-
-	PG_RETURN_INET_P(dst);
+	PG_RETURN_INET_P(cidr_set_masklen_internal(src, bits));
 }
 
 Datum
@@ -334,10 +308,6 @@ cidr_set_masklen(PG_FUNCTION_ARGS)
 {
 	inet	   *src = PG_GETARG_INET_PP(0);
 	int			bits = PG_GETARG_INT32(1);
-	inet	   *dst;
-	int			byte;
-	int			nbits;
-	int			maxbytes;
 
 	if (bits == -1)
 		bits = ip_maxbits(src);
@@ -347,31 +317,36 @@ cidr_set_masklen(PG_FUNCTION_ARGS)
 				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
 				 errmsg("invalid mask length: %d", bits)));
 
-	/* clone the original data */
-	dst = (inet *) palloc(VARSIZE_ANY(src));
-	memcpy(dst, src, VARSIZE_ANY(src));
+	PG_RETURN_INET_P(cidr_set_masklen_internal(src, bits));
+}
 
-	ip_bits(dst) = bits;
+/*
+ * Copy src and set mask length to 'bits' (which must be valid for the family)
+ */
+inet *
+cidr_set_masklen_internal(const inet *src, int bits)
+{
+	inet	   *dst = (inet *) palloc0(sizeof(inet));
 
-	/* zero out any bits to the right of the new netmask */
-	byte = bits / 8;
+	ip_family(dst) = ip_family(src);
+	ip_bits(dst) = bits;
 
-	nbits = bits % 8;
-	/* clear the first byte, this might be a partial byte */
-	if (nbits != 0)
+	if (bits > 0)
 	{
-		ip_addr(dst)[byte] &= ~(0xFF >> nbits);
-		byte++;
-	}
-	/* clear remaining bytes */
-	maxbytes = ip_addrsize(dst);
-	while (byte < maxbytes)
-	{
-		ip_addr(dst)[byte] = 0;
-		byte++;
+		Assert(bits <= ip_maxbits(dst));
+
+		/* Clone appropriate bytes of the address, leaving the rest 0 */
+		memcpy(ip_addr(dst), ip_addr(src), (bits + 7) / 8);
+
+		/* Clear any unwanted bits in the last partial byte */
+		if (bits % 8)
+			ip_addr(dst)[bits / 8] &= ~(0xFF >> (bits % 8));
 	}
 
-	PG_RETURN_INET_P(dst);
+	/* Set varlena header correctly */
+	SET_INET_VARSIZE(dst);
+
+	return dst;
 }
 
 /*
@@ -719,11 +694,7 @@ network_broadcast(PG_FUNCTION_ARGS)
 	/* make sure any unused bits are zeroed */
 	dst = (inet *) palloc0(sizeof(inet));
 
-	if (ip_family(ip) == PGSQL_AF_INET)
-		maxbytes = 4;
-	else
-		maxbytes = 16;
-
+	maxbytes = ip_addrsize(ip);
 	bits = ip_bits(ip);
 	a = ip_addr(ip);
 	b = ip_addr(dst);
@@ -853,11 +824,7 @@ network_hostmask(PG_FUNCTION_ARGS)
 	/* make sure any unused bits are zeroed */
 	dst = (inet *) palloc0(sizeof(inet));
 
-	if (ip_family(ip) == PGSQL_AF_INET)
-		maxbytes = 4;
-	else
-		maxbytes = 16;
-
+	maxbytes = ip_addrsize(ip);
 	bits = ip_maxbits(ip) - ip_bits(ip);
 	b = ip_addr(dst);
 
@@ -907,8 +874,7 @@ Datum
 inet_merge(PG_FUNCTION_ARGS)
 {
 	inet	   *a1 = PG_GETARG_INET_PP(0),
-			   *a2 = PG_GETARG_INET_PP(1),
-			   *result;
+			   *a2 = PG_GETARG_INET_PP(1);
 	int			commonbits;
 
 	if (ip_family(a1) != ip_family(a2))
@@ -919,24 +885,7 @@ inet_merge(PG_FUNCTION_ARGS)
 	commonbits = bitncommon(ip_addr(a1), ip_addr(a2),
 							Min(ip_bits(a1), ip_bits(a2)));
 
-	/* Make sure any unused bits are zeroed. */
-	result = (inet *) palloc0(sizeof(inet));
-
-	ip_family(result) = ip_family(a1);
-	ip_bits(result) = commonbits;
-
-	/* Clone appropriate bytes of the address. */
-	if (commonbits > 0)
-		memcpy(ip_addr(result), ip_addr(a1), (commonbits + 7) / 8);
-
-	/* Clean any unwanted bits in the last partial byte. */
-	if (commonbits % 8 != 0)
-		ip_addr(result)[commonbits / 8] &= ~(0xFF >> (commonbits % 8));
-
-	/* Set varlena header correctly. */
-	SET_INET_VARSIZE(result);
-
-	PG_RETURN_INET_P(result);
+	PG_RETURN_INET_P(cidr_set_masklen_internal(a1, commonbits));
 }
 
 /*
diff --git a/src/include/utils/inet.h b/src/include/utils/inet.h
index 2fe3ca8c3c8..dfa0b9f7113 100644
--- a/src/include/utils/inet.h
+++ b/src/include/utils/inet.h
@@ -28,10 +28,12 @@ typedef struct
 } inet_struct;
 
 /*
+ * We use these values for the "family" field.
+ *
  * Referencing all of the non-AF_INET types to AF_INET lets us work on
  * machines which may not have the appropriate address family (like
  * inet6 addresses when AF_INET6 isn't present) but doesn't cause a
- * dump/reload requirement.  Existing databases used AF_INET for the family
+ * dump/reload requirement.  Pre-7.4 databases used AF_INET for the family
  * type on disk.
  */
 #define PGSQL_AF_INET	(AF_INET + 0)
@@ -117,6 +119,7 @@ typedef struct macaddr
 /*
  * Support functions in network.c
  */
+extern inet *cidr_set_masklen_internal(const inet *src, int bits);
 extern int	bitncmp(const unsigned char *l, const unsigned char *r, int n);
 extern int	bitncommon(const unsigned char *l, const unsigned char *r, int n);
 
-- 
GitLab