# -*- test-case-name: mv3d.test.net.test_net -*- # Copyright (C) 2006-2012 Mortal Coil Games # See LICENSE for details. """ """ from mv3d.util.classgen import getClass try: import cjson except ImportError: try: import json class Json(object): def encode(self, data): return json.dumps(data) def decode(self, data): return json.loads(data) cjson = Json() except ImportError: cjson = None import random from hashlib import md5 import logging from time import time, timezone from urlparse import urlparse, urlunparse from base64 import b64encode, b64decode from zope.interface import Interface, implements #@UnresolvedImport try: from Crypto.Cipher import AES except ImportError: AES = None from twisted.python.failure import Failure from twisted.internet import reactor, defer from twisted.internet.defer import inlineCallbacks, returnValue, _DefGen_Return try: from twisted.internet.ssl import ClientContextFactory except ImportError: ClientContextFactory = None from twisted.cred import credentials from twisted.cred.error import UnauthorizedLogin from twisted.web.client import HTTPClientFactory from twisted.spread import pb from mv3d.net.security import AccessDenied from mv3d.net.pb import Copyable readableCharacters = ("".join([chr(x) for x in range(48, 58)]) + "".join([chr(x) for x in range(65, 91)]) + "".join([chr(x) for x in range(97, 123)]) + " !@#$%^&*()-_=+[]{}") def genRandomText(length): """ Generate random text """ return "".join([random.choice(readableCharacters) for x in range(length)]) def padRandomText(data, blockSize): """ Pad data with random characters so that the length is evenly divisible by blockSize """ rlen = blockSize - (len(data) % blockSize) + blockSize return data + genRandomText(rlen) class LocalSession: """ Defines a session that is local """ authenticated = True authenticatedUser = None def __init__(self, user): self.authenticatedUser = user def getUsername(self): return self.authenticatedUser def getSessionID(self): return "localSession" class ServiceLoc(Copyable): """ Defines a pointer to the location of a service. The service can be local or remote """ protocol = None creds = None host = None port = None path = None query = None fragment = None params = None def __init__(self, value=None, protocol=None, creds=None, host=None, port=None, path=None, query=None, fragment=None, params=None): """ Initialize based on value. """ if value is not None: if isinstance(value, (str, unicode)): self.fromString(value) if isinstance(value, ServiceLoc): self.fromString(value.toString()) if protocol is not None: self.protocol = protocol if creds is not None: self.creds = creds if host is not None: self.host = host if port is not None: self.port = port if path is not None: self.path = path if query is not None: self.query = query if fragment is not None: self.fragment = fragment if params is not None: self.params = params def __cmp__(self, other): try: other = ServiceLoc(other) except: pass for attr in ["protocol", "host", "creds", "port", "path", "query", "fragment", "params"]: a = getattr(self, attr) b = getattr(other, attr) if a < b: return -1 if a > b: return 1 return 0 def __hash__(self): return hash(self.toString()) def __str__(self): return self.toString() def sameHost(self, other): """ Check if the host/port/protocol match up """ return (self.protocol == other.protocol and self.host == other.host and self.port == other.port and (self.creds == other.creds or self.creds is None or other.creds is None)) def stripToHost(self): """ Return a copy only including connection info """ sl = ServiceLoc() sl.protocol = self.protocol sl.host = self.host sl.port = self.port sl.creds = self.creds return sl def stripCreds(self): """ Return a copy without creds in it """ loc = ServiceLoc(self) loc.creds = None return loc def copy(self): """ Return a full copy """ sl = ServiceLoc() sl.protocol = self.protocol sl.creds = self.creds sl.host = self.host sl.port = self.port sl.path = self.path sl.query = self.query sl.fragment = self.fragment sl.params = self.params return sl def fromString(self, string): """ Extract a value from a string """ try: parts = urlparse(string) self.protocol = parts[0] # stupid bug in urlparse if "://" in string: parts = urlparse("http:" + ":".join(string.split(":")[1:])) else: parts = urlparse("http://" + string) loc = parts[1] if "@" in loc: creds, loc = loc.split("@") self.creds = creds.split(":") if ":" in loc: self.host, port = loc.split(":") self.port = int(port) else: self.host = loc if self.host == "self": self.protocol = "local" self.path, self.params, self.query, self.fragment = parts[2:] except ValueError, exc: raise ValueError("Invalid string " "%r passed to ServiceLoc.fromString (%s)." % ( string, str(exc))) return self def toString(self): """ Convert to a string """ proto = self.protocol or "pb" s = "" if self.creds is not None: s += ":".join(self.creds) + "@" if self.host is not None: s += self.host if self.port is not None: s += ":" + str(self.port) return urlunparse((proto, s, self.path or "", self.params or "", self.query or "", self.fragment or "")) def getConnection(self, conductor=None): """ Returns an IServiceCon subclass for this location """ if self.protocol == "local": return LocalServiceCon(self, conductor) if self.protocol == "pb": return PBServiceCon(self, conductor) raise ValueError("Don't know what type of connection" "for %s" % self.protocol) class IServiceCon: """ A connection to a service Example usage: sc = ServiceCon("pb://mv3d.com:1999/sim", conductor=c) d = sc.getItem((0,1)) --> connects to mv3d.com:1999 gets "sim" service calls view_getItem((0,1)) on "sim" service sc = ServiceCon("self/sim", conductor=c) d = sc.getItem((0,1)) --> gets "sim" service from the conductor calls getItem((0,1)) on it. sc = ServiceCon("https://mv3d.com/sim") d = sc.getItem((0,1)) --> makes https request for https://user:pass@mv3d.com/sim/getItem args are POSTed so request.args["argv"] = [(0,1)] note: unlike over pb, the item will not be used as a cacheable """ def __getattr__(self, key): """ Basically, if we have the attribute, then return it. Otherwise, return a function that will call this on the service. todo: this is dumb. Don't need to check in __dict__ """ def callIt(*a, **kw): return self.call(key, *a, **kw) return self.__dict__.get(key, getattr(self.__class__, key, callIt)) def getLocation(self): """ Just return the service location """ def disconnect(self): """ Disconnect from the service """ def isConnected(self): """ Returns true if this connection is connected """ def call(self, command, *a, **kw): """ Execute a command on the service """ def storeSessionData(self, key, value): """ Store some session specific value """ def retrieveSessionData(self, key): """ Retrieve session specific data """ def __repr__(self): """ This is important for various reasons """ return self.__class__.__name__ + "() @ %s" % hex(id(self)) def __str__(self): """ Also important """ return self.__class__.__name__ + "() @ %s" % hex(id(self)) def isLocal(self): """ Returns true if this is a local connection """ return False class LocalServiceCon(IServiceCon): """ A service connection to an in process service """ service = None def __init__(self, location, conductor=None): if not isinstance(location, ServiceLoc): location = ServiceLoc(location) self.location = location self.conductor = conductor def getLocation(self): """ Just return the service location """ return self.location def connect(self, conductor=None): """ Just get the reference to the interface """ if conductor is not None: self.conductor = conductor self.service = self.conductor.getNamedService( self.location.path.split("/")[1]) def disconnect(self): """ Drop the reference to the interface """ self.service = None def isConnected(self): """ If we've got the reference, then we're connected """ return self.service is not None def call(self, command, *a, **kw): """ Make a local call (no need to dial 1) """ if self.service is None: self.connect() return defer.maybeDeferred(getattr(self.service, command), *a, **kw) def isLocal(self): """ Returns true if this is a local connection """ return True class ConnectionError(Exception): """ Raised when there is a connection issue """ class PBServiceCon(IServiceCon): """ A simple service connection that runs over PB. """ connection = None service = None location = None def __init__(self, location, connection=None, service=None): self.location = ServiceLoc(location) self.connection = connection self.service = service assert self.location.protocol == "pb" def getLocation(self): """ Just return the service location """ return self.location def disconnect(self): """ Disconnect this connection """ self.service = None if self.connection is not None: d = self.connection.broker.transport.loseConnection() self.connection = None return d def isConnected(self): """ Returns true if we have a service and it's still connected """ return (self.connection is not None and not self.connection.broker.disconnected and self.service is not None) def __eq__(self, other): raise Exception("WTF") def call(self, command, *a, **kw): """ Proxy over to callRemote """ if not self.isConnected(): raise ConnectionError("This connection has been disconnected") return self.service.callRemote(command, *a, **kw) class IConnectionFactory(Interface): """ A connection factory helps the app connect to resources over a specific protocol. """ def configure(nm, cf): #@NoSelf """ Configure this factory """ def getProtocol(): #@NoSelf """ Return the name of the protocol that we provide connections for """ def getService(loc): #@NoSelf """ Return a connected IServiceCon for the given location """ def isServiceConnected(loc): #@NoSelf """ Returns true if we have an open connection to this service. Some protocols may not support this and would therefore always return False """ def hasConnectionTo(loc): #@NoSelf """ Returns true if we have an open connection to this host/port """ def countConnections(): #@NoSelf """ Returns the number of connections """ class ClientController(pb.Viewable): """ This defines the object that the server uses to control the client The object lives on the client side of a PB connection, and methods are called remotely from the server side """ pingtime = None def __init__(self, parent): self.parent = parent self.parent.parent.log("Connected.", logging.INFO) def view_readyClass(self, _con, c): """ Prepare a class for being transmitted from the server """ if isinstance(c, list): dd = [] for cc in c: dd.append(defer.maybeDeferred(cc.readyJellyFor)) return defer.gatherResults(dd) return c.readyJellyFor() def view_ping(self, _con, t, pingtime): """ Respond to a ping request """ self.pingtime = pingtime return t def view_receiveEvent(self, _, event): """ Receive an event from the server """ self.parent.parent.broadcastEvent(event) class PBConFactory(object): """ This is a factory for outgoing perspective broker connections """ implements(IConnectionFactory) creds = None localAliases = None loginServices = None def __init__(self, parent): if AES is None: raise ImportError("Please install Python Crypto module.") self.connections = {} self.localAliases = [] self.connectors = [] self.parent = parent def getProtocol(self): return "pb" def configure(self, nm, cf): """ Configure this PBConFactory """ if cf.has_option(nm, "defaultCredentials"): self.creds = cf.get(nm, "defaultCredentials").split(":") if cf.has_option(nm, "localAliases"): la = cf.get(nm, "localAliases").split(",") self.localAliases = [ServiceLoc(l.strip()) for l in la] if cf.has_option(nm, "loginServices"): ls = cf.get(nm, "loginServices").split(",") self.loginServices = [ServiceLoc(l) for l in ls] def getLocalAlias(self, type=None): """ Get a local alias """ #eventually we'll use the type arg return self.localAliases[0] def genAuthToken(self, user, password): """ Generates a client auth token """ data = "%f,%s," % (time() + timezone, user) rlen = 16 - (len(data) % 16) + 16 data += "".join([chr(random.randint(0, 255)) for x in range(rlen)]) pw = password if len(pw) < 32: l = 32 / len(pw) + 1 pw = pw * int(l) return AES.new(pw[:32]).encrypt(data) @inlineCallbacks def getTempPasswords(self, svc): """ Get a set of temporary passwords using creds from svc """ if svc.creds is None: svc.creds = self.creds if svc.creds is None: raise ConnectionError("No credentials supplied!") service = yield self.parent.getOneService(self.loginServices) if not service.isLocal(): key = yield service.getSessionKey(*svc.creds) encData = yield service.createPasswords() pw1, pw2, uid, _rnd = AES.new(key).decrypt( b64decode(encData)).split(",") returnValue((pw1, pw2, uid)) data = yield service.createPasswords(LocalSession(svc.creds[0])) returnValue(tuple(data.split(",")[:3])) def _gotService(self, serviceLoc, svc, sc, serviceName, serviceInterface): assert svc is not None, "%s was none?" % serviceLoc con = PBServiceCon(serviceLoc, sc.connection, svc) for d in self.connections[serviceLoc]: d.callback(con) self.connections[serviceLoc] = (con, serviceName, serviceInterface) return con @inlineCallbacks def _getServiceIfLocal(self, serviceLoc): """ Returns a local service if this serviceLoc references us. """ if serviceLoc.fragment: serviceName = None serviceInt = serviceLoc.fragment else: serviceName = serviceLoc.path.split("/")[1] serviceInt = None if self.localAliases is not None: for local in self.localAliases: if local.sameHost(serviceLoc): sl2 = serviceLoc.copy() sl2.protocol = "local" sl2.host = "self" if serviceName is not None: newService = yield self.parent.getService(sl2) else: newService = yield self.parent.getLocalService( getClass(serviceInt)) returnValue(newService) @inlineCallbacks def _getSameHostService(self, serviceLoc): """ Returns a service if we have a connection to one on the same host. """ if serviceLoc.fragment: serviceName = None serviceInt = serviceLoc.fragment else: serviceName = serviceLoc.path.split("/")[1] serviceInt = None for k, c in self.connections.items(): if k.sameHost(serviceLoc): if isinstance(c, list): d = defer.Deferred() c.append(d) self.connections[serviceLoc] = [] conn = yield d if serviceName is not None: serviceInfo = yield conn.connection.callRemote( "getService", serviceName) # handle older servers that don't return the interface if isinstance(serviceInfo, (tuple, list)): serviceInt, service = serviceInfo else: service = serviceInfo serviceInt = None else: serviceName, service = yield conn.connection.callRemote( "getServiceByInterface", serviceInt) returnValue(self._gotService(serviceLoc, service, conn, serviceName, serviceInt)) if c[0].isConnected(): self.connections[serviceLoc] = [] try: if serviceName is not None: serviceInfo = yield c[0].connection.callRemote( "getService", serviceLoc.path.split("/")[1]) # handle older servers that don't return the interface if isinstance(serviceInfo, (tuple, list)): serviceInt, service = serviceInfo else: service = serviceInfo serviceInt = None else: serviceName, service = yield c[0].connection.callRemote( "getServiceByInterface", serviceInt) returnValue(self._gotService(serviceLoc, service, c[0], serviceName, serviceInt)) except _DefGen_Return: raise except Exception, error: if self.connections.has_key(serviceLoc): for dfrd in self.connections[serviceLoc]: dfrd.errback(error) del self.connections[serviceLoc] if self.connections.has_key(k): del self.connections[k] raise else: del self.connections[k] def _getMatchingService(self, serviceLoc): """ Checks to see if any of the connected services match the service location. """ for checkLoc, data in self.connections.iteritems(): if isinstance(data, list): continue if not serviceLoc.sameHost(checkLoc): continue connection, serviceName, serviceInt = data path = serviceLoc.path.split("/") if len(path) > 1 and path[1] == serviceName: return connection if serviceLoc.fragment: try: scls = getClass(serviceInt) intcls = getClass(serviceLoc.fragment) except ImportError: # TODO: MDH - this is totally not going to work on the client. continue if intcls.implementedBy(scls): return connection @inlineCallbacks def getService(self, serviceLoc): """ Get a remote service specified by loc todo: refactor this. """ serviceLoc = ServiceLoc(serviceLoc) assert serviceLoc.protocol == "pb" service = yield self._getServiceIfLocal(serviceLoc) if service is not None: returnValue(service) # check for duplicate connection if isinstance(self.connections.get(serviceLoc), list): d = defer.Deferred() self.connections[serviceLoc].append(d) service = yield d returnValue(service) if (self.connections.has_key(serviceLoc) and self.connections[serviceLoc][0].isConnected()): returnValue(self.connections[serviceLoc][0]) service = self._getMatchingService(serviceLoc) if service is not None: returnValue(service) service = yield self._getSameHostService(serviceLoc) if service is not None: returnValue(service) # ok fine, actually do the work of getting the service. if serviceLoc.fragment: serviceName = None serviceInt = serviceLoc.fragment else: serviceName = serviceLoc.path.split("/")[1] serviceInt = None self.connections[serviceLoc] = [] try: sl = serviceLoc.copy() sl.creds = sl.creds or self.creds credSplit = sl.creds[0].split(",") user = credSplit[0] if len(credSplit) > 1: authMethod = credSplit[1] else: authMethod = "loginService" if authMethod == "loginService": sl.creds = (user, sl.creds[1]) pws = yield self.getTempPasswords(sl) pw1, _pw2, uid = pws creds = credentials.UsernamePassword(",".join([ user, authMethod, str(uid)]), pw1) else: creds = credentials.UsernamePassword(",".join([ user, authMethod]), sl.creds[1]) factory = pb.PBClientFactory() except _DefGen_Return: raise except: if self.connections.has_key(serviceLoc): for d in self.connections[serviceLoc]: d.errback() del self.connections[serviceLoc] raise try: connector = reactor.connectTCP(serviceLoc.host, #@UndefinedVariable serviceLoc.port, factory) # mdh TODO: this is busted in python 2.6 # see http://twistedmatrix.com/trac/ticket/4520 dfrd = factory.login(creds, client=ClientController( self)) self.connectors.append(connector) try: remote = yield dfrd finally: self.connectors.remove(connector) con = PBServiceCon(serviceLoc, remote) if serviceName is not None: serviceInfo = yield remote.callRemote("getService", serviceName) # handle older servers that don't return the interface if isinstance(serviceInfo, (tuple, list)): serviceInt, svc = serviceInfo else: svc = serviceInfo serviceInt = None else: serviceName, svc = yield remote.callRemote("getServiceByInterface", serviceInt) con.service = svc for d in self.connections[serviceLoc]: d.callback(con) self.connections[serviceLoc] = con, serviceName, serviceInt def removeIt(_): if self.connections.has_key(serviceLoc): del self.connections[serviceLoc] con.connection.notifyOnDisconnect(removeIt) except _DefGen_Return: raise except: if self.connections.has_key(serviceLoc): for d in self.connections[serviceLoc]: d.errback() del self.connections[serviceLoc] connector.disconnect() raise if self.creds is None: self.creds = serviceLoc.creds returnValue(con) def isServiceConnected(self, serviceLoc): """ Returns true if we have a connection to this service """ serviceLoc = ServiceLoc(serviceLoc) return (self.connections.has_key(serviceLoc) and not isinstance(self.connections[serviceLoc], list) and self.connections[serviceLoc][0].isConnected()) def hasConnectionTo(self, loc): """ Returns true if we have an open connection to this host/port """ loc = ServiceLoc(loc) for k, c in self.connections.items(): if k.sameHost(loc) and c[0].isConnected: return True return False def garbageCollect(self): """ Remove any dead connections """ dead = [] for k, c in self.connections.items(): if not c[0].isConnected(): dead.append(k) for d in dead: del self.connections[d] def countConnections(self): """ Return a count of the connections """ return len(self.connections) def disconnectAll(self): """ Disconnect all the connections """ dd = [] items = self.connections.items()[:] for k, c in items: if isinstance(c, list): for waiter in c: waiter.errback(ConnectionError("disconnectAll called")) else: d = c[0].disconnect() if isinstance(d, defer.Deferred): dd.append(d) del self.connections[k] for connector in self.connectors: dd.append(connector) return defer.gatherResults(dd) class JSONServiceCon(IServiceCon): """ A connection to a JSON RPC over http service """ loggedIn = False contextFactory = None nextId = 0 sessionData = None def __init__(self, location, cookies=None): if cjson is None: raise ImportError("No cjson available, please install it.") if ClientContextFactory is None: raise ImportError("ClientContextFactory from Twisted is missing.") self.location = ServiceLoc(location) self.cookies = cookies or {} self.sessionData = {} def getLocation(self): """ Just return the service location """ return self.location def disconnect(self): """ Disconnect from the service """ def isConnected(self): """ Returns true if this connection is connected """ if not self.loggedIn: return False return True def httpGet(self, sl, fields=None, files=None): """ Send a httpGet request and return a deferred """ method = "GET" headers = None postdata = None loc = ServiceLoc(sl) loc.creds = None if fields is not None or files is not None: method = "POST" ctype, postdata = self.formatPOST(fields, files) headers = {"Content-Type":ctype} factory = HTTPClientFactory(loc.toString(), postdata=postdata, timeout=30, agent="MV3D", cookies=self.cookies, headers=headers, method=method) if loc.protocol == "https": if self.contextFactory is None: self.contextFactory = ClientContextFactory() reactor.connectSSL(loc.host, loc.port, factory, self.contextFactory) #@UndefinedVariable else: reactor.connectTCP(loc.host, loc.port, factory) #@UndefinedVariable d = factory.deferred def done(result): self.cookies.update(factory.cookies) return result return d.addCallback(done) def formatPOST(self, fields=None, files=None): """ Format HTTP Post data. This should be moved to somewhere more generic if it is ever needed elsewhere. General idea from: http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/146306 """ assert fields is not None or files is not None boundary = "----------MV3D Boundry Data HERE ----------" lines = [] if fields is not None: for k, v in fields.items(): lines.append('--' + boundary) lines.append('Content-Disposition: form-data; name="%s"' % k) lines.append('') lines.append(str(v)) if files is not None: for k, v in files.items(): filename, _value = v lines.append('--' + boundary) lines.append('Content-Disposition: form-data; name="%s";' % k + 'filename="%s"' % filename) lines.append('Content-Type: application/octet-stream') lines.append('') lines.append(str(v)) lines.append('--' + boundary + "--") lines.append('') data = "\r\n".join(lines) ctype = 'multipart/form-data; boundary=%s' % boundary return ctype, data def logIn(self): """ Log in to the service """ if self.location.creds is None: self.loggedIn = True return defer.succeed(True) user, pw = self.location.creds nl = self.location.copy() nl.creds = None n = -1 if nl.path[-1] == "/": n = -2 sp = nl.path.split("/") nl.path = "/".join(sp[:n] + ["__login__"] + sp[n:]) d = self.httpGet(nl, dict(username=user, password=pw)) def loggedIn(result): if result != "Login accepted.": raise UnauthorizedLogin("Log in failed") self.loggedIn = True return True def check404(e): if hasattr(e.value, "status") and e.value.status != "404": e.raiseException() raise UnauthorizedLogin("Log in failed") return d.addCallback(loggedIn).addErrback(check404) def storeSessionData(self, key, value): """ Store some session specific value """ self.sessionData[key] = value def retrieveSessionData(self, key, default=None): """ Retrieve session specific data """ return self.sessionData.get(key, default) def getSessionKey(self, username, password): """ Ensure that we _really_ trust service i """ assert not self.isLocal() key = self.retrieveSessionData("key") if key is None: # must authenticate d = self.challenge() def gotChallenge(c): m = md5(c) m.update(password) skey = m.hexdigest() # print "Using key", skey data = ",".join( [str(time() + timezone), username]) + "," data = padRandomText(data, 16) tm1, un1, rnd = data.split(",") key = md5(rnd + password).hexdigest() self.storeSessionData("key", key) cryptData = b64encode(AES.new(skey).encrypt(data)) # print repr(cryptData) d = self.authenticateSession(username, cryptData) def gotReply(cryptData): data = AES.new(key).decrypt(b64decode(cryptData)) tm2, un2, rnd = data.split(",") if float(tm1) != float(tm2) or un1 != un2: del self.sessionData["key"] raise AccessDenied("Server failed verification!") return key def error(e): del self.sessionData["key"] e.raiseException() return d.addCallback(gotReply).addErrback(error) return d.addCallback(gotChallenge) return defer.succeed(key) def call(self, command, *a, **kw): """ Execute a command on the service """ # print "JSON Call", command, a, kw theId = self.nextId self.nextId += 1 request = dict(id=theId, method=command, params=dict(args=a, kwargs=kw)) request = cjson.encode(request) if not self.loggedIn: d = self.logIn() d.addCallback(lambda _: self.httpGet(self.location, dict(data=request))) else: d = self.httpGet(self.location, dict(data=request)) def gotReply(reply): reply = cjson.decode(reply) if reply["error"] is not None: clsn, msg, tb = reply["error"] emod = ".".join(clsn.split(".")[:-1]) ecls = clsn.split(".")[-1] exc = getattr(__import__(emod, globals(), locals(), [ecls]), ecls) return Failure(exc, msg) assert reply["id"] == theId return reply["result"] return d.addCallback(gotReply) class JSONConFactory(object): """ A connection factory for JSON RPC over HTTP requests """ implements(IConnectionFactory) protocol = "https" localAliases = None creds = None def __init__(self, parent): self.parent = parent self.connections = {} def configure(self, nm, cf): """ Configure this JSONConFactory """ if cf.has_option(nm, "protocol"): self.protocol = cf.get(nm, "protocol") if cf.has_option(nm, "defaultCredentials"): self.creds = cf.get(nm, "defaultCredentials").split(":") if cf.has_option(nm, "localAliases"): la = cf.get(nm, "localAliases").split(",") self.localAliases = [ServiceLoc(l.strip()) for l in la] def getProtocol(self): """ Return the name of the protocol that we provide connections for """ return self.protocol def getService(self, loc): """ Return a connected IServiceCon for the given location """ serviceLoc = ServiceLoc(loc) assert serviceLoc.protocol == self.protocol # check if this should really be local if self.localAliases is not None: for l in self.localAliases: if l.sameHost(serviceLoc): sl2 = serviceLoc.copy() sl2.protocol = "local" sl2.host = "self" return self.parent.getService(sl2) # check for duplicate connection if (self.connections.has_key(serviceLoc) and self.connections[serviceLoc].isConnected()): return defer.succeed(self.connections[serviceLoc]) sl = serviceLoc.copy() sl.creds = sl.creds or self.creds self.connections[sl] = JSONServiceCon(sl) # log in now so that errors are reported now, not when # you try to use the service. d = self.connections[sl].logIn() def error(e): del self.connections[sl] e.raiseException() d.addErrback(error) d.addCallback(lambda _: self.connections[sl]) return d def isServiceConnected(self, loc): """ Returns true if we have an open connection to this service. Some protocols may not support this and would therefore always return False """ serviceLoc = ServiceLoc(loc) return (self.connections.has_key(serviceLoc) and self.connections[serviceLoc].isConnected()) def hasConnectionTo(self, loc): """ Returns true if we have an open connection to this host/port """ loc = ServiceLoc(loc) for k, c in self.connections.items(): if k.sameHost(loc) and c.isConnected: return True return False def countConnections(self): """ Returns the number of open connections """ return len(self.connections) def disconnectAll(self): """ Disconnect all connected clients """ for client in self.connections.values(): client.disconnect() pb.setUnjellyableForClass(ServiceLoc, ServiceLoc)