D12043: test-http-bad-server: factor code dealing with "write" in the new object

marmoute (Pierre-Yves David) phabricator at mercurial-scm.org
Mon Jan 24 14:46:19 UTC 2022


marmoute created this revision.
Herald added a reviewer: hg-reviewers.
Herald added a subscriber: mercurial-patches.

REVISION SUMMARY
  This will make sure both `sendall` and `write` do the same processing and make
  it simpler to update that processing in the future.

REPOSITORY
  rHG Mercurial

BRANCH
  default

REVISION DETAIL
  https://phab.mercurial-scm.org/D12043

AFFECTED FILES
  tests/testlib/badserverext.py

CHANGE DETAILS

diff --git a/tests/testlib/badserverext.py b/tests/testlib/badserverext.py
--- a/tests/testlib/badserverext.py
+++ b/tests/testlib/badserverext.py
@@ -70,6 +70,11 @@
         self._all_close_after_recv_bytes = close_after_recv_bytes
         self._all_close_after_send_bytes = close_after_send_bytes
 
+        self.target_recv_bytes = None
+        self.remaining_recv_bytes = None
+        self.target_send_bytes = None
+        self.remaining_send_bytes = None
+
     def start_next_request(self):
         """move to the next set of close condition"""
         if self._all_close_after_recv_bytes:
@@ -93,6 +98,54 @@
             return True
         return False
 
+    def forward_write(self, obj, method, data, *args, **kwargs):
+        """call an underlying write function until condition are met
+
+        When the condition are met the socket is closed
+        """
+        remaining = self.remaining_send_bytes
+
+        orig = object.__getattribute__(obj, '_orig')
+        bmethod = method.encode('ascii')
+        func = getattr(orig, method)
+        # No byte limit on this operation. Call original function.
+        if not remaining:
+            result = func(data, *args, **kwargs)
+            obj._writelog(b'%s(%d) -> %s' % (bmethod, len(data), data))
+            return result
+
+        remaining = max(0, remaining)
+
+        if remaining > 0:
+            if remaining < len(data):
+                newdata = data[0:remaining]
+            else:
+                newdata = data
+
+            remaining -= len(newdata)
+
+            obj._writelog(
+                b'%s(%d from %d) -> (%d) %s'
+                % (
+                    bmethod,
+                    len(newdata),
+                    len(data),
+                    remaining,
+                    newdata,
+                )
+            )
+
+            result = func(newdata, *args, **kwargs)
+
+        self.remaining_send_bytes = remaining
+
+        if remaining <= 0:
+            obj._writelog(b'write limit reached; closing socket')
+            object.__getattribute__(obj, '_cond_close')()
+            raise Exception('connection closed after sending N bytes')
+
+        return result
+
 
 # We can't adjust __class__ on a socket instance. So we define a proxy type.
 class socketproxy(object):
@@ -131,37 +184,11 @@
         return fileobjectproxy(f, logfp, cond)
 
     def sendall(self, data, flags=0):
-        remaining = object.__getattribute__(self, '_cond').remaining_send_bytes
-
-        # No read limit. Call original function.
-        if not remaining:
-            result = object.__getattribute__(self, '_orig').sendall(data, flags)
-            self._writelog(b'sendall(%d) -> %s' % (len(data), data))
-            return result
-
-        if len(data) > remaining:
-            newdata = data[0:remaining]
-        else:
-            newdata = data
-
-        remaining -= len(newdata)
+        cond = object.__getattribute__(self, '_cond')
+        return cond.forward_write(self, 'sendall', data, flags)
 
-        result = object.__getattribute__(self, '_orig').sendall(newdata, flags)
-
-        self._writelog(
-            b'sendall(%d from %d) -> (%d) %s'
-            % (len(newdata), len(data), remaining, newdata)
-        )
-
-        object.__getattribute__(self, '_cond').remaining_send_bytes = remaining
-
-        if remaining <= 0:
-            self._writelog(b'write limit reached; closing socket')
-            object.__getattribute__(self, '_orig').shutdown(socket.SHUT_RDWR)
-
-            raise Exception('connection closed after sending N bytes')
-
-        return result
+    def _cond_close(self):
+        object.__getattribute__(self, '_orig').shutdown(socket.SHUT_RDWR)
 
 
 # We can't adjust __class__ on socket._fileobject, so define a proxy.
@@ -174,7 +201,14 @@
         object.__setattr__(self, '_cond', condition_tracked)
 
     def __getattribute__(self, name):
-        if name in ('_close', 'read', 'readline', 'write', '_writelog'):
+        if name in (
+            '_close',
+            'read',
+            'readline',
+            'write',
+            '_writelog',
+            '_cond_close',
+        ):
             return object.__getattribute__(self, name)
 
         return getattr(object.__getattribute__(self, '_orig'), name)
@@ -280,37 +314,11 @@
         return result
 
     def write(self, data):
-        remaining = object.__getattribute__(self, '_cond').remaining_send_bytes
-
-        # No byte limit on this operation. Call original function.
-        if not remaining:
-            result = object.__getattribute__(self, '_orig').write(data)
-            self._writelog(b'write(%d) -> %s' % (len(data), data))
-            return result
-
-        if len(data) > remaining:
-            newdata = data[0:remaining]
-        else:
-            newdata = data
-
-        remaining -= len(newdata)
+        cond = object.__getattribute__(self, '_cond')
+        return cond.forward_write(self, 'write', data)
 
-        result = object.__getattribute__(self, '_orig').write(newdata)
-
-        self._writelog(
-            b'write(%d from %d) -> (%d) %s'
-            % (len(newdata), len(data), remaining, newdata)
-        )
-
-        object.__getattribute__(self, '_cond').remaining_send_bytes = remaining
-
-        if remaining <= 0:
-            self._writelog(b'write limit reached; closing socket')
-            self._close()
-
-            raise Exception('connection closed after sending N bytes')
-
-        return result
+    def _cond_close(self):
+        self._close()
 
 
 def process_config(value):



To: marmoute, #hg-reviewers
Cc: mercurial-patches, mercurial-devel


More information about the Mercurial-devel mailing list