This commit is contained in:
Croneter 2018-06-15 14:40:29 +02:00
parent 1a58967111
commit 51444111d2

View file

@ -27,8 +27,10 @@ try:
from ssl import SSLError from ssl import SSLError
HAVE_SSL = True HAVE_SSL = True
except ImportError: except ImportError:
# dummy class of SSLError for ssl none-support environment.
class SSLError(Exception): class SSLError(Exception):
"""
Dummy class of SSLError for ssl none-support environment.
"""
pass pass
HAVE_SSL = False HAVE_SSL = False
@ -50,7 +52,7 @@ import utils
############################################################################### ###############################################################################
log = logging.getLogger("PLEX."+__name__) LOG = logging.getLogger("PLEX." + __name__)
############################################################################### ###############################################################################
@ -95,28 +97,31 @@ class WebSocketConnectionClosedException(WebSocketException):
""" """
pass pass
class WebSocketTimeoutException(WebSocketException): class WebSocketTimeoutException(WebSocketException):
""" """
WebSocketTimeoutException will be raised at socket timeout during read/write data. WebSocketTimeoutException will be raised at socket timeout during read and
write data.
""" """
pass pass
default_timeout = None
traceEnabled = False DEFAULT_TIMEOUT = None
TRACE_ENABLED = False
def enableTrace(tracable): def enable_trace(tracable):
""" """
turn on/off the tracability. turn on/off the tracability.
tracable: boolean value. if set True, tracability is enabled. tracable: boolean value. if set True, tracability is enabled.
""" """
global traceEnabled global TRACE_ENABLED
traceEnabled = tracable TRACE_ENABLED = tracable
if tracable: if tracable:
if not log.handlers: if not LOG.handlers:
log.addHandler(logging.StreamHandler()) LOG.addHandler(logging.StreamHandler())
log.setLevel(logging.DEBUG) LOG.setLevel(logging.DEBUG)
def setdefaulttimeout(timeout): def setdefaulttimeout(timeout):
@ -125,15 +130,15 @@ def setdefaulttimeout(timeout):
timeout: default socket timeout time. This value is second. timeout: default socket timeout time. This value is second.
""" """
global default_timeout global DEFAULT_TIMEOUT
default_timeout = timeout DEFAULT_TIMEOUT = timeout
def getdefaulttimeout(): def getdefaulttimeout():
""" """
Return the global timeout setting(second) to connect. Return the global timeout setting(second) to connect.
""" """
return default_timeout return DEFAULT_TIMEOUT
def _parse_url(url): def _parse_url(url):
@ -185,7 +190,8 @@ def create_connection(url, timeout=None, **options):
Connect to url and return the WebSocket object. Connect to url and return the WebSocket object.
Passing optional timeout parameter will set the timeout on the socket. Passing optional timeout parameter will set the timeout on the socket.
If no timeout is supplied, the global default timeout setting returned by getdefauttimeout() is used. If no timeout is supplied, the global default timeout setting returned by
getdefauttimeout() is used.
You can customize using 'options'. You can customize using 'options'.
If you set "header" list object, you can set your own custom header. If you set "header" list object, you can set your own custom header.
@ -195,18 +201,20 @@ def create_connection(url, timeout=None, **options):
timeout: socket timeout time. This value is integer. timeout: socket timeout time. This value is integer.
if you set None for this value, it means "use default_timeout value" if you set None for this value, it means "use DEFAULT_TIMEOUT
value"
options: current support option is only "header". options: current support option is only "header".
if you set header as dict value, the custom HTTP headers are added. if you set header as dict value, the custom HTTP headers are added
""" """
sockopt = options.get("sockopt", []) sockopt = options.get("sockopt", [])
sslopt = options.get("sslopt", {}) sslopt = options.get("sslopt", {})
websock = WebSocket(sockopt=sockopt, sslopt=sslopt) websock = WebSocket(sockopt=sockopt, sslopt=sslopt)
websock.settimeout(timeout if timeout is not None else default_timeout) websock.settimeout(timeout if timeout is not None else DEFAULT_TIMEOUT)
websock.connect(url, **options) websock.connect(url, **options)
return websock return websock
_MAX_INTEGER = (1 << 32) - 1 _MAX_INTEGER = (1 << 32) - 1
_AVAILABLE_KEY_CHARS = range(0x21, 0x2f + 1) + range(0x3a, 0x7e + 1) _AVAILABLE_KEY_CHARS = range(0x21, 0x2f + 1) + range(0x3a, 0x7e + 1)
_MAX_CHAR_BYTE = (1 << 8) - 1 _MAX_CHAR_BYTE = (1 << 8) - 1
@ -220,10 +228,7 @@ def _create_sec_websocket_key():
return base64.encodestring(uid.bytes).strip() return base64.encodestring(uid.bytes).strip()
_HEADERS_TO_CHECK = { _HEADERS_TO_CHECK = {"upgrade": "websocket", "connection": "upgrade"}
"upgrade": "websocket",
"connection": "upgrade",
}
class ABNF(object): class ABNF(object):
@ -308,9 +313,9 @@ class ABNF(object):
if length >= ABNF.LENGTH_63: if length >= ABNF.LENGTH_63:
raise ValueError("data is too long") raise ValueError("data is too long")
frame_header = chr(self.fin << 7 frame_header = chr(self.fin << 7 |
| self.rsv1 << 6 | self.rsv2 << 5 | self.rsv3 << 4 self.rsv1 << 6 | self.rsv2 << 5 | self.rsv3 << 4 |
| self.opcode) self.opcode)
if length < ABNF.LENGTH_7: if length < ABNF.LENGTH_7:
frame_header += chr(self.mask << 7 | length) frame_header += chr(self.mask << 7 | length)
elif length < ABNF.LENGTH_16: elif length < ABNF.LENGTH_16:
@ -395,6 +400,9 @@ class WebSocket(object):
self._cont_data = None self._cont_data = None
def fileno(self): def fileno(self):
"""
Returns sock.fileno()
"""
return self.sock.fileno() return self.sock.fileno()
def set_mask_key(self, func): def set_mask_key(self, func):
@ -438,7 +446,7 @@ class WebSocket(object):
timeout: socket timeout time. This value is integer. timeout: socket timeout time. This value is integer.
if you set None for this value, if you set None for this value,
it means "use default_timeout value" it means "use DEFAULT_TIMEOUT value"
options: current support option is only "header". options: current support option is only "header".
if you set header as dict value, if you set header as dict value,
@ -487,10 +495,10 @@ class WebSocket(object):
header_str = "\r\n".join(headers) header_str = "\r\n".join(headers)
self._send(header_str) self._send(header_str)
if traceEnabled: if TRACE_ENABLED:
log.debug("--- request header ---") LOG.debug("--- request header ---")
log.debug(header_str) LOG.debug(header_str)
log.debug("-----------------------") LOG.debug("-----------------------")
status, resp_headers = self._read_headers() status, resp_headers = self._read_headers()
if status != 101: if status != 101:
@ -526,16 +534,16 @@ class WebSocket(object):
def _read_headers(self): def _read_headers(self):
status = None status = None
headers = {} headers = {}
if traceEnabled: if TRACE_ENABLED:
log.debug("--- response header ---") LOG.debug("--- response header ---")
while True: while True:
line = self._recv_line() line = self._recv_line()
if line == "\r\n": if line == "\r\n":
break break
line = line.strip() line = line.strip()
if traceEnabled: if TRACE_ENABLED:
log.debug(line) LOG.debug(line)
if not status: if not status:
status_info = line.split(" ", 2) status_info = line.split(" ", 2)
status = int(status_info[1]) status = int(status_info[1])
@ -547,8 +555,8 @@ class WebSocket(object):
else: else:
raise WebSocketException("Invalid header") raise WebSocketException("Invalid header")
if traceEnabled: if TRACE_ENABLED:
log.debug("-----------------------") LOG.debug("-----------------------")
return status, headers return status, headers
@ -567,14 +575,17 @@ class WebSocket(object):
frame.get_mask_key = self.get_mask_key frame.get_mask_key = self.get_mask_key
data = frame.format() data = frame.format()
length = len(data) length = len(data)
if traceEnabled: if TRACE_ENABLED:
log.debug("send: " + repr(data)) LOG.debug("send: %s", repr(data))
while data: while data:
l = self._send(data) l = self._send(data)
data = data[l:] data = data[l:]
return length return length
def send_binary(self, payload): def send_binary(self, payload):
"""
send the payload
"""
return self.send(payload, ABNF.OPCODE_BINARY) return self.send(payload, ABNF.OPCODE_BINARY)
def ping(self, payload=""): def ping(self, payload=""):
@ -693,34 +704,10 @@ class WebSocket(object):
reason: the reason to close. This must be string. reason: the reason to close. This must be string.
""" """
try: try:
self.sock.shutdown(socket.SHUT_RDWR) self.sock.shutdown(socket.SHUT_RDWR)
except: except:
pass pass
'''
if self.connected:
if status < 0 or status >= ABNF.LENGTH_16:
raise ValueError("code is invalid range")
try:
self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE)
timeout = self.sock.gettimeout()
self.sock.settimeout(3)
try:
frame = self.recv_frame()
if log.isEnabledFor(logging.ERROR):
recv_status = struct.unpack("!H", frame.data)[0]
if recv_status != STATUS_NORMAL:
log.error("close status: " + repr(recv_status))
except:
pass
self.sock.settimeout(timeout)
self.sock.shutdown(socket.SHUT_RDWR)
except:
pass
'''
self._closeInternal() self._closeInternal()
def _closeInternal(self): def _closeInternal(self):
@ -752,7 +739,6 @@ class WebSocket(object):
raise WebSocketConnectionClosedException() raise WebSocketConnectionClosedException()
return bytes_ return bytes_
def _recv_strict(self, bufsize): def _recv_strict(self, bufsize):
shortage = bufsize - sum(len(x) for x in self._recv_buffer) shortage = bufsize - sum(len(x) for x in self._recv_buffer)
while shortage > 0: while shortage > 0:
@ -767,7 +753,6 @@ class WebSocket(object):
self._recv_buffer = [unified[bufsize:]] self._recv_buffer = [unified[bufsize:]]
return unified[:bufsize] return unified[:bufsize]
def _recv_line(self): def _recv_line(self):
line = [] line = []
while True: while True:
@ -846,9 +831,11 @@ class WebSocketApp(object):
run event loop for WebSocket framework. run event loop for WebSocket framework.
This loop is infinite loop and is alive during websocket is available. This loop is infinite loop and is alive during websocket is available.
sockopt: values for socket.setsockopt. sockopt: values for socket.setsockopt.
sockopt must be tuple and each element is argument of sock.setscokopt. sockopt must be tuple and each element is argument of
sock.setscokopt.
sslopt: ssl socket optional dict. sslopt: ssl socket optional dict.
ping_interval: automatically send "ping" command every specified period(second) ping_interval: automatically send "ping" command every specified
period(second)
if set to 0, not send automatically. if set to 0, not send automatically.
""" """
if sockopt is None: if sockopt is None:
@ -861,26 +848,26 @@ class WebSocketApp(object):
self.keep_running = True self.keep_running = True
try: try:
self.sock = WebSocket(self.get_mask_key, sockopt=sockopt, sslopt=sslopt) self.sock = WebSocket(self.get_mask_key,
self.sock.settimeout(default_timeout) sockopt=sockopt,
sslopt=sslopt)
self.sock.settimeout(DEFAULT_TIMEOUT)
self.sock.connect(self.url, header=self.header) self.sock.connect(self.url, header=self.header)
self._callback(self.on_open) self._callback(self.on_open)
if ping_interval: if ping_interval:
thread = threading.Thread(target=self._send_ping, args=(ping_interval,)) thread = threading.Thread(target=self._send_ping,
args=(ping_interval,))
thread.setDaemon(True) thread.setDaemon(True)
thread.start() thread.start()
while self.keep_running: while self.keep_running:
try: try:
data = self.sock.recv() data = self.sock.recv()
if data is None or self.keep_running is False:
if data is None or self.keep_running == False:
break break
self._callback(self.on_message, data) self._callback(self.on_message, data)
except Exception, e: except Exception, e:
#print str(e.args[0])
if "timed out" not in e.args[0]: if "timed out" not in e.args[0]:
raise e raise e
@ -898,19 +885,18 @@ class WebSocketApp(object):
try: try:
callback(self, *args) callback(self, *args)
except Exception, e: except Exception, e:
log.error(e) LOG.error(e)
if True:#log.isEnabledFor(logging.DEBUG):
_, _, tb = sys.exc_info() _, _, tb = sys.exc_info()
traceback.print_tb(tb) traceback.print_tb(tb)
if __name__ == "__main__": if __name__ == "__main__":
enableTrace(True) enable_trace(True)
ws = create_connection("ws://echo.websocket.org/") WEBSOCKET = create_connection("ws://echo.websocket.org/")
print("Sending 'Hello, World'...") LOG.info("Sending 'Hello, World'...")
ws.send("Hello, World") WEBSOCKET.send("Hello, World")
print("Sent") LOG.info("Sent")
print("Receiving...") LOG.info("Receiving...")
result = ws.recv() RESULT = WEBSOCKET.recv()
print("Received '%s'" % result) LOG.info("Received '%s'", RESULT)
ws.close() WEBSOCKET.close()