Merge pull request #1490 from croneter/py3-update-websockets

Update websocket client to 1.0.0
This commit is contained in:
croneter 2021-05-24 20:30:01 +02:00 committed by GitHub
commit 6d566c6cd2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 1008 additions and 558 deletions

View file

@ -4,7 +4,6 @@
<import addon="xbmc.python" version="3.0.0"/> <import addon="xbmc.python" version="3.0.0"/>
<import addon="script.module.requests" version="2.22.0+matrix.1" /> <import addon="script.module.requests" version="2.22.0+matrix.1" />
<import addon="script.module.defusedxml" version="0.6.0+matrix.1"/> <import addon="script.module.defusedxml" version="0.6.0+matrix.1"/>
<import addon="script.module.six" />
<import addon="plugin.video.plexkodiconnect.movies" version="3.0.0" /> <import addon="plugin.video.plexkodiconnect.movies" version="3.0.0" />
<import addon="plugin.video.plexkodiconnect.tvshows" version="3.0.0" /> <import addon="plugin.video.plexkodiconnect.tvshows" version="3.0.0" />
<import addon="metadata.themoviedb.org.python" version="1.3.1+matrix.1" /> <import addon="metadata.themoviedb.org.python" version="1.3.1+matrix.1" />

View file

@ -25,4 +25,4 @@ from ._exceptions import *
from ._logging import * from ._logging import *
from ._socket import * from ._socket import *
__version__ = "0.58.0" __version__ = "1.0.0"

View file

@ -26,17 +26,12 @@ import array
import os import os
import struct import struct
import six
from ._exceptions import * from ._exceptions import *
from ._utils import validate_utf8 from ._utils import validate_utf8
from threading import Lock from threading import Lock
try: try:
if six.PY3:
import numpy import numpy
else:
numpy = None
except ImportError: except ImportError:
numpy = None numpy = None
@ -53,10 +48,7 @@ except ImportError:
for i in range(len(_d)): for i in range(len(_d)):
_d[i] ^= _m[i % 4] _d[i] ^= _m[i % 4]
if six.PY3:
return _d.tobytes() return _d.tobytes()
else:
return _d.tostring()
__all__ = [ __all__ = [
@ -181,8 +173,7 @@ class ABNF(object):
if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]): if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]):
raise WebSocketProtocolException("Invalid close frame.") raise WebSocketProtocolException("Invalid close frame.")
code = 256 * \ code = 256 * self.data[0] + self.data[1]
six.byte2int(self.data[0:1]) + six.byte2int(self.data[1:2])
if not self._is_valid_close_status(code): if not self._is_valid_close_status(code):
raise WebSocketProtocolException("Invalid close opcode.") raise WebSocketProtocolException("Invalid close opcode.")
@ -211,7 +202,7 @@ class ABNF(object):
fin: <type> fin: <type>
fin flag. if set to 0, create continue fragmentation. fin flag. if set to 0, create continue fragmentation.
""" """
if opcode == ABNF.OPCODE_TEXT and isinstance(data, six.text_type): if opcode == ABNF.OPCODE_TEXT and isinstance(data, str):
data = data.encode("utf-8") data = data.encode("utf-8")
# mask must be set if send data from client # mask must be set if send data from client
return ABNF(fin, 0, 0, 0, opcode, 1, data) return ABNF(fin, 0, 0, 0, opcode, 1, data)
@ -228,19 +219,16 @@ class ABNF(object):
if length >= ABNF.LENGTH_63: if length >= ABNF.LENGTH_63:
raise ValueError("data is too long") raise ValueError("data is too long")
frame_header = chr(self.fin << 7 frame_header = chr(self.fin << 7 |
| self.rsv1 << 6 | self.rsv2 << 5 | self.rsv3 << 4 self.rsv1 << 6 | self.rsv2 << 5 | self.rsv3 << 4 |
| self.opcode) self.opcode).encode('latin-1')
if length < ABNF.LENGTH_7: if length < ABNF.LENGTH_7:
frame_header += chr(self.mask << 7 | length) frame_header += chr(self.mask << 7 | length).encode('latin-1')
frame_header = six.b(frame_header)
elif length < ABNF.LENGTH_16: elif length < ABNF.LENGTH_16:
frame_header += chr(self.mask << 7 | 0x7e) frame_header += chr(self.mask << 7 | 0x7e).encode('latin-1')
frame_header = six.b(frame_header)
frame_header += struct.pack("!H", length) frame_header += struct.pack("!H", length)
else: else:
frame_header += chr(self.mask << 7 | 0x7f) frame_header += chr(self.mask << 7 | 0x7f).encode('latin-1')
frame_header = six.b(frame_header)
frame_header += struct.pack("!Q", length) frame_header += struct.pack("!Q", length)
if not self.mask: if not self.mask:
@ -252,7 +240,7 @@ class ABNF(object):
def _get_masked(self, mask_key): def _get_masked(self, mask_key):
s = ABNF.mask(mask_key, self.data) s = ABNF.mask(mask_key, self.data)
if isinstance(mask_key, six.text_type): if isinstance(mask_key, str):
mask_key = mask_key.encode('utf-8') mask_key = mask_key.encode('utf-8')
return mask_key + s return mask_key + s
@ -265,34 +253,32 @@ class ABNF(object):
Parameters Parameters
---------- ----------
mask_key: <type> mask_key: <type>
4 byte string(byte). 4 byte string.
data: <type> data: <type>
data to mask/unmask. data to mask/unmask.
""" """
if data is None: if data is None:
data = "" data = ""
if isinstance(mask_key, six.text_type): if isinstance(mask_key, str):
mask_key = six.b(mask_key) mask_key = mask_key.encode('latin-1')
if isinstance(data, six.text_type): if isinstance(data, str):
data = six.b(data) data = data.encode('latin-1')
if numpy: if numpy:
origlen = len(data) origlen = len(data)
_mask_key = mask_key[3] << 24 | mask_key[2] << 16 | mask_key[1] << 8 | mask_key[0] _mask_key = mask_key[3] << 24 | mask_key[2] << 16 | mask_key[1] << 8 | mask_key[0]
# We need data to be a multiple of four... # We need data to be a multiple of four...
data += bytes(" " * (4 - (len(data) % 4)), "us-ascii") data += b' ' * (4 - (len(data) % 4))
a = numpy.frombuffer(data, dtype="uint32") a = numpy.frombuffer(data, dtype="uint32")
masked = numpy.bitwise_xor(a, [_mask_key]).astype("uint32") masked = numpy.bitwise_xor(a, [_mask_key]).astype("uint32")
if len(data) > origlen: if len(data) > origlen:
return masked.tobytes()[:origlen] return masked.tobytes()[:origlen]
return masked.tobytes() return masked.tobytes()
else: else:
_m = array.array("B", mask_key) return _mask(array.array("B", mask_key), array.array("B", data))
_d = array.array("B", data)
return _mask(_m, _d)
class frame_buffer(object): class frame_buffer(object):
@ -319,20 +305,12 @@ class frame_buffer(object):
def recv_header(self): def recv_header(self):
header = self.recv_strict(2) header = self.recv_strict(2)
b1 = header[0] b1 = header[0]
if six.PY2:
b1 = ord(b1)
fin = b1 >> 7 & 1 fin = b1 >> 7 & 1
rsv1 = b1 >> 6 & 1 rsv1 = b1 >> 6 & 1
rsv2 = b1 >> 5 & 1 rsv2 = b1 >> 5 & 1
rsv3 = b1 >> 4 & 1 rsv3 = b1 >> 4 & 1
opcode = b1 & 0xf opcode = b1 & 0xf
b2 = header[1] b2 = header[1]
if six.PY2:
b2 = ord(b2)
has_mask = b2 >> 7 & 1 has_mask = b2 >> 7 & 1
length_bits = b2 & 0x7f length_bits = b2 & 0x7f
@ -408,7 +386,7 @@ class frame_buffer(object):
self.recv_buffer.append(bytes_) self.recv_buffer.append(bytes_)
shortage -= len(bytes_) shortage -= len(bytes_)
unified = six.b("").join(self.recv_buffer) unified = bytes("", 'utf-8').join(self.recv_buffer)
if shortage == 0: if shortage == 0:
self.recv_buffer = [] self.recv_buffer = []

View file

@ -22,15 +22,11 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
""" """
import inspect import selectors
import select
import sys import sys
import threading import threading
import time import time
import traceback import traceback
import six
from ._abnf import ABNF from ._abnf import ABNF
from ._core import WebSocket, getdefaulttimeout from ._core import WebSocket, getdefaulttimeout
from ._exceptions import * from ._exceptions import *
@ -39,6 +35,7 @@ from . import _logging
__all__ = ["WebSocketApp"] __all__ = ["WebSocketApp"]
class Dispatcher: class Dispatcher:
""" """
Dispatcher Dispatcher
@ -49,12 +46,16 @@ class Dispatcher:
def read(self, sock, read_callback, check_callback): def read(self, sock, read_callback, check_callback):
while self.app.keep_running: while self.app.keep_running:
r, w, e = select.select( sel = selectors.DefaultSelector()
(self.app.sock.sock, ), (), (), self.ping_timeout) sel.register(self.app.sock.sock, selectors.EVENT_READ)
r = sel.select(self.ping_timeout)
if r: if r:
if not read_callback(): if not read_callback():
break break
check_callback() check_callback()
sel.close()
class SSLDispatcher: class SSLDispatcher:
""" """
@ -77,8 +78,14 @@ class SSLDispatcher:
if sock.pending(): if sock.pending():
return [sock,] return [sock,]
r, w, e = select.select((sock, ), (), (), self.ping_timeout) sel = selectors.DefaultSelector()
return r sel.register(sock, selectors.EVENT_READ)
r = sel.select(self.ping_timeout)
sel.close()
if len(r) > 0:
return r[0][0]
class WebSocketApp(object): class WebSocketApp(object):
@ -190,18 +197,19 @@ class WebSocketApp(object):
self.sock.close(**kwargs) self.sock.close(**kwargs)
self.sock = None self.sock = None
def _send_ping(self, interval, event): def _send_ping(self, interval, event, payload):
while not event.wait(interval): while not event.wait(interval):
self.last_ping_tm = time.time() self.last_ping_tm = time.time()
if self.sock: if self.sock:
try: try:
self.sock.ping() self.sock.ping(payload)
except Exception as ex: except Exception as ex:
_logging.warning("send_ping routine terminated: {}".format(ex)) _logging.warning("send_ping routine terminated: {}".format(ex))
break break
def run_forever(self, sockopt=None, sslopt=None, def run_forever(self, sockopt=None, sslopt=None,
ping_interval=0, ping_timeout=None, ping_interval=0, ping_timeout=None,
ping_payload="",
http_proxy_host=None, http_proxy_port=None, http_proxy_host=None, http_proxy_port=None,
http_no_proxy=None, http_proxy_auth=None, http_no_proxy=None, http_proxy_auth=None,
skip_utf8_validation=False, skip_utf8_validation=False,
@ -226,6 +234,8 @@ class WebSocketApp(object):
if set to 0, not send automatically. if set to 0, not send automatically.
ping_timeout: int or float ping_timeout: int or float
timeout (in seconds) if the pong message is not received. timeout (in seconds) if the pong message is not received.
ping_payload: str
payload message to send with each ping.
http_proxy_host: <type> http_proxy_host: <type>
http proxy host name. http proxy host name.
http_proxy_port: <type> http_proxy_port: <type>
@ -250,7 +260,9 @@ class WebSocketApp(object):
""" """
if ping_timeout is not None and ping_timeout <= 0: if ping_timeout is not None and ping_timeout <= 0:
ping_timeout = None raise WebSocketException("Ensure ping_timeout > 0")
if ping_interval is not None and ping_interval < 0:
raise WebSocketException("Ensure ping_interval >= 0")
if ping_timeout and ping_interval and ping_interval <= ping_timeout: if ping_timeout and ping_interval and ping_interval <= ping_timeout:
raise WebSocketException("Ensure ping_interval > ping_timeout") raise WebSocketException("Ensure ping_interval > ping_timeout")
if not sockopt: if not sockopt:
@ -271,15 +283,16 @@ class WebSocketApp(object):
If close_frame is set, we will invoke the on_close handler with the If close_frame is set, we will invoke the on_close handler with the
statusCode and reason from there. statusCode and reason from there.
""" """
if thread and thread.is_alive(): if thread and thread.is_alive():
event.set() event.set()
thread.join() thread.join()
self.keep_running = False self.keep_running = False
if self.sock: if self.sock:
self.sock.close() self.sock.close()
close_args = self._get_close_args( close_status_code, close_reason = self._get_close_args(
close_frame.data if close_frame else None) close_frame if close_frame else None)
self._callback(self.on_close, *close_args) self._callback(self.on_close, close_status_code, close_reason)
self.sock = None self.sock = None
try: try:
@ -304,8 +317,8 @@ class WebSocketApp(object):
if ping_interval: if ping_interval:
event = threading.Event() event = threading.Event()
thread = threading.Thread( thread = threading.Thread(
target=self._send_ping, args=(ping_interval, event)) target=self._send_ping, args=(ping_interval, event, ping_payload))
thread.setDaemon(True) thread.daemon = True
thread.start() thread.start()
def read(): def read():
@ -327,7 +340,7 @@ class WebSocketApp(object):
frame.data, frame.fin) frame.data, frame.fin)
else: else:
data = frame.data data = frame.data
if six.PY3 and op_code == ABNF.OPCODE_TEXT: if op_code == ABNF.OPCODE_TEXT:
data = data.decode("utf-8") data = data.decode("utf-8")
self._callback(self.on_data, data, frame.opcode, True) self._callback(self.on_data, data, frame.opcode, True)
self._callback(self.on_message, data) self._callback(self.on_message, data)
@ -340,9 +353,9 @@ class WebSocketApp(object):
has_pong_not_arrived_after_last_ping = self.last_pong_tm - self.last_ping_tm < 0 has_pong_not_arrived_after_last_ping = self.last_pong_tm - self.last_ping_tm < 0
has_pong_arrived_too_late = self.last_pong_tm - self.last_ping_tm > ping_timeout has_pong_arrived_too_late = self.last_pong_tm - self.last_ping_tm > ping_timeout
if (self.last_ping_tm if (self.last_ping_tm and
and has_timeout_expired has_timeout_expired and
and (has_pong_not_arrived_after_last_ping or has_pong_arrived_too_late)): (has_pong_not_arrived_after_last_ping or has_pong_arrived_too_late)):
raise WebSocketTimeoutException("ping/pong timed out") raise WebSocketTimeoutException("ping/pong timed out")
return True return True
@ -362,24 +375,23 @@ class WebSocketApp(object):
return Dispatcher(self, timeout) return Dispatcher(self, timeout)
def _get_close_args(self, data): def _get_close_args(self, close_frame):
""" """
_get_close_args extracts the code, reason from the close body _get_close_args extracts the close code and reason from the close body
if they exists, and if the self.on_close except three arguments if it exists (RFC6455 says WebSocket Connection Close Code is optional)
""" """
# if the on_close callback is "old", just return empty list # Need to catch the case where close_frame is None
if sys.version_info < (3, 0): # Otherwise the following if statement causes an error
if not self.on_close or len(inspect.getargspec(self.on_close).args) != 3: if not self.on_close or not close_frame:
return [] return [None, None]
# Extract close frame status code
if close_frame.data and len(close_frame.data) >= 2:
close_status_code = 256 * close_frame.data[0] + close_frame.data[1]
reason = close_frame.data[2:].decode('utf-8')
return [close_status_code, reason]
else: else:
if not self.on_close or len(inspect.getfullargspec(self.on_close).args) != 3: # Most likely reached this because len(close_frame_data.data) < 2
return []
if data and len(data) >= 2:
code = 256 * six.byte2int(data[0:1]) + six.byte2int(data[1:2])
reason = data[2:].decode('utf-8')
return [code, reason]
return [None, None] return [None, None]
def _callback(self, callback, *args): def _callback(self, callback, *args):

View file

@ -22,10 +22,7 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
""" """
try: import http.cookies
import Cookie
except:
import http.cookies as Cookie
class SimpleCookieJar(object): class SimpleCookieJar(object):
@ -34,26 +31,20 @@ class SimpleCookieJar(object):
def add(self, set_cookie): def add(self, set_cookie):
if set_cookie: if set_cookie:
try: simpleCookie = http.cookies.SimpleCookie(set_cookie)
simpleCookie = Cookie.SimpleCookie(set_cookie)
except:
simpleCookie = Cookie.SimpleCookie(set_cookie.encode('ascii', 'ignore'))
for k, v in simpleCookie.items(): for k, v in simpleCookie.items():
domain = v.get("domain") domain = v.get("domain")
if domain: if domain:
if not domain.startswith("."): if not domain.startswith("."):
domain = "." + domain domain = "." + domain
cookie = self.jar.get(domain) if self.jar.get(domain) else Cookie.SimpleCookie() cookie = self.jar.get(domain) if self.jar.get(domain) else http.cookies.SimpleCookie()
cookie.update(simpleCookie) cookie.update(simpleCookie)
self.jar[domain.lower()] = cookie self.jar[domain.lower()] = cookie
def set(self, set_cookie): def set(self, set_cookie):
if set_cookie: if set_cookie:
try: simpleCookie = http.cookies.SimpleCookie(set_cookie)
simpleCookie = Cookie.SimpleCookie(set_cookie)
except:
simpleCookie = Cookie.SimpleCookie(set_cookie.encode('ascii', 'ignore'))
for k, v in simpleCookie.items(): for k, v in simpleCookie.items():
domain = v.get("domain") domain = v.get("domain")
@ -72,5 +63,7 @@ class SimpleCookieJar(object):
if host.endswith(domain) or host == domain[1:]: if host.endswith(domain) or host == domain[1:]:
cookies.append(self.jar.get(domain)) cookies.append(self.jar.get(domain))
return "; ".join(filter(None, ["%s=%s" % (k, v.value) for cookie in filter(None, sorted(cookies)) for k, v in return "; ".join(filter(
sorted(cookie.items())])) None, sorted(
["%s=%s" % (k, v.value) for cookie in filter(None, cookies) for k, v in cookie.items()]
)))

View file

@ -1,4 +1,3 @@
from __future__ import print_function
""" """
_core.py _core.py
==================================== ====================================
@ -30,8 +29,6 @@ import struct
import threading import threading
import time import time
import six
# websocket modules # websocket modules
from ._abnf import * from ._abnf import *
from ._exceptions import * from ._exceptions import *
@ -44,6 +41,7 @@ from ._utils import *
__all__ = ['WebSocket', 'create_connection'] __all__ = ['WebSocket', 'create_connection']
class WebSocket(object): class WebSocket(object):
""" """
Low level WebSocket interface. Low level WebSocket interface.
@ -225,6 +223,9 @@ class WebSocket(object):
cookie value. cookie value.
- origin: str - origin: str
custom origin url. custom origin url.
- connection: str
custom connection header value.
default value "Upgrade" set in _handshake.py
- suppress_origin: bool - suppress_origin: bool
suppress outputting origin header. suppress outputting origin header.
- host: str - host: str
@ -270,11 +271,11 @@ class WebSocket(object):
Parameters Parameters
---------- ----------
payload: <type> payload: str
Payload must be utf-8 string or unicode, Payload must be utf-8 string or unicode,
if the opcode is OPCODE_TEXT. if the opcode is OPCODE_TEXT.
Otherwise, it must be string(byte array) Otherwise, it must be string(byte array)
opcode: <type> opcode: int
operation code to send. Please see OPCODE_XXX. operation code to send. Please see OPCODE_XXX.
""" """
@ -295,7 +296,7 @@ class WebSocket(object):
Parameters Parameters
---------- ----------
frame: <type> frame: ABNF frame
frame data created by ABNF.create_frame frame data created by ABNF.create_frame
""" """
if self.get_mask_key: if self.get_mask_key:
@ -303,8 +304,8 @@ class WebSocket(object):
data = frame.format() data = frame.format()
length = len(data) length = len(data)
if (isEnabledForTrace()): if (isEnabledForTrace()):
trace("send: " + repr(data)) trace("++Sent raw: " + repr(data))
trace("++Sent decoded: " + frame.__str__())
with self.lock: with self.lock:
while data: while data:
l = self._send(data) l = self._send(data)
@ -321,10 +322,10 @@ class WebSocket(object):
Parameters Parameters
---------- ----------
payload: <type> payload: str
data payload to send server. data payload to send server.
""" """
if isinstance(payload, six.text_type): if isinstance(payload, str):
payload = payload.encode("utf-8") payload = payload.encode("utf-8")
self.send(payload, ABNF.OPCODE_PING) self.send(payload, ABNF.OPCODE_PING)
@ -334,10 +335,10 @@ class WebSocket(object):
Parameters Parameters
---------- ----------
payload: <type> payload: str
data payload to send server. data payload to send server.
""" """
if isinstance(payload, six.text_type): if isinstance(payload, str):
payload = payload.encode("utf-8") payload = payload.encode("utf-8")
self.send(payload, ABNF.OPCODE_PONG) self.send(payload, ABNF.OPCODE_PONG)
@ -351,7 +352,7 @@ class WebSocket(object):
""" """
with self.readlock: with self.readlock:
opcode, data = self.recv_data() opcode, data = self.recv_data()
if six.PY3 and opcode == ABNF.OPCODE_TEXT: if opcode == ABNF.OPCODE_TEXT:
return data.decode("utf-8") return data.decode("utf-8")
elif opcode == ABNF.OPCODE_TEXT or opcode == ABNF.OPCODE_BINARY: elif opcode == ABNF.OPCODE_TEXT or opcode == ABNF.OPCODE_BINARY:
return data return data
@ -393,6 +394,9 @@ class WebSocket(object):
""" """
while True: while True:
frame = self.recv_frame() frame = self.recv_frame()
if (isEnabledForTrace()):
trace("++Rcv raw: " + repr(frame.format()))
trace("++Rcv decoded: " + frame.__str__())
if not frame: if not frame:
# handle error: # handle error:
# 'NoneType' object has no attribute 'opcode' # 'NoneType' object has no attribute 'opcode'
@ -430,7 +434,7 @@ class WebSocket(object):
""" """
return self.frame_buffer.recv_frame() return self.frame_buffer.recv_frame()
def send_close(self, status=STATUS_NORMAL, reason=six.b("")): def send_close(self, status=STATUS_NORMAL, reason=bytes('', encoding='utf-8')):
""" """
Send close data to the server. Send close data to the server.
@ -446,16 +450,16 @@ class WebSocket(object):
self.connected = False self.connected = False
self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE) self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE)
def close(self, status=STATUS_NORMAL, reason=six.b(""), timeout=3): def close(self, status=STATUS_NORMAL, reason=bytes('', encoding='utf-8'), timeout=3):
""" """
Close Websocket object Close Websocket object
Parameters Parameters
---------- ----------
status: <type> status: int
status code to send. see STATUS_XXX. status code to send. see STATUS_XXX.
reason: <type> reason: bytes
the reason to close. This must be string. the reason to close.
timeout: int or float timeout: int or float
timeout until receive a close frame. timeout until receive a close frame.
If None, it will wait forever until receive a close frame. If None, it will wait forever until receive a close frame.
@ -466,8 +470,7 @@ class WebSocket(object):
try: try:
self.connected = False self.connected = False
self.send(struct.pack('!H', status) + self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE)
reason, ABNF.OPCODE_CLOSE)
sock_timeout = self.sock.gettimeout() sock_timeout = self.sock.gettimeout()
self.sock.settimeout(timeout) self.sock.settimeout(timeout)
start_time = time.time() start_time = time.time()
@ -487,8 +490,10 @@ class WebSocket(object):
break break
self.sock.settimeout(sock_timeout) self.sock.settimeout(sock_timeout)
self.sock.shutdown(socket.SHUT_RDWR) self.sock.shutdown(socket.SHUT_RDWR)
except: except OSError: # This happens often on Mac
pass pass
except:
raise
self.shutdown() self.shutdown()

View file

@ -23,6 +23,7 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
""" """
class WebSocketException(Exception): class WebSocketException(Exception):
""" """
WebSocket exception class. WebSocket exception class.

View file

@ -21,36 +21,16 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
import hashlib import hashlib
import hmac import hmac
import os import os
from base64 import encodebytes as base64encode
import six from http import client as HTTPStatus
from ._cookiejar import SimpleCookieJar from ._cookiejar import SimpleCookieJar
from ._exceptions import * from ._exceptions import *
from ._http import * from ._http import *
from ._logging import * from ._logging import *
from ._socket import * from ._socket import *
if hasattr(six, 'PY3') and six.PY3:
from base64 import encodebytes as base64encode
else:
from base64 import encodestring as base64encode
if hasattr(six, 'PY3') and six.PY3:
if hasattr(six, 'PY34') and six.PY34:
from http import client as HTTPStatus
else:
from http import HTTPStatus
else:
import httplib as HTTPStatus
__all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"] __all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"]
if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest
else:
def compare_digest(s1, s2):
return s1 == s2
# websocket supported version. # websocket supported version.
VERSION = 13 VERSION = 13
@ -93,6 +73,7 @@ def _pack_hostname(hostname):
return hostname return hostname
def _get_handshake_headers(resource, host, port, options): def _get_handshake_headers(resource, host, port, options):
headers = [ headers = [
"GET %s HTTP/1.1" % resource, "GET %s HTTP/1.1" % resource,
@ -116,16 +97,16 @@ def _get_handshake_headers(resource, host, port, options):
key = _create_sec_websocket_key() key = _create_sec_websocket_key()
# Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified # Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified
if not 'header' in options or 'Sec-WebSocket-Key' not in options['header']: if 'header' not in options or 'Sec-WebSocket-Key' not in options['header']:
key = _create_sec_websocket_key() key = _create_sec_websocket_key()
headers.append("Sec-WebSocket-Key: %s" % key) headers.append("Sec-WebSocket-Key: %s" % key)
else: else:
key = options['header']['Sec-WebSocket-Key'] key = options['header']['Sec-WebSocket-Key']
if not 'header' in options or 'Sec-WebSocket-Version' not in options['header']: if 'header' not in options or 'Sec-WebSocket-Version' not in options['header']:
headers.append("Sec-WebSocket-Version: %s" % VERSION) headers.append("Sec-WebSocket-Version: %s" % VERSION)
if not 'connection' in options or options['connection'] is None: if 'connection' not in options or options['connection'] is None:
headers.append('Connection: Upgrade') headers.append('Connection: Upgrade')
else: else:
headers.append(options['connection']) headers.append(options['connection'])
@ -177,8 +158,8 @@ def _validate(headers, key, subprotocols):
r = headers.get(k, None) r = headers.get(k, None)
if not r: if not r:
return False, None return False, None
r = r.lower() r = [x.strip().lower() for x in r.split(',')]
if v != r: if v not in r:
return False, None return False, None
if subprotocols: if subprotocols:
@ -193,12 +174,12 @@ def _validate(headers, key, subprotocols):
return False, None return False, None
result = result.lower() result = result.lower()
if isinstance(result, six.text_type): if isinstance(result, str):
result = result.encode('utf-8') result = result.encode('utf-8')
value = (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode('utf-8') value = (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode('utf-8')
hashed = base64encode(hashlib.sha1(value).digest()).strip().lower() hashed = base64encode(hashlib.sha1(value).digest()).strip().lower()
success = compare_digest(hashed, result) success = hmac.compare_digest(hashed, result)
if success: if success:
return True, subproto return True, subproto

View file

@ -23,18 +23,13 @@ import os
import socket import socket
import sys import sys
import six
from ._exceptions import * from ._exceptions import *
from ._logging import * from ._logging import *
from ._socket import* from ._socket import*
from ._ssl_compat import * from ._ssl_compat import *
from ._url import * from ._url import *
if six.PY3:
from base64 import encodebytes as base64encode from base64 import encodebytes as base64encode
else:
from base64 import encodestring as base64encode
__all__ = ["proxy_info", "connect", "read_headers"] __all__ = ["proxy_info", "connect", "read_headers"]
@ -47,6 +42,7 @@ except:
pass pass
HAS_PYSOCKS = False HAS_PYSOCKS = False
class proxy_info(object): class proxy_info(object):
def __init__(self, **options): def __init__(self, **options):
@ -91,10 +87,9 @@ def _open_proxied_socket(url, options, proxy):
socket_options=DEFAULT_SOCKET_OPTION + options.sockopt socket_options=DEFAULT_SOCKET_OPTION + options.sockopt
) )
if is_secure: if is_secure and HAVE_SSL:
if HAVE_SSL:
sock = _ssl_socket(sock, options.sslopt, hostname) sock = _ssl_socket(sock, options.sslopt, hostname)
else: elif is_secure:
raise WebSocketException("SSL not available.") raise WebSocketException("SSL not available.")
return sock, (hostname, port, resource) return sock, (hostname, port, resource)
@ -189,6 +184,8 @@ def _open_socket(addrinfo_list, sockopt, timeout):
err = error err = error
continue continue
else: else:
if sock:
sock.close()
raise error raise error
else: else:
break break
@ -202,10 +199,6 @@ def _open_socket(addrinfo_list, sockopt, timeout):
return sock return sock
def _can_use_sni():
return six.PY2 and sys.version_info >= (2, 7, 9) or sys.version_info >= (3, 2)
def _wrap_sni_socket(sock, sslopt, hostname, check_hostname): def _wrap_sni_socket(sock, sslopt, hostname, check_hostname):
context = ssl.SSLContext(sslopt.get('ssl_version', ssl.PROTOCOL_SSLv23)) context = ssl.SSLContext(sslopt.get('ssl_version', ssl.PROTOCOL_SSLv23))
@ -249,8 +242,7 @@ def _ssl_socket(sock, user_sslopt, hostname):
certPath = os.environ.get('WEBSOCKET_CLIENT_CA_BUNDLE') certPath = os.environ.get('WEBSOCKET_CLIENT_CA_BUNDLE')
if certPath and os.path.isfile(certPath) \ if certPath and os.path.isfile(certPath) \
and user_sslopt.get('ca_certs', None) is None \ and user_sslopt.get('ca_certs', None) is None:
and user_sslopt.get('ca_cert', None) is None:
sslopt['ca_certs'] = certPath sslopt['ca_certs'] = certPath
elif certPath and os.path.isdir(certPath) \ elif certPath and os.path.isdir(certPath) \
and user_sslopt.get('ca_cert_path', None) is None: and user_sslopt.get('ca_cert_path', None) is None:
@ -258,12 +250,7 @@ def _ssl_socket(sock, user_sslopt, hostname):
check_hostname = sslopt["cert_reqs"] != ssl.CERT_NONE and sslopt.pop( check_hostname = sslopt["cert_reqs"] != ssl.CERT_NONE and sslopt.pop(
'check_hostname', True) 'check_hostname', True)
if _can_use_sni():
sock = _wrap_sni_socket(sock, sslopt, hostname, check_hostname) sock = _wrap_sni_socket(sock, sslopt, hostname, check_hostname)
else:
sslopt.pop('check_hostname', True)
sock = ssl.wrap_socket(sock, **sslopt)
if not HAVE_CONTEXT_CHECK_HOSTNAME and check_hostname: if not HAVE_CONTEXT_CHECK_HOSTNAME and check_hostname:
match_hostname(sock.getpeercert(), hostname) match_hostname(sock.getpeercert(), hostname)
@ -273,7 +260,9 @@ def _ssl_socket(sock, user_sslopt, hostname):
def _tunnel(sock, host, port, auth): def _tunnel(sock, host, port, auth):
debug("Connecting proxy...") debug("Connecting proxy...")
connect_header = "CONNECT %s:%d HTTP/1.0\r\n" % (host, port) connect_header = "CONNECT %s:%d HTTP/1.1\r\n" % (host, port)
connect_header += "Host: %s:%d\r\n" % (host, port)
# TODO: support digest auth. # TODO: support digest auth.
if auth and auth[0]: if auth and auth[0]:
auth_str = auth[0] auth_str = auth[0]
@ -320,6 +309,9 @@ def read_headers(sock):
kv = line.split(":", 1) kv = line.split(":", 1)
if len(kv) == 2: if len(kv) == 2:
key, value = kv key, value = kv
if key.lower() == "set-cookie" and headers.get("set-cookie"):
headers["set-cookie"] = headers.get("set-cookie") + "; " + value.strip()
else:
headers[key.lower()] = value.strip() headers[key.lower()] = value.strip()
else: else:
raise WebSocketException("Invalid header") raise WebSocketException("Invalid header")

View file

@ -55,6 +55,7 @@ def enableTrace(traceable, handler = logging.StreamHandler()):
_logger.addHandler(handler) _logger.addHandler(handler)
_logger.setLevel(logging.DEBUG) _logger.setLevel(logging.DEBUG)
def dump(title, message): def dump(title, message):
if _traceEnabled: if _traceEnabled:
_logger.debug("--- " + title + " ---") _logger.debug("--- " + title + " ---")
@ -86,5 +87,6 @@ def isEnabledForError():
def isEnabledForDebug(): def isEnabledForDebug():
return _logger.isEnabledFor(logging.DEBUG) return _logger.isEnabledFor(logging.DEBUG)
def isEnabledForTrace(): def isEnabledForTrace():
return _traceEnabled return _traceEnabled

View file

@ -23,12 +23,9 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
""" """
import errno import errno
import select import selectors
import socket import socket
import six
import sys
from ._exceptions import * from ._exceptions import *
from ._ssl_compat import * from ._ssl_compat import *
from ._utils import * from ._utils import *
@ -102,7 +99,12 @@ def recv(sock, bufsize):
if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK: if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK:
raise raise
r, w, e = select.select((sock, ), (), (), sock.gettimeout()) sel = selectors.DefaultSelector()
sel.register(sock, selectors.EVENT_READ)
r = sel.select(sock.gettimeout())
sel.close()
if r: if r:
return sock.recv(bufsize) return sock.recv(bufsize)
@ -133,13 +135,13 @@ def recv_line(sock):
while True: while True:
c = recv(sock, 1) c = recv(sock, 1)
line.append(c) line.append(c)
if c == six.b("\n"): if c == b'\n':
break break
return six.b("").join(line) return b''.join(line)
def send(sock, data): def send(sock, data):
if isinstance(data, six.text_type): if isinstance(data, str):
data = data.encode('utf-8') data = data.encode('utf-8')
if not sock: if not sock:
@ -157,7 +159,12 @@ def send(sock, data):
if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK: if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK:
raise raise
r, w, e = select.select((), (sock, ), (), sock.gettimeout()) sel = selectors.DefaultSelector()
sel.register(sock, selectors.EVENT_WRITE)
w = sel.select(sock.gettimeout())
sel.close()
if w: if w:
return sock.send(data) return sock.send(data)

View file

@ -25,20 +25,14 @@ try:
from ssl import SSLError from ssl import SSLError
from ssl import SSLWantReadError from ssl import SSLWantReadError
from ssl import SSLWantWriteError from ssl import SSLWantWriteError
HAVE_CONTEXT_CHECK_HOSTNAME = False
if hasattr(ssl, 'SSLContext') and hasattr(ssl.SSLContext, 'check_hostname'): if hasattr(ssl, 'SSLContext') and hasattr(ssl.SSLContext, 'check_hostname'):
HAVE_CONTEXT_CHECK_HOSTNAME = True HAVE_CONTEXT_CHECK_HOSTNAME = True
else:
HAVE_CONTEXT_CHECK_HOSTNAME = False
if hasattr(ssl, "match_hostname"):
from ssl import match_hostname
else:
from backports.ssl_match_hostname import match_hostname
__all__.append("match_hostname")
__all__.append("HAVE_CONTEXT_CHECK_HOSTNAME")
__all__.append("HAVE_CONTEXT_CHECK_HOSTNAME")
HAVE_SSL = True HAVE_SSL = True
except ImportError: except ImportError:
# dummy class of SSLError for ssl none-support environment. # dummy class of SSLError for environment without ssl support
class SSLError(Exception): class SSLError(Exception):
pass pass
@ -48,6 +42,5 @@ except ImportError:
class SSLWantWriteError(Exception): class SSLWantWriteError(Exception):
pass pass
ssl = lambda: None ssl = None
HAVE_SSL = False HAVE_SSL = False

View file

@ -26,7 +26,7 @@ import os
import socket import socket
import struct import struct
from six.moves.urllib.parse import urlparse from urllib.parse import urlparse
__all__ = ["parse_url", "get_proxy_info"] __all__ = ["parse_url", "get_proxy_info"]
@ -47,7 +47,7 @@ def parse_url(url):
scheme, url = url.split(":", 1) scheme, url = url.split(":", 1)
parsed = urlparse(url, scheme="ws") parsed = urlparse(url, scheme="http")
if parsed.hostname: if parsed.hostname:
hostname = parsed.hostname hostname = parsed.hostname
else: else:
@ -99,10 +99,12 @@ def _is_subnet_address(hostname):
def _is_address_in_network(ip, net): def _is_address_in_network(ip, net):
ipaddr = struct.unpack('I', socket.inet_aton(ip))[0] ipaddr = struct.unpack('!I', socket.inet_aton(ip))[0]
netaddr, bits = net.split('/') netaddr, netmask = net.split('/')
netmask = struct.unpack('I', socket.inet_aton(netaddr))[0] & ((2 << int(bits) - 1) - 1) netaddr = struct.unpack('!I', socket.inet_aton(netaddr))[0]
return ipaddr & netmask == netmask
netmask = (0xFFFFFFFF << (32 - int(netmask))) & 0xFFFFFFFF
return ipaddr & netmask == netaddr
def _is_no_proxy_host(hostname, no_proxy): def _is_no_proxy_host(hostname, no_proxy):
@ -113,11 +115,15 @@ def _is_no_proxy_host(hostname, no_proxy):
if not no_proxy: if not no_proxy:
no_proxy = DEFAULT_NO_PROXY_HOST no_proxy = DEFAULT_NO_PROXY_HOST
if '*' in no_proxy:
return True
if hostname in no_proxy: if hostname in no_proxy:
return True return True
elif _is_ip_address(hostname): if _is_ip_address(hostname):
return any([_is_address_in_network(hostname, subnet) for subnet in no_proxy if _is_subnet_address(subnet)]) return any([_is_address_in_network(hostname, subnet) for subnet in no_proxy if _is_subnet_address(subnet)])
for domain in [domain for domain in no_proxy if domain.startswith('.')]:
if hostname.endswith(domain):
return True
return False return False

View file

@ -18,8 +18,6 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
""" """
import six
__all__ = ["NoLock", "validate_utf8", "extract_err_message", "extract_error_code"] __all__ = ["NoLock", "validate_utf8", "extract_err_message", "extract_error_code"]
@ -80,8 +78,6 @@ except ImportError:
state = _UTF8_ACCEPT state = _UTF8_ACCEPT
codep = 0 codep = 0
for i in utfbytes: for i in utfbytes:
if six.PY2:
i = ord(i)
state, codep = _decode(state, codep, i) state, codep = _decode(state, codep, i)
if state == _UTF8_REJECT: if state == _UTF8_REJECT:
return False return False

View file

@ -0,0 +1,7 @@
HTTP/1.1 101 WebSocket Protocol Handshake
Connection: Upgrade, Keep-Alive
Upgrade: WebSocket
Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0=
Set-Cookie: Token=ABCDE
some_header: something

View file

@ -0,0 +1,94 @@
# -*- coding: utf-8 -*-
#
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
"""
import os
import websocket as ws
from websocket._abnf import *
import sys
import unittest
sys.path[0:0] = [""]
class ABNFTest(unittest.TestCase):
def testInit(self):
a = ABNF(0,0,0,0, opcode=ABNF.OPCODE_PING)
self.assertEqual(a.fin, 0)
self.assertEqual(a.rsv1, 0)
self.assertEqual(a.rsv2, 0)
self.assertEqual(a.rsv3, 0)
self.assertEqual(a.opcode, 9)
self.assertEqual(a.data, '')
a_bad = ABNF(0,1,0,0, opcode=77)
self.assertEqual(a_bad.rsv1, 1)
self.assertEqual(a_bad.opcode, 77)
def testValidate(self):
a_invalid_ping = ABNF(0,0,0,0, opcode=ABNF.OPCODE_PING)
self.assertRaises(ws._exceptions.WebSocketProtocolException, a_invalid_ping.validate, skip_utf8_validation=False)
a_bad_rsv_value = ABNF(0,1,0,0, opcode=ABNF.OPCODE_TEXT)
self.assertRaises(ws._exceptions.WebSocketProtocolException, a_bad_rsv_value.validate, skip_utf8_validation=False)
a_bad_opcode = ABNF(0,0,0,0, opcode=77)
self.assertRaises(ws._exceptions.WebSocketProtocolException, a_bad_opcode.validate, skip_utf8_validation=False)
a_bad_close_frame = ABNF(0,0,0,0, opcode=ABNF.OPCODE_CLOSE, data=b'\x01')
self.assertRaises(ws._exceptions.WebSocketProtocolException, a_bad_close_frame.validate, skip_utf8_validation=False)
a_bad_close_frame_2 = ABNF(0,0,0,0, opcode=ABNF.OPCODE_CLOSE, data=b'\x01\x8a\xaa\xff\xdd')
self.assertRaises(ws._exceptions.WebSocketProtocolException, a_bad_close_frame_2.validate, skip_utf8_validation=False)
a_bad_close_frame_3 = ABNF(0,0,0,0, opcode=ABNF.OPCODE_CLOSE, data=b'\x03\xe7')
self.assertRaises(ws._exceptions.WebSocketProtocolException, a_bad_close_frame_3.validate, skip_utf8_validation=True)
def testMask(self):
abnf_none_data = ABNF(0,0,0,0, opcode=ABNF.OPCODE_PING, mask=1, data=None)
bytes_val = bytes("aaaa", 'utf-8')
self.assertEqual(abnf_none_data._get_masked(bytes_val), bytes_val)
abnf_str_data = ABNF(0,0,0,0, opcode=ABNF.OPCODE_PING, mask=1, data="a")
self.assertEqual(abnf_str_data._get_masked(bytes_val), b'aaaa\x00')
def testFormat(self):
abnf_bad_rsv_bits = ABNF(2,0,0,0, opcode=ABNF.OPCODE_TEXT)
self.assertRaises(ValueError, abnf_bad_rsv_bits.format)
abnf_bad_opcode = ABNF(0,0,0,0, opcode=5)
self.assertRaises(ValueError, abnf_bad_opcode.format)
abnf_length_10 = ABNF(0,0,0,0, opcode=ABNF.OPCODE_TEXT, data="abcdefghij")
self.assertEqual(b'\x01', abnf_length_10.format()[0].to_bytes(1, 'big'))
self.assertEqual(b'\x8a', abnf_length_10.format()[1].to_bytes(1, 'big'))
self.assertEqual("fin=0 opcode=1 data=abcdefghij", abnf_length_10.__str__())
abnf_length_20 = ABNF(0,0,0,0, opcode=ABNF.OPCODE_BINARY, data="abcdefghijabcdefghij")
self.assertEqual(b'\x02', abnf_length_20.format()[0].to_bytes(1, 'big'))
self.assertEqual(b'\x94', abnf_length_20.format()[1].to_bytes(1, 'big'))
abnf_no_mask = ABNF(0,0,0,0, opcode=ABNF.OPCODE_TEXT, mask=0, data=b'\x01\x8a\xcc')
self.assertEqual(b'\x01\x03\x01\x8a\xcc', abnf_no_mask.format())
def testFrameBuffer(self):
fb = frame_buffer(0, True)
self.assertEqual(fb.recv, 0)
self.assertEqual(fb.skip_utf8_validation, True)
fb.clear
self.assertEqual(fb.header, None)
self.assertEqual(fb.length, None)
self.assertEqual(fb.mask, None)
self.assertEqual(fb.has_mask(), False)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,176 @@
# -*- coding: utf-8 -*-
#
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
"""
import os
import os.path
import websocket as ws
import sys
import ssl
import unittest
sys.path[0:0] = [""]
# Skip test to access the internet.
TEST_WITH_INTERNET = os.environ.get('TEST_WITH_INTERNET', '0') == '1'
TRACEABLE = True
class WebSocketAppTest(unittest.TestCase):
class NotSetYet(object):
""" A marker class for signalling that a value hasn't been set yet.
"""
def setUp(self):
ws.enableTrace(TRACEABLE)
WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()
WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
def tearDown(self):
WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()
WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testKeepRunning(self):
""" A WebSocketApp should keep running as long as its self.keep_running
is not False (in the boolean context).
"""
def on_open(self, *args, **kwargs):
""" Set the keep_running flag for later inspection and immediately
close the connection.
"""
WebSocketAppTest.keep_running_open = self.keep_running
self.close()
def on_close(self, *args, **kwargs):
""" Set the keep_running flag for the test to use.
"""
WebSocketAppTest.keep_running_close = self.keep_running
self.send("connection should be closed here")
app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, on_close=on_close)
app.run_forever()
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testSockMaskKey(self):
""" A WebSocketApp should forward the received mask_key function down
to the actual socket.
"""
def my_mask_key_func():
return "\x00\x00\x00\x00"
app = ws.WebSocketApp('wss://stream.meetup.com/2/rsvps', get_mask_key=my_mask_key_func)
# if numpy is installed, this assertion fail
# Note: We can't use 'is' for comparing the functions directly, need to use 'id'.
self.assertEqual(id(app.get_mask_key), id(my_mask_key_func))
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testInvalidPingIntervalPingTimeout(self):
""" Test exception handling if ping_interval < ping_timeout
"""
def on_ping(app, msg):
print("Got a ping!")
app.close()
def on_pong(app, msg):
print("Got a pong! No need to respond")
app.close()
app = ws.WebSocketApp('wss://api-pub.bitfinex.com/ws/1', on_ping=on_ping, on_pong=on_pong)
self.assertRaises(ws.WebSocketException, app.run_forever, ping_interval=1, ping_timeout=2, sslopt={"cert_reqs": ssl.CERT_NONE})
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testPingInterval(self):
""" Test WebSocketApp proper ping functionality
"""
def on_ping(app, msg):
print("Got a ping!")
app.close()
def on_pong(app, msg):
print("Got a pong! No need to respond")
app.close()
app = ws.WebSocketApp('wss://api-pub.bitfinex.com/ws/1', on_ping=on_ping, on_pong=on_pong)
app.run_forever(ping_interval=2, ping_timeout=1, sslopt={"cert_reqs": ssl.CERT_NONE})
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testOpcodeClose(self):
""" Test WebSocketApp close opcode
"""
app = ws.WebSocketApp('wss://tsock.us1.twilio.com/v3/wsconnect')
app.run_forever(ping_interval=2, ping_timeout=1, ping_payload="Ping payload")
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testOpcodeBinary(self):
""" Test WebSocketApp binary opcode
"""
app = ws.WebSocketApp('streaming.vn.teslamotors.com/streaming/')
app.run_forever(ping_interval=2, ping_timeout=1, ping_payload="Ping payload")
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testBadPingInterval(self):
""" A WebSocketApp handling of negative ping_interval
"""
app = ws.WebSocketApp('wss://api-pub.bitfinex.com/ws/1')
self.assertRaises(ws.WebSocketException, app.run_forever, ping_interval=-5, sslopt={"cert_reqs": ssl.CERT_NONE})
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testBadPingTimeout(self):
""" A WebSocketApp handling of negative ping_timeout
"""
app = ws.WebSocketApp('wss://api-pub.bitfinex.com/ws/1')
self.assertRaises(ws.WebSocketException, app.run_forever, ping_timeout=-3, sslopt={"cert_reqs": ssl.CERT_NONE})
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testCloseStatusCode(self):
""" Test extraction of close frame status code and close reason in WebSocketApp
"""
def on_close(wsapp, close_status_code, close_msg):
print("on_close reached")
app = ws.WebSocketApp('wss://tsock.us1.twilio.com/v3/wsconnect', on_close=on_close)
closeframe = ws.ABNF(opcode=ws.ABNF.OPCODE_CLOSE, data=b'\x03\xe8no-init-from-client')
self.assertEqual([1000, 'no-init-from-client'], app._get_close_args(closeframe))
closeframe = ws.ABNF(opcode=ws.ABNF.OPCODE_CLOSE, data=b'')
self.assertEqual([None, None], app._get_close_args(closeframe))
app2 = ws.WebSocketApp('wss://tsock.us1.twilio.com/v3/wsconnect')
closeframe = ws.ABNF(opcode=ws.ABNF.OPCODE_CLOSE, data=b'')
self.assertEqual([None, None], app2._get_close_args(closeframe))
self.assertRaises(ws.WebSocketConnectionClosedException, app.send, data="test if connection is closed")
if __name__ == "__main__":
unittest.main()

View file

@ -26,11 +26,6 @@ import unittest
from websocket._cookiejar import SimpleCookieJar from websocket._cookiejar import SimpleCookieJar
try:
import Cookie
except:
import http.cookies as Cookie
class CookieJarTest(unittest.TestCase): class CookieJarTest(unittest.TestCase):
def testAdd(self): def testAdd(self):
@ -54,6 +49,7 @@ class CookieJarTest(unittest.TestCase):
cookie_jar = SimpleCookieJar() cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc") cookie_jar.add("a=b; c=d; domain=abc")
self.assertEqual(cookie_jar.get("abc"), "a=b; c=d") self.assertEqual(cookie_jar.get("abc"), "a=b; c=d")
self.assertEqual(cookie_jar.get(None), "")
cookie_jar = SimpleCookieJar() cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc") cookie_jar.add("a=b; c=d; domain=abc")

View file

@ -0,0 +1,150 @@
# -*- coding: utf-8 -*-
#
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
"""
import os
import os.path
import websocket as ws
from websocket._http import proxy_info, read_headers, _open_proxied_socket, _tunnel, _get_addrinfo_list, connect
import sys
import unittest
import ssl
import websocket
import socks
import socket
sys.path[0:0] = [""]
# Skip test to access the internet.
TEST_WITH_INTERNET = os.environ.get('TEST_WITH_INTERNET', '0') == '1'
class SockMock(object):
def __init__(self):
self.data = []
self.sent = []
def add_packet(self, data):
self.data.append(data)
def gettimeout(self):
return None
def recv(self, bufsize):
if self.data:
e = self.data.pop(0)
if isinstance(e, Exception):
raise e
if len(e) > bufsize:
self.data.insert(0, e[bufsize:])
return e[:bufsize]
def send(self, data):
self.sent.append(data)
return len(data)
def close(self):
pass
class HeaderSockMock(SockMock):
def __init__(self, fname):
SockMock.__init__(self)
path = os.path.join(os.path.dirname(__file__), fname)
with open(path, "rb") as f:
self.add_packet(f.read())
class OptsList():
def __init__(self):
self.timeout = 1
self.sockopt = []
class HttpTest(unittest.TestCase):
def testReadHeader(self):
status, header, status_message = read_headers(HeaderSockMock("data/header01.txt"))
self.assertEqual(status, 101)
self.assertEqual(header["connection"], "Upgrade")
# header02.txt is intentionally malformed
self.assertRaises(ws.WebSocketException, read_headers, HeaderSockMock("data/header02.txt"))
def testTunnel(self):
self.assertRaises(ws.WebSocketProxyException, _tunnel, HeaderSockMock("data/header01.txt"), "example.com", 80, ("username", "password"))
self.assertRaises(ws.WebSocketProxyException, _tunnel, HeaderSockMock("data/header02.txt"), "example.com", 80, ("username", "password"))
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testConnect(self):
# Not currently testing an actual proxy connection, so just check whether TypeError is raised. This requires internet for a DNS lookup
self.assertRaises(TypeError, _open_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host=None, http_proxy_port=None, proxy_type=None))
self.assertRaises(TypeError, _open_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="http"))
self.assertRaises(TypeError, _open_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="socks4"))
self.assertRaises(TypeError, _open_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="socks5h"))
self.assertRaises(TypeError, _get_addrinfo_list, None, 80, True, proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http"))
self.assertRaises(TypeError, _get_addrinfo_list, None, 80, True, proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http"))
self.assertRaises(socks.ProxyConnectionError, connect, "wss://example.com", OptsList(), proxy_info(http_proxy_host="127.0.0.1", http_proxy_port=8080, proxy_type="socks4"), None)
self.assertRaises(socket.timeout, connect, "wss://google.com", OptsList(), proxy_info(http_proxy_host="8.8.8.8", http_proxy_port=8080, proxy_type="http"), None)
self.assertEqual(
connect("wss://google.com", OptsList(), proxy_info(http_proxy_host="8.8.8.8", http_proxy_port=8080, proxy_type="http"), True),
(True, ("google.com", 443, "/")))
# The following test fails on Mac OS with a gaierror, not an OverflowError
# self.assertRaises(OverflowError, connect, "wss://example.com", OptsList(), proxy_info(http_proxy_host="127.0.0.1", http_proxy_port=99999, proxy_type="socks4", timeout=2), False)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testSSLopt(self):
ssloptions = {
"cert_reqs": ssl.CERT_NONE,
"check_hostname": False,
"ssl_version": ssl.PROTOCOL_SSLv23,
"ciphers": "TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256:\
TLS_AES_128_GCM_SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:\
ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384:\
ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:\
DHE-RSA-CHACHA20-POLY1305:ECDHE-ECDSA-AES128-GCM-SHA256:\
ECDHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES128-GCM-SHA256:\
ECDHE-ECDSA-AES256-SHA384:ECDHE-RSA-AES256-SHA384:\
DHE-RSA-AES256-SHA256:ECDHE-ECDSA-AES128-SHA256:\
ECDHE-RSA-AES128-SHA256:DHE-RSA-AES128-SHA256:\
ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA",
"ecdh_curve": "prime256v1"
}
ws_ssl1 = websocket.WebSocket(sslopt=ssloptions)
ws_ssl1.connect("wss://api.bitfinex.com/ws/2")
ws_ssl1.send("Hello")
ws_ssl1.close()
ws_ssl2 = websocket.WebSocket(sslopt={"check_hostname": True})
ws_ssl2.connect("wss://api.bitfinex.com/ws/2")
ws_ssl2.close
def testProxyInfo(self):
self.assertEqual(proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http").type, "http")
self.assertRaises(ValueError, proxy_info, http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="badval")
self.assertEqual(proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="http").host, "example.com")
self.assertEqual(proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http").port, "8080")
self.assertEqual(proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http").auth, None)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,301 @@
# -*- coding: utf-8 -*-
#
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
"""
import sys
import os
import unittest
sys.path[0:0] = [""]
from websocket._url import get_proxy_info, parse_url, _is_address_in_network, _is_no_proxy_host
class UrlTest(unittest.TestCase):
def test_address_in_network(self):
self.assertTrue(_is_address_in_network('127.0.0.1', '127.0.0.0/8'))
self.assertTrue(_is_address_in_network('127.1.0.1', '127.0.0.0/8'))
self.assertFalse(_is_address_in_network('127.1.0.1', '127.0.0.0/24'))
def testParseUrl(self):
p = parse_url("ws://www.example.com/r")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 80)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com/r/")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 80)
self.assertEqual(p[2], "/r/")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com/")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 80)
self.assertEqual(p[2], "/")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 80)
self.assertEqual(p[2], "/")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com:8080/r")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com:8080/")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com:8080")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/")
self.assertEqual(p[3], False)
p = parse_url("wss://www.example.com:8080/r")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], True)
p = parse_url("wss://www.example.com:8080/r?key=value")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/r?key=value")
self.assertEqual(p[3], True)
self.assertRaises(ValueError, parse_url, "http://www.example.com/r")
p = parse_url("ws://[2a03:4000:123:83::3]/r")
self.assertEqual(p[0], "2a03:4000:123:83::3")
self.assertEqual(p[1], 80)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], False)
p = parse_url("ws://[2a03:4000:123:83::3]:8080/r")
self.assertEqual(p[0], "2a03:4000:123:83::3")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], False)
p = parse_url("wss://[2a03:4000:123:83::3]/r")
self.assertEqual(p[0], "2a03:4000:123:83::3")
self.assertEqual(p[1], 443)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], True)
p = parse_url("wss://[2a03:4000:123:83::3]:8080/r")
self.assertEqual(p[0], "2a03:4000:123:83::3")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], True)
class IsNoProxyHostTest(unittest.TestCase):
def setUp(self):
self.no_proxy = os.environ.get("no_proxy", None)
if "no_proxy" in os.environ:
del os.environ["no_proxy"]
def tearDown(self):
if self.no_proxy:
os.environ["no_proxy"] = self.no_proxy
elif "no_proxy" in os.environ:
del os.environ["no_proxy"]
def testMatchAll(self):
self.assertTrue(_is_no_proxy_host("any.websocket.org", ['*']))
self.assertTrue(_is_no_proxy_host("192.168.0.1", ['*']))
self.assertTrue(_is_no_proxy_host("any.websocket.org", ['other.websocket.org', '*']))
os.environ['no_proxy'] = '*'
self.assertTrue(_is_no_proxy_host("any.websocket.org", None))
self.assertTrue(_is_no_proxy_host("192.168.0.1", None))
os.environ['no_proxy'] = 'other.websocket.org, *'
self.assertTrue(_is_no_proxy_host("any.websocket.org", None))
def testIpAddress(self):
self.assertTrue(_is_no_proxy_host("127.0.0.1", ['127.0.0.1']))
self.assertFalse(_is_no_proxy_host("127.0.0.2", ['127.0.0.1']))
self.assertTrue(_is_no_proxy_host("127.0.0.1", ['other.websocket.org', '127.0.0.1']))
self.assertFalse(_is_no_proxy_host("127.0.0.2", ['other.websocket.org', '127.0.0.1']))
os.environ['no_proxy'] = '127.0.0.1'
self.assertTrue(_is_no_proxy_host("127.0.0.1", None))
self.assertFalse(_is_no_proxy_host("127.0.0.2", None))
os.environ['no_proxy'] = 'other.websocket.org, 127.0.0.1'
self.assertTrue(_is_no_proxy_host("127.0.0.1", None))
self.assertFalse(_is_no_proxy_host("127.0.0.2", None))
def testIpAddressInRange(self):
self.assertTrue(_is_no_proxy_host("127.0.0.1", ['127.0.0.0/8']))
self.assertTrue(_is_no_proxy_host("127.0.0.2", ['127.0.0.0/8']))
self.assertFalse(_is_no_proxy_host("127.1.0.1", ['127.0.0.0/24']))
os.environ['no_proxy'] = '127.0.0.0/8'
self.assertTrue(_is_no_proxy_host("127.0.0.1", None))
self.assertTrue(_is_no_proxy_host("127.0.0.2", None))
os.environ['no_proxy'] = '127.0.0.0/24'
self.assertFalse(_is_no_proxy_host("127.1.0.1", None))
def testHostnameMatch(self):
self.assertTrue(_is_no_proxy_host("my.websocket.org", ['my.websocket.org']))
self.assertTrue(_is_no_proxy_host("my.websocket.org", ['other.websocket.org', 'my.websocket.org']))
self.assertFalse(_is_no_proxy_host("my.websocket.org", ['other.websocket.org']))
os.environ['no_proxy'] = 'my.websocket.org'
self.assertTrue(_is_no_proxy_host("my.websocket.org", None))
self.assertFalse(_is_no_proxy_host("other.websocket.org", None))
os.environ['no_proxy'] = 'other.websocket.org, my.websocket.org'
self.assertTrue(_is_no_proxy_host("my.websocket.org", None))
def testHostnameMatchDomain(self):
self.assertTrue(_is_no_proxy_host("any.websocket.org", ['.websocket.org']))
self.assertTrue(_is_no_proxy_host("my.other.websocket.org", ['.websocket.org']))
self.assertTrue(_is_no_proxy_host("any.websocket.org", ['my.websocket.org', '.websocket.org']))
self.assertFalse(_is_no_proxy_host("any.websocket.com", ['.websocket.org']))
os.environ['no_proxy'] = '.websocket.org'
self.assertTrue(_is_no_proxy_host("any.websocket.org", None))
self.assertTrue(_is_no_proxy_host("my.other.websocket.org", None))
self.assertFalse(_is_no_proxy_host("any.websocket.com", None))
os.environ['no_proxy'] = 'my.websocket.org, .websocket.org'
self.assertTrue(_is_no_proxy_host("any.websocket.org", None))
class ProxyInfoTest(unittest.TestCase):
def setUp(self):
self.http_proxy = os.environ.get("http_proxy", None)
self.https_proxy = os.environ.get("https_proxy", None)
self.no_proxy = os.environ.get("no_proxy", None)
if "http_proxy" in os.environ:
del os.environ["http_proxy"]
if "https_proxy" in os.environ:
del os.environ["https_proxy"]
if "no_proxy" in os.environ:
del os.environ["no_proxy"]
def tearDown(self):
if self.http_proxy:
os.environ["http_proxy"] = self.http_proxy
elif "http_proxy" in os.environ:
del os.environ["http_proxy"]
if self.https_proxy:
os.environ["https_proxy"] = self.https_proxy
elif "https_proxy" in os.environ:
del os.environ["https_proxy"]
if self.no_proxy:
os.environ["no_proxy"] = self.no_proxy
elif "no_proxy" in os.environ:
del os.environ["no_proxy"]
def testProxyFromArgs(self):
self.assertEqual(get_proxy_info("echo.websocket.org", False, proxy_host="localhost"), ("localhost", 0, None))
self.assertEqual(get_proxy_info("echo.websocket.org", False, proxy_host="localhost", proxy_port=3128),
("localhost", 3128, None))
self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost"), ("localhost", 0, None))
self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128),
("localhost", 3128, None))
self.assertEqual(get_proxy_info("echo.websocket.org", False, proxy_host="localhost", proxy_auth=("a", "b")),
("localhost", 0, ("a", "b")))
self.assertEqual(
get_proxy_info("echo.websocket.org", False, proxy_host="localhost", proxy_port=3128, proxy_auth=("a", "b")),
("localhost", 3128, ("a", "b")))
self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_auth=("a", "b")),
("localhost", 0, ("a", "b")))
self.assertEqual(
get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128, proxy_auth=("a", "b")),
("localhost", 3128, ("a", "b")))
self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128,
no_proxy=["example.com"], proxy_auth=("a", "b")),
("localhost", 3128, ("a", "b")))
self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128,
no_proxy=["echo.websocket.org"], proxy_auth=("a", "b")),
(None, 0, None))
def testProxyFromEnv(self):
os.environ["http_proxy"] = "http://localhost/"
self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", None, None))
os.environ["http_proxy"] = "http://localhost:3128/"
self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", 3128, None))
os.environ["http_proxy"] = "http://localhost/"
os.environ["https_proxy"] = "http://localhost2/"
self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", None, None))
os.environ["http_proxy"] = "http://localhost:3128/"
os.environ["https_proxy"] = "http://localhost2:3128/"
self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", 3128, None))
os.environ["http_proxy"] = "http://localhost/"
os.environ["https_proxy"] = "http://localhost2/"
self.assertEqual(get_proxy_info("echo.websocket.org", True), ("localhost2", None, None))
os.environ["http_proxy"] = "http://localhost:3128/"
os.environ["https_proxy"] = "http://localhost2:3128/"
self.assertEqual(get_proxy_info("echo.websocket.org", True), ("localhost2", 3128, None))
os.environ["http_proxy"] = "http://a:b@localhost/"
self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", None, ("a", "b")))
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", 3128, ("a", "b")))
os.environ["http_proxy"] = "http://a:b@localhost/"
os.environ["https_proxy"] = "http://a:b@localhost2/"
self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", None, ("a", "b")))
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", 3128, ("a", "b")))
os.environ["http_proxy"] = "http://a:b@localhost/"
os.environ["https_proxy"] = "http://a:b@localhost2/"
self.assertEqual(get_proxy_info("echo.websocket.org", True), ("localhost2", None, ("a", "b")))
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
self.assertEqual(get_proxy_info("echo.websocket.org", True), ("localhost2", 3128, ("a", "b")))
os.environ["http_proxy"] = "http://a:b@localhost/"
os.environ["https_proxy"] = "http://a:b@localhost2/"
os.environ["no_proxy"] = "example1.com,example2.com"
self.assertEqual(get_proxy_info("example.1.com", True), ("localhost2", None, ("a", "b")))
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
os.environ["no_proxy"] = "example1.com,example2.com, echo.websocket.org"
self.assertEqual(get_proxy_info("echo.websocket.org", True), (None, 0, None))
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
os.environ["no_proxy"] = "example1.com,example2.com, .websocket.org"
self.assertEqual(get_proxy_info("echo.websocket.org", True), (None, 0, None))
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
os.environ["no_proxy"] = "127.0.0.0/8, 192.168.0.0/16"
self.assertEqual(get_proxy_info("127.0.0.1", False), (None, 0, None))
self.assertEqual(get_proxy_info("192.168.1.1", False), (None, 0, None))
if __name__ == "__main__":
unittest.main()

View file

@ -27,32 +27,20 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
import sys import sys
sys.path[0:0] = [""] sys.path[0:0] = [""]
import os import os
import os.path import os.path
import socket import socket
import six
# websocket-client
import websocket as ws import websocket as ws
from websocket._handshake import _create_sec_websocket_key, \ from websocket._handshake import _create_sec_websocket_key, \
_validate as _validate_header _validate as _validate_header
from websocket._http import read_headers from websocket._http import read_headers
from websocket._url import get_proxy_info, parse_url
from websocket._utils import validate_utf8 from websocket._utils import validate_utf8
if six.PY3:
from base64 import decodebytes as base64decode from base64 import decodebytes as base64decode
else:
from base64 import decodestring as base64decode
if sys.version_info[0] == 2 and sys.version_info[1] < 7:
import unittest2 as unittest
else:
import unittest import unittest
try: try:
import ssl
from ssl import SSLError from ssl import SSLError
except ImportError: except ImportError:
# dummy class of SSLError for ssl none-support environment. # dummy class of SSLError for ssl none-support environment.
@ -61,9 +49,6 @@ except ImportError:
# Skip test to access the internet. # Skip test to access the internet.
TEST_WITH_INTERNET = os.environ.get('TEST_WITH_INTERNET', '0') == '1' TEST_WITH_INTERNET = os.environ.get('TEST_WITH_INTERNET', '0') == '1'
# Skip Secure WebSocket test.
TEST_SECURE_WS = True
TRACEABLE = True TRACEABLE = True
@ -121,102 +106,24 @@ class WebSocketTest(unittest.TestCase):
self.assertEqual(ws.getdefaulttimeout(), 10) self.assertEqual(ws.getdefaulttimeout(), 10)
ws.setdefaulttimeout(None) ws.setdefaulttimeout(None)
def testParseUrl(self):
p = parse_url("ws://www.example.com/r")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 80)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com/r/")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 80)
self.assertEqual(p[2], "/r/")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com/")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 80)
self.assertEqual(p[2], "/")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 80)
self.assertEqual(p[2], "/")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com:8080/r")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com:8080/")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com:8080")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/")
self.assertEqual(p[3], False)
p = parse_url("wss://www.example.com:8080/r")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], True)
p = parse_url("wss://www.example.com:8080/r?key=value")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/r?key=value")
self.assertEqual(p[3], True)
self.assertRaises(ValueError, parse_url, "http://www.example.com/r")
if sys.version_info[0] == 2 and sys.version_info[1] < 7:
return
p = parse_url("ws://[2a03:4000:123:83::3]/r")
self.assertEqual(p[0], "2a03:4000:123:83::3")
self.assertEqual(p[1], 80)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], False)
p = parse_url("ws://[2a03:4000:123:83::3]:8080/r")
self.assertEqual(p[0], "2a03:4000:123:83::3")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], False)
p = parse_url("wss://[2a03:4000:123:83::3]/r")
self.assertEqual(p[0], "2a03:4000:123:83::3")
self.assertEqual(p[1], 443)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], True)
p = parse_url("wss://[2a03:4000:123:83::3]:8080/r")
self.assertEqual(p[0], "2a03:4000:123:83::3")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], True)
def testWSKey(self): def testWSKey(self):
key = _create_sec_websocket_key() key = _create_sec_websocket_key()
self.assertTrue(key != 24) self.assertTrue(key != 24)
self.assertTrue(six.u("¥n") not in key) self.assertTrue(str("¥n") not in key)
def testNonce(self):
""" WebSocket key should be a random 16-byte nonce.
"""
key = _create_sec_websocket_key()
nonce = base64decode(key.encode("utf-8"))
self.assertEqual(16, len(nonce))
def testWsUtils(self): def testWsUtils(self):
key = "c6b8hTg4EeGb2gQMztV1/g==" key = "c6b8hTg4EeGb2gQMztV1/g=="
required_header = { required_header = {
"upgrade": "websocket", "upgrade": "websocket",
"connection": "upgrade", "connection": "upgrade",
"sec-websocket-accept": "Kxep+hNu9n51529fGidYu7a3wO0=", "sec-websocket-accept": "Kxep+hNu9n51529fGidYu7a3wO0="}
}
self.assertEqual(_validate_header(required_header, key, None), (True, None)) self.assertEqual(_validate_header(required_header, key, None), (True, None))
header = required_header.copy() header = required_header.copy()
@ -240,6 +147,7 @@ class WebSocketTest(unittest.TestCase):
header = required_header.copy() header = required_header.copy()
header["sec-websocket-protocol"] = "sub1" header["sec-websocket-protocol"] = "sub1"
self.assertEqual(_validate_header(header, key, ["sub1", "sub2"]), (True, "sub1")) self.assertEqual(_validate_header(header, key, ["sub1", "sub2"]), (True, "sub1"))
# This case will print out a logging error using the error() function, but that is expected
self.assertEqual(_validate_header(header, key, ["sub2", "sub3"]), (False, None)) self.assertEqual(_validate_header(header, key, ["sub2", "sub3"]), (False, None))
header = required_header.copy() header = required_header.copy()
@ -247,6 +155,7 @@ class WebSocketTest(unittest.TestCase):
self.assertEqual(_validate_header(header, key, ["Sub1", "suB2"]), (True, "sub1")) self.assertEqual(_validate_header(header, key, ["Sub1", "suB2"]), (True, "sub1"))
header = required_header.copy() header = required_header.copy()
# This case will print out a logging error using the error() function, but that is expected
self.assertEqual(_validate_header(header, key, ["Sub1", "suB2"]), (False, None)) self.assertEqual(_validate_header(header, key, ["Sub1", "suB2"]), (False, None))
def testReadHeader(self): def testReadHeader(self):
@ -254,6 +163,10 @@ class WebSocketTest(unittest.TestCase):
self.assertEqual(status, 101) self.assertEqual(status, 101)
self.assertEqual(header["connection"], "Upgrade") self.assertEqual(header["connection"], "Upgrade")
status, header, status_message = read_headers(HeaderSockMock("data/header03.txt"))
self.assertEqual(status, 101)
self.assertEqual(header["connection"], "Upgrade, Keep-Alive")
HeaderSockMock("data/header02.txt") HeaderSockMock("data/header02.txt")
self.assertRaises(ws.WebSocketException, read_headers, HeaderSockMock("data/header02.txt")) self.assertRaises(ws.WebSocketException, read_headers, HeaderSockMock("data/header02.txt"))
@ -263,26 +176,26 @@ class WebSocketTest(unittest.TestCase):
sock.set_mask_key(create_mask_key) sock.set_mask_key(create_mask_key)
s = sock.sock = HeaderSockMock("data/header01.txt") s = sock.sock = HeaderSockMock("data/header01.txt")
sock.send("Hello") sock.send("Hello")
self.assertEqual(s.sent[0], six.b("\x81\x85abcd)\x07\x0f\x08\x0e")) self.assertEqual(s.sent[0], b'\x81\x85abcd)\x07\x0f\x08\x0e')
sock.send("こんにちは") sock.send("こんにちは")
self.assertEqual(s.sent[1], six.b("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")) self.assertEqual(s.sent[1], b'\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc')
sock.send(u"こんにちは") # sock.send("x" * 5000)
self.assertEqual(s.sent[1], six.b("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")) # self.assertEqual(s.sent[1], b'\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")
sock.send("x" * 127) self.assertEqual(sock.send_binary(b'1111111111101'), 19)
def testRecv(self): def testRecv(self):
# TODO: add longer frame data # TODO: add longer frame data
sock = ws.WebSocket() sock = ws.WebSocket()
s = sock.sock = SockMock() s = sock.sock = SockMock()
something = six.b("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc") something = b'\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc'
s.add_packet(something) s.add_packet(something)
data = sock.recv() data = sock.recv()
self.assertEqual(data, "こんにちは") self.assertEqual(data, "こんにちは")
s.add_packet(six.b("\x81\x85abcd)\x07\x0f\x08\x0e")) s.add_packet(b'\x81\x85abcd)\x07\x0f\x08\x0e')
data = sock.recv() data = sock.recv()
self.assertEqual(data, "Hello") self.assertEqual(data, "Hello")
@ -302,32 +215,28 @@ class WebSocketTest(unittest.TestCase):
def testInternalRecvStrict(self): def testInternalRecvStrict(self):
sock = ws.WebSocket() sock = ws.WebSocket()
s = sock.sock = SockMock() s = sock.sock = SockMock()
s.add_packet(six.b("foo")) s.add_packet(b'foo')
s.add_packet(socket.timeout()) s.add_packet(socket.timeout())
s.add_packet(six.b("bar")) s.add_packet(b'bar')
# s.add_packet(SSLError("The read operation timed out")) # s.add_packet(SSLError("The read operation timed out"))
s.add_packet(six.b("baz")) s.add_packet(b'baz')
with self.assertRaises(ws.WebSocketTimeoutException): with self.assertRaises(ws.WebSocketTimeoutException):
sock.frame_buffer.recv_strict(9) sock.frame_buffer.recv_strict(9)
# if six.PY2:
# with self.assertRaises(ws.WebSocketTimeoutException):
# data = sock._recv_strict(9)
# else:
# with self.assertRaises(SSLError): # with self.assertRaises(SSLError):
# data = sock._recv_strict(9) # data = sock._recv_strict(9)
data = sock.frame_buffer.recv_strict(9) data = sock.frame_buffer.recv_strict(9)
self.assertEqual(data, six.b("foobarbaz")) self.assertEqual(data, b'foobarbaz')
with self.assertRaises(ws.WebSocketConnectionClosedException): with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.frame_buffer.recv_strict(1) sock.frame_buffer.recv_strict(1)
def testRecvTimeout(self): def testRecvTimeout(self):
sock = ws.WebSocket() sock = ws.WebSocket()
s = sock.sock = SockMock() s = sock.sock = SockMock()
s.add_packet(six.b("\x81")) s.add_packet(b'\x81')
s.add_packet(socket.timeout()) s.add_packet(socket.timeout())
s.add_packet(six.b("\x8dabcd\x29\x07\x0f\x08\x0e")) s.add_packet(b'\x8dabcd\x29\x07\x0f\x08\x0e')
s.add_packet(socket.timeout()) s.add_packet(socket.timeout())
s.add_packet(six.b("\x4e\x43\x33\x0e\x10\x0f\x00\x40")) s.add_packet(b'\x4e\x43\x33\x0e\x10\x0f\x00\x40')
with self.assertRaises(ws.WebSocketTimeoutException): with self.assertRaises(ws.WebSocketTimeoutException):
sock.recv() sock.recv()
with self.assertRaises(ws.WebSocketTimeoutException): with self.assertRaises(ws.WebSocketTimeoutException):
@ -341,9 +250,9 @@ class WebSocketTest(unittest.TestCase):
sock = ws.WebSocket() sock = ws.WebSocket()
s = sock.sock = SockMock() s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Brevity is " # OPCODE=TEXT, FIN=0, MSG="Brevity is "
s.add_packet(six.b("\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")) s.add_packet(b'\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C')
# OPCODE=CONT, FIN=1, MSG="the soul of wit" # OPCODE=CONT, FIN=1, MSG="the soul of wit"
s.add_packet(six.b("\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17")) s.add_packet(b'\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17')
data = sock.recv() data = sock.recv()
self.assertEqual(data, "Brevity is the soul of wit") self.assertEqual(data, "Brevity is the soul of wit")
with self.assertRaises(ws.WebSocketConnectionClosedException): with self.assertRaises(ws.WebSocketConnectionClosedException):
@ -353,21 +262,21 @@ class WebSocketTest(unittest.TestCase):
sock = ws.WebSocket(fire_cont_frame=True) sock = ws.WebSocket(fire_cont_frame=True)
s = sock.sock = SockMock() s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Brevity is " # OPCODE=TEXT, FIN=0, MSG="Brevity is "
s.add_packet(six.b("\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")) s.add_packet(b'\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C')
# OPCODE=CONT, FIN=0, MSG="Brevity is " # OPCODE=CONT, FIN=0, MSG="Brevity is "
s.add_packet(six.b("\x00\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")) s.add_packet(b'\x00\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C')
# OPCODE=CONT, FIN=1, MSG="the soul of wit" # OPCODE=CONT, FIN=1, MSG="the soul of wit"
s.add_packet(six.b("\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17")) s.add_packet(b'\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17')
_, data = sock.recv_data() _, data = sock.recv_data()
self.assertEqual(data, six.b("Brevity is ")) self.assertEqual(data, b'Brevity is ')
_, data = sock.recv_data() _, data = sock.recv_data()
self.assertEqual(data, six.b("Brevity is ")) self.assertEqual(data, b'Brevity is ')
_, data = sock.recv_data() _, data = sock.recv_data()
self.assertEqual(data, six.b("the soul of wit")) self.assertEqual(data, b'the soul of wit')
# OPCODE=CONT, FIN=0, MSG="Brevity is " # OPCODE=CONT, FIN=0, MSG="Brevity is "
s.add_packet(six.b("\x80\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")) s.add_packet(b'\x80\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C')
with self.assertRaises(ws.WebSocketException): with self.assertRaises(ws.WebSocketException):
sock.recv_data() sock.recv_data()
@ -377,15 +286,13 @@ class WebSocketTest(unittest.TestCase):
def testClose(self): def testClose(self):
sock = ws.WebSocket() sock = ws.WebSocket()
sock.sock = SockMock()
sock.connected = True sock.connected = True
sock.close() self.assertRaises(ws._exceptions.WebSocketConnectionClosedException, sock.close)
self.assertEqual(sock.connected, False)
sock = ws.WebSocket() sock = ws.WebSocket()
s = sock.sock = SockMock() s = sock.sock = SockMock()
sock.connected = True sock.connected = True
s.add_packet(six.b('\x88\x80\x17\x98p\x84')) s.add_packet(b'\x88\x80\x17\x98p\x84')
sock.recv() sock.recv()
self.assertEqual(sock.connected, False) self.assertEqual(sock.connected, False)
@ -393,20 +300,18 @@ class WebSocketTest(unittest.TestCase):
sock = ws.WebSocket() sock = ws.WebSocket()
s = sock.sock = SockMock() s = sock.sock = SockMock()
# OPCODE=CONT, FIN=1, MSG="the soul of wit" # OPCODE=CONT, FIN=1, MSG="the soul of wit"
s.add_packet(six.b("\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17")) s.add_packet(b'\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17')
self.assertRaises(ws.WebSocketException, sock.recv) self.assertRaises(ws.WebSocketException, sock.recv)
def testRecvWithProlongedFragmentation(self): def testRecvWithProlongedFragmentation(self):
sock = ws.WebSocket() sock = ws.WebSocket()
s = sock.sock = SockMock() s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Once more unto the breach, " # OPCODE=TEXT, FIN=0, MSG="Once more unto the breach, "
s.add_packet(six.b("\x01\x9babcd.\x0c\x00\x01A\x0f\x0c\x16\x04B\x16\n\x15" s.add_packet(b'\x01\x9babcd.\x0c\x00\x01A\x0f\x0c\x16\x04B\x16\n\x15\rC\x10\t\x07C\x06\x13\x07\x02\x07\tNC')
"\rC\x10\t\x07C\x06\x13\x07\x02\x07\tNC"))
# OPCODE=CONT, FIN=0, MSG="dear friends, " # OPCODE=CONT, FIN=0, MSG="dear friends, "
s.add_packet(six.b("\x00\x8eabcd\x05\x07\x02\x16A\x04\x11\r\x04\x0c\x07" s.add_packet(b'\x00\x8eabcd\x05\x07\x02\x16A\x04\x11\r\x04\x0c\x07\x17MB')
"\x17MB"))
# OPCODE=CONT, FIN=1, MSG="once more" # OPCODE=CONT, FIN=1, MSG="once more"
s.add_packet(six.b("\x80\x89abcd\x0e\x0c\x00\x01A\x0f\x0c\x16\x04")) s.add_packet(b'\x80\x89abcd\x0e\x0c\x00\x01A\x0f\x0c\x16\x04')
data = sock.recv() data = sock.recv()
self.assertEqual( self.assertEqual(
data, data,
@ -419,19 +324,18 @@ class WebSocketTest(unittest.TestCase):
sock.set_mask_key(create_mask_key) sock.set_mask_key(create_mask_key)
s = sock.sock = SockMock() s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Too much " # OPCODE=TEXT, FIN=0, MSG="Too much "
s.add_packet(six.b("\x01\x89abcd5\r\x0cD\x0c\x17\x00\x0cA")) s.add_packet(b'\x01\x89abcd5\r\x0cD\x0c\x17\x00\x0cA')
# OPCODE=PING, FIN=1, MSG="Please PONG this" # OPCODE=PING, FIN=1, MSG="Please PONG this"
s.add_packet(six.b("\x89\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17")) s.add_packet(b'\x89\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17')
# OPCODE=CONT, FIN=1, MSG="of a good thing" # OPCODE=CONT, FIN=1, MSG="of a good thing"
s.add_packet(six.b("\x80\x8fabcd\x0e\x04C\x05A\x05\x0c\x0b\x05B\x17\x0c" s.add_packet(b'\x80\x8fabcd\x0e\x04C\x05A\x05\x0c\x0b\x05B\x17\x0c\x08\x0c\x04')
"\x08\x0c\x04"))
data = sock.recv() data = sock.recv()
self.assertEqual(data, "Too much of a good thing") self.assertEqual(data, "Too much of a good thing")
with self.assertRaises(ws.WebSocketConnectionClosedException): with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv() sock.recv()
self.assertEqual( self.assertEqual(
s.sent[0], s.sent[0],
six.b("\x8a\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17")) b'\x8a\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17')
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testWebSocket(self): def testWebSocket(self):
@ -441,9 +345,10 @@ class WebSocketTest(unittest.TestCase):
result = s.recv() result = s.recv()
self.assertEqual(result, "Hello, World") self.assertEqual(result, "Hello, World")
s.send(u"こにゃにゃちは、世界") s.send("こにゃにゃちは、世界")
result = s.recv() result = s.recv()
self.assertEqual(result, "こにゃにゃちは、世界") self.assertEqual(result, "こにゃにゃちは、世界")
self.assertRaises(ValueError, s.send_close, -1, "")
s.close() s.close()
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
@ -455,22 +360,17 @@ class WebSocketTest(unittest.TestCase):
s.close() s.close()
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
@unittest.skipUnless(TEST_SECURE_WS, "wss://echo.websocket.org doesn't work well.")
def testSecureWebSocket(self): def testSecureWebSocket(self):
if 1:
import ssl import ssl
s = ws.create_connection("wss://echo.websocket.org/") s = ws.create_connection("wss://api.bitfinex.com/ws/2")
self.assertNotEqual(s, None) self.assertNotEqual(s, None)
self.assertTrue(isinstance(s.sock, ssl.SSLSocket)) self.assertTrue(isinstance(s.sock, ssl.SSLSocket))
s.send("Hello, World") self.assertEqual(s.getstatus(), 101)
result = s.recv() self.assertNotEqual(s.getheaders(), None)
self.assertEqual(result, "Hello, World") s.settimeout(10)
s.send(u"こにゃにゃちは、世界") self.assertEqual(s.gettimeout(), 10)
result = s.recv() self.assertEqual(s.getsubprotocol(), None)
self.assertEqual(result, "こにゃにゃちは、世界") s.abort()
s.close()
#except:
# pass
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testWebSocketWithCustomHeader(self): def testWebSocketWithCustomHeader(self):
@ -480,6 +380,7 @@ class WebSocketTest(unittest.TestCase):
s.send("Hello, World") s.send("Hello, World")
result = s.recv() result = s.recv()
self.assertEqual(result, "Hello, World") self.assertEqual(result, "Hello, World")
self.assertRaises(ValueError, s.close, -1, "")
s.close() s.close()
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
@ -490,87 +391,6 @@ class WebSocketTest(unittest.TestCase):
self.assertRaises(ws.WebSocketConnectionClosedException, s.send, "Hello") self.assertRaises(ws.WebSocketConnectionClosedException, s.send, "Hello")
self.assertRaises(ws.WebSocketConnectionClosedException, s.recv) self.assertRaises(ws.WebSocketConnectionClosedException, s.recv)
def testNonce(self):
""" WebSocket key should be a random 16-byte nonce.
"""
key = _create_sec_websocket_key()
nonce = base64decode(key.encode("utf-8"))
self.assertEqual(16, len(nonce))
class WebSocketAppTest(unittest.TestCase):
class NotSetYet(object):
""" A marker class for signalling that a value hasn't been set yet.
"""
def setUp(self):
ws.enableTrace(TRACEABLE)
WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()
WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
def tearDown(self):
WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()
WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testKeepRunning(self):
""" A WebSocketApp should keep running as long as its self.keep_running
is not False (in the boolean context).
"""
def on_open(self, *args, **kwargs):
""" Set the keep_running flag for later inspection and immediately
close the connection.
"""
WebSocketAppTest.keep_running_open = self.keep_running
self.close()
def on_close(self, *args, **kwargs):
""" Set the keep_running flag for the test to use.
"""
WebSocketAppTest.keep_running_close = self.keep_running
app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, on_close=on_close)
app.run_forever()
# if numpy is installed, this assertion fail
# self.assertFalse(isinstance(WebSocketAppTest.keep_running_open,
# WebSocketAppTest.NotSetYet))
# self.assertFalse(isinstance(WebSocketAppTest.keep_running_close,
# WebSocketAppTest.NotSetYet))
# self.assertEqual(True, WebSocketAppTest.keep_running_open)
# self.assertEqual(False, WebSocketAppTest.keep_running_close)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testSockMaskKey(self):
""" A WebSocketApp should forward the received mask_key function down
to the actual socket.
"""
def my_mask_key_func():
pass
def on_open(self, *args, **kwargs):
""" Set the value so the test can use it later on and immediately
close the connection.
"""
WebSocketAppTest.get_mask_key_id = id(self.get_mask_key)
self.close()
app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, get_mask_key=my_mask_key_func)
app.run_forever()
# if numpu is installed, this assertion fail
# Note: We can't use 'is' for comparing the functions directly, need to use 'id'.
# self.assertEqual(WebSocketAppTest.get_mask_key_id, id(my_mask_key_func))
class SockOptTest(unittest.TestCase): class SockOptTest(unittest.TestCase):
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
@ -583,108 +403,49 @@ class SockOptTest(unittest.TestCase):
class UtilsTest(unittest.TestCase): class UtilsTest(unittest.TestCase):
def testUtf8Validator(self): def testUtf8Validator(self):
state = validate_utf8(six.b('\xf0\x90\x80\x80')) state = validate_utf8(b'\xf0\x90\x80\x80')
self.assertEqual(state, True) self.assertEqual(state, True)
state = validate_utf8(six.b('\xce\xba\xe1\xbd\xb9\xcf\x83\xce\xbc\xce\xb5\xed\xa0\x80edited')) state = validate_utf8(b'\xce\xba\xe1\xbd\xb9\xcf\x83\xce\xbc\xce\xb5\xed\xa0\x80edited')
self.assertEqual(state, False) self.assertEqual(state, False)
state = validate_utf8(six.b('')) state = validate_utf8(b'')
self.assertEqual(state, True) self.assertEqual(state, True)
class ProxyInfoTest(unittest.TestCase): class HandshakeTest(unittest.TestCase):
def setUp(self): @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
self.http_proxy = os.environ.get("http_proxy", None) def test_http_SSL(self):
self.https_proxy = os.environ.get("https_proxy", None) websock1 = ws.WebSocket(sslopt={"cert_chain": ssl.get_default_verify_paths().capath})
if "http_proxy" in os.environ: self.assertRaises(ValueError,
del os.environ["http_proxy"] websock1.connect, "wss://api.bitfinex.com/ws/2")
if "https_proxy" in os.environ: websock2 = ws.WebSocket(sslopt={"certfile": "myNonexistentCertFile"})
del os.environ["https_proxy"] self.assertRaises(FileNotFoundError,
websock2.connect, "wss://api.bitfinex.com/ws/2")
def tearDown(self): @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
if self.http_proxy: def testManualHeaders(self):
os.environ["http_proxy"] = self.http_proxy websock3 = ws.WebSocket(sslopt={"cert_reqs": ssl.CERT_NONE,
elif "http_proxy" in os.environ: "ca_certs": ssl.get_default_verify_paths().capath,
del os.environ["http_proxy"] "ca_cert_path": ssl.get_default_verify_paths().openssl_cafile})
self.assertRaises(ws._exceptions.WebSocketBadStatusException,
websock3.connect, "wss://api.bitfinex.com/ws/2", cookie="chocolate",
origin="testing_websockets.com",
host="echo.websocket.org/websocket-client-test",
subprotocols=["testproto"],
connection="Upgrade",
header={"CustomHeader1":"123",
"Cookie":"TestValue",
"Sec-WebSocket-Key":"k9kFAUWNAMmf5OEMfTlOEA==",
"Sec-WebSocket-Protocol":"newprotocol"})
if self.https_proxy: def testIPv6(self):
os.environ["https_proxy"] = self.https_proxy websock2 = ws.WebSocket()
elif "https_proxy" in os.environ: self.assertRaises(ValueError, websock2.connect, "2001:4860:4860::8888")
del os.environ["https_proxy"]
def testProxyFromArgs(self): def testBadURLs(self):
self.assertEqual(get_proxy_info("echo.websocket.org", False, proxy_host="localhost"), ("localhost", 0, None)) websock3 = ws.WebSocket()
self.assertEqual(get_proxy_info("echo.websocket.org", False, proxy_host="localhost", proxy_port=3128), ("localhost", 3128, None)) self.assertRaises(ValueError, websock3.connect, "ws//example.com")
self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost"), ("localhost", 0, None)) self.assertRaises(ws.WebSocketAddressException, websock3.connect, "ws://example")
self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128), ("localhost", 3128, None)) self.assertRaises(ValueError, websock3.connect, "example.com")
self.assertEqual(get_proxy_info("echo.websocket.org", False, proxy_host="localhost", proxy_auth=("a", "b")),
("localhost", 0, ("a", "b")))
self.assertEqual(get_proxy_info("echo.websocket.org", False, proxy_host="localhost", proxy_port=3128, proxy_auth=("a", "b")),
("localhost", 3128, ("a", "b")))
self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_auth=("a", "b")),
("localhost", 0, ("a", "b")))
self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128, proxy_auth=("a", "b")),
("localhost", 3128, ("a", "b")))
self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128, no_proxy=["example.com"], proxy_auth=("a", "b")),
("localhost", 3128, ("a", "b")))
self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128, no_proxy=["echo.websocket.org"], proxy_auth=("a", "b")),
(None, 0, None))
def testProxyFromEnv(self):
os.environ["http_proxy"] = "http://localhost/"
self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", None, None))
os.environ["http_proxy"] = "http://localhost:3128/"
self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", 3128, None))
os.environ["http_proxy"] = "http://localhost/"
os.environ["https_proxy"] = "http://localhost2/"
self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", None, None))
os.environ["http_proxy"] = "http://localhost:3128/"
os.environ["https_proxy"] = "http://localhost2:3128/"
self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", 3128, None))
os.environ["http_proxy"] = "http://localhost/"
os.environ["https_proxy"] = "http://localhost2/"
self.assertEqual(get_proxy_info("echo.websocket.org", True), ("localhost2", None, None))
os.environ["http_proxy"] = "http://localhost:3128/"
os.environ["https_proxy"] = "http://localhost2:3128/"
self.assertEqual(get_proxy_info("echo.websocket.org", True), ("localhost2", 3128, None))
os.environ["http_proxy"] = "http://a:b@localhost/"
self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", None, ("a", "b")))
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", 3128, ("a", "b")))
os.environ["http_proxy"] = "http://a:b@localhost/"
os.environ["https_proxy"] = "http://a:b@localhost2/"
self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", None, ("a", "b")))
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", 3128, ("a", "b")))
os.environ["http_proxy"] = "http://a:b@localhost/"
os.environ["https_proxy"] = "http://a:b@localhost2/"
self.assertEqual(get_proxy_info("echo.websocket.org", True), ("localhost2", None, ("a", "b")))
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
self.assertEqual(get_proxy_info("echo.websocket.org", True), ("localhost2", 3128, ("a", "b")))
os.environ["http_proxy"] = "http://a:b@localhost/"
os.environ["https_proxy"] = "http://a:b@localhost2/"
os.environ["no_proxy"] = "example1.com,example2.com"
self.assertEqual(get_proxy_info("example.1.com", True), ("localhost2", None, ("a", "b")))
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
os.environ["no_proxy"] = "example1.com,example2.com, echo.websocket.org"
self.assertEqual(get_proxy_info("echo.websocket.org", True), (None, 0, None))
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
os.environ["no_proxy"] = "127.0.0.0/8, 192.168.0.0/16"
self.assertEqual(get_proxy_info("127.0.0.1", False), (None, 0, None))
self.assertEqual(get_proxy_info("192.168.1.1", False), (None, 0, None))
if __name__ == "__main__": if __name__ == "__main__":