ia64/xen-unstable

view tools/python/xen/xend/server/channel.py @ 6552:a9873d384da4

Merge.
author adsharma@los-vmm.sc.intel.com
date Thu Aug 25 12:24:48 2005 -0700 (2005-08-25)
parents 112d44270733 fa0754a9f64f
children dfaf788ab18c
line source
1 #============================================================================
2 # This library is free software; you can redistribute it and/or
3 # modify it under the terms of version 2.1 of the GNU Lesser General Public
4 # License as published by the Free Software Foundation.
5 #
6 # This library is distributed in the hope that it will be useful,
7 # but WITHOUT ANY WARRANTY; without even the implied warranty of
8 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
9 # Lesser General Public License for more details.
10 #
11 # You should have received a copy of the GNU Lesser General Public
12 # License along with this library; if not, write to the Free Software
13 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
14 #============================================================================
15 # Copyright (C) 2004, 2005 Mike Wray <mike.wray@hp.com>
16 #============================================================================
18 import threading
19 import select
21 import xen.lowlevel.xc; xc = xen.lowlevel.xc.new()
22 from xen.lowlevel import xu
24 from xen.xend.XendLogging import log
26 from messages import *
28 DEBUG = 0
30 RESPONSE_TIMEOUT = 20.0
32 class EventChannel(dict):
33 """An event channel between domains.
34 """
36 def interdomain(cls, dom1, dom2, port1=0, port2=0):
37 """Create an event channel between domains.
39 @return EventChannel (None on error)
40 """
41 v = xc.evtchn_bind_interdomain(dom1=dom1, dom2=dom2,
42 port1=port1, port2=port2)
43 if v:
44 v = cls(dom1, dom2, v)
45 return v
47 interdomain = classmethod(interdomain)
49 def restoreFromDB(cls, db, dom1, dom2, port1=0, port2=0):
50 """Create an event channel using db info if available.
51 Inverse to saveToDB().
53 @param db db
54 @param dom1
55 @param dom2
56 @param port1
57 @param port2
58 """
59 try:
60 dom1 = int(db['dom1'].getData())
61 except: pass
62 try:
63 dom2 = int(db['dom2'].getData())
64 except: pass
65 try:
66 port1 = int(db['port1'].getData())
67 except: pass
68 try:
69 port2 = int(db['port2'].getData())
70 except: pass
71 evtchn = cls.interdomain(dom1, dom2, port1=port1, port2=port2)
72 return evtchn
74 restoreFromDB = classmethod(restoreFromDB)
76 def __init__(self, dom1, dom2, d):
77 d['dom1'] = dom1
78 d['dom2'] = dom2
79 self.update(d)
80 self.dom1 = dom1
81 self.dom2 = dom2
82 self.port1 = d.get('port1')
83 self.port2 = d.get('port2')
85 def close(self):
86 """Close the event channel.
87 """
88 def evtchn_close(dom, port):
89 try:
90 xc.evtchn_close(dom=dom, port=port)
91 except Exception, ex:
92 pass
94 if DEBUG:
95 print 'EventChannel>close>', self
96 evtchn_close(self.dom1, self.port1)
97 evtchn_close(self.dom2, self.port2)
99 def saveToDB(self, db, save=False):
100 """Save the event channel to the db so it can be restored later,
101 using restoreFromDB() on the class.
103 @param db db
104 """
105 db['dom1'] = str(self.dom1)
106 db['dom2'] = str(self.dom2)
107 db['port1'] = str(self.port1)
108 db['port2'] = str(self.port2)
109 db.saveDB(save=save)
111 def sxpr(self):
112 return ['event-channel',
113 ['dom1', self.dom1 ],
114 ['port1', self.port1 ],
115 ['dom2', self.dom2 ],
116 ['port2', self.port2 ]
117 ]
119 def __repr__(self):
120 return ("<EventChannel dom1:%d:%d dom2:%d:%d>"
121 % (self.dom1, self.port1, self.dom2, self.port2))
123 def eventChannel(dom1, dom2, port1=0, port2=0):
124 """Create an event channel between domains.
126 @return EventChannel (None on error)
127 """
128 return EventChannel.interdomain(dom1, dom2, port1=port1, port2=port2)
130 def eventChannelClose(evtchn):
131 """Close an event channel.
132 """
133 if not evtchn: return
134 evtchn.close()
136 class ChannelFactory:
137 """Factory for creating control channels.
138 Maintains a table of channels.
139 """
141 """ Channels indexed by index. """
142 channels = None
144 thread = None
146 notifier = None
148 """Map of ports to the virq they signal."""
149 virqPorts = None
151 def __init__(self):
152 """Constructor - do not use. Use the channelFactory function."""
153 self.channels = {}
154 self.virqPorts = {}
155 self.notifier = xu.notifier()
156 # Register interest in virqs.
157 self.bind_virq(xen.lowlevel.xc.VIRQ_DOM_EXC)
158 self.virqHandler = None
160 def bind_virq(self, virq):
161 port = self.notifier.bind_virq(virq)
162 self.virqPorts[port] = virq
163 log.info("Virq %s on port %s", virq, port)
165 def start(self):
166 """Fork a thread to read messages.
167 """
168 if self.thread: return
169 self.thread = threading.Thread(name="ChannelFactory",
170 target=self.main)
171 self.thread.setDaemon(True)
172 self.thread.start()
174 def stop(self):
175 """Signal the thread to stop.
176 """
177 self.thread = None
179 def main(self):
180 """Main routine for the thread.
181 Reads the notifier and dispatches to channels.
182 """
183 while True:
184 if self.thread == None: return
185 port = self.notifier.read()
186 if port:
187 virq = self.virqPorts.get(port)
188 if virq is not None:
189 self.virqReceived(virq)
190 else:
191 self.msgReceived(port)
192 else:
193 select.select([self.notifier], [], [], 1.0)
195 def msgReceived(self, port):
196 # We run the message handlers in their own threads.
197 # Note we use keyword args to lambda to save the values -
198 # otherwise lambda will use the variables, which will get
199 # assigned by the loop and the lambda will get the changed values.
200 received = 0
201 for chan in self.channels.values():
202 if self.thread == None: return
203 msg = chan.readResponse()
204 if msg:
205 received += 1
206 chan.responseReceived(msg)
207 for chan in self.channels.values():
208 if self.thread == None: return
209 msg = chan.readRequest()
210 if msg:
211 received += 1
212 self.runInThread(lambda chan=chan, msg=msg: chan.requestReceived(msg))
213 if port and received == 0:
214 log.warning("Port %s notified, but no messages found", port)
216 def runInThread(self, thunk):
217 thread = threading.Thread(target = thunk)
218 thread.setDaemon(True)
219 thread.start()
221 def setVirqHandler(self, virqHandler):
222 self.virqHandler = virqHandler
224 def virqReceived(self, virq):
225 if DEBUG:
226 print 'virqReceived>', virq
227 if not self.virqHandler: return
228 self.runInThread(lambda virq=virq: self.virqHandler(virq))
230 def newChannel(self, dom, local_port, remote_port):
231 """Create a new channel.
232 """
233 return self.addChannel(Channel(self, dom, local_port, remote_port))
235 def addChannel(self, channel):
236 """Add a channel.
237 """
238 self.channels[channel.getKey()] = channel
239 return channel
241 def delChannel(self, channel):
242 """Remove the channel.
243 """
244 key = channel.getKey()
245 if key in self.channels:
246 del self.channels[key]
248 def getChannel(self, dom, local_port, remote_port):
249 """Get the channel with the given domain and ports (if any).
250 """
251 key = (dom, local_port, remote_port)
252 return self.channels.get(key)
254 def findChannel(self, dom, local_port=0, remote_port=0):
255 """Find a channel. Ports given as zero are wildcards.
257 dom domain
259 returns channel
260 """
261 chan = self.getChannel(dom, local_port, remote_port)
262 if chan: return chan
263 if local_port and remote_port:
264 return None
265 for c in self.channels.values():
266 if c.dom != dom: continue
267 if local_port and local_port != c.getLocalPort(): continue
268 if remote_port and remote_port != c.getRemotePort(): continue
269 return c
270 return None
272 def openChannel(self, dom, local_port=0, remote_port=0):
273 chan = self.findChannel(dom, local_port=local_port,
274 remote_port=remote_port)
275 if chan:
276 return chan
277 chan = self.newChannel(dom, local_port, remote_port)
278 return chan
281 def createPort(self, dom, local_port=0, remote_port=0):
282 """Create a port for a channel to the given domain.
283 If only the domain is specified, a new channel with new port ids is
284 created. If one port id is specified and the given port id is in use,
285 the other port id is filled. If one port id is specified and the
286 given port id is not in use, a new channel is created with one port
287 id equal to the given id and a new id for the other end. If both
288 port ids are specified, a port is reconnected using the given port
289 ids.
291 @param dom: domain
292 @param local: local port id to use
293 @type local: int
294 @param remote: remote port id to use
295 @type remote: int
296 @return: port object
297 """
298 return xu.port(dom, local_port=local_port, remote_port=remote_port)
300 def restoreFromDB(self, db, dom, local, remote):
301 """Create a channel using ports restored from the db (if available).
302 Otherwise use the given ports. This is the inverse operation to
303 saveToDB() on a channel.
305 @param db db
306 @param dom domain the channel connects to
307 @param local default local port
308 @param remote default remote port
309 """
310 try:
311 local_port = int(db['local_port'])
312 except:
313 local_port = local
314 try:
315 remote_port = int(db['remote_port'])
316 except:
317 remote_port = remote
318 try:
319 chan = self.openChannel(dom, local_port, remote_port)
320 except:
321 return None
322 return chan
324 def channelFactory():
325 """Singleton constructor for the channel factory.
326 Use this instead of the class constructor.
327 """
328 global inst
329 try:
330 inst
331 except:
332 inst = ChannelFactory()
333 return inst
335 class Channel:
336 """Control channel to a domain.
337 Maintains a list of device handlers to dispatch requests to, based
338 on the request type.
339 """
341 def __init__(self, factory, dom, local_port, remote_port):
342 self.factory = factory
343 self.dom = int(dom)
344 # Registered device handlers.
345 self.devs = []
346 # Handlers indexed by the message types they handle.
347 self.devs_by_type = {}
348 self.port = self.factory.createPort(self.dom,
349 local_port=local_port,
350 remote_port=remote_port)
351 self.closed = False
352 # Queue of waiters for responses to requests.
353 self.queue = ResponseQueue(self)
354 # Make sure the port will deliver all the messages.
355 self.port.register(TYPE_WILDCARD)
357 def saveToDB(self, db, save=False):
358 """Save the channel ports to the db so the channel can be restored later,
359 using restoreFromDB() on the factory.
361 @param db db
362 """
363 if self.closed: return
364 db['local_port'] = str(self.getLocalPort())
365 db['remote_port'] = str(self.getRemotePort())
366 db.saveDB(save=save)
368 def getKey(self):
369 """Get the channel key.
370 """
371 return (self.dom, self.getLocalPort(), self.getRemotePort())
373 def sxpr(self):
374 val = ['channel']
375 val.append(['domain', self.dom])
376 if self.port:
377 val.append(['local_port', self.port.local_port])
378 val.append(['remote_port', self.port.remote_port])
379 return val
381 def close(self):
382 """Close the channel.
383 """
384 if DEBUG:
385 print 'Channel>close>', self
386 if self.closed: return
387 self.closed = True
388 self.factory.delChannel(self)
389 for d in self.devs[:]:
390 d.lostChannel(self)
391 self.devs = []
392 self.devs_by_type = {}
393 if self.port:
394 self.port.close()
395 #self.port = None
397 def getDomain(self):
398 return self.dom
400 def getLocalPort(self):
401 """Get the local port.
403 @return: local port
404 @rtype: int
405 """
406 if self.closed: return -1
407 return self.port.local_port
409 def getRemotePort(self):
410 """Get the remote port.
412 @return: remote port
413 @rtype: int
414 """
415 if self.closed: return -1
416 return self.port.remote_port
418 def __repr__(self):
419 return ('<Channel dom=%d ports=%d:%d>'
420 % (self.dom,
421 self.getLocalPort(),
422 self.getRemotePort()))
425 def registerDevice(self, types, dev):
426 """Register a device message handler.
428 @param types: message types handled
429 @type types: array of ints
430 @param dev: device handler
431 """
432 if self.closed: return
433 self.devs.append(dev)
434 for ty in types:
435 self.devs_by_type[ty] = dev
437 def deregisterDevice(self, dev):
438 """Remove the registration for a device handler.
440 @param dev: device handler
441 """
442 if dev in self.devs:
443 self.devs.remove(dev)
444 types = [ ty for (ty, d) in self.devs_by_type.items() if d == dev ]
445 for ty in types:
446 del self.devs_by_type[ty]
448 def getDevice(self, type):
449 """Get the handler for a message type.
451 @param type: message type
452 @type type: int
453 @return: controller or None
454 @rtype: device handler
455 """
456 return self.devs_by_type.get(type)
458 def requestReceived(self, msg):
459 """A request has been received on the channel.
460 Disptach it to the device handlers.
461 Called from the channel factory thread.
462 """
463 if DEBUG:
464 print 'Channel>requestReceived>', self,
465 printMsg(msg)
466 (ty, subty) = getMessageType(msg)
467 responded = False
468 dev = self.getDevice(ty)
469 if dev:
470 responded = dev.requestReceived(msg, ty, subty)
471 elif DEBUG:
472 print "Channel>requestReceived> No device handler", self,
473 printMsg(msg)
474 else:
475 pass
476 if not responded:
477 self.writeResponse(msg)
479 def writeRequest(self, msg):
480 """Write a request to the channel.
481 """
482 if DEBUG:
483 print 'Channel>writeRequest>', self,
484 printMsg(msg, all=True)
485 if self.closed: return -1
486 self.port.write_request(msg)
487 return 1
489 def writeResponse(self, msg):
490 """Write a response to the channel.
491 """
492 if DEBUG:
493 print 'Channel>writeResponse>', self,
494 printMsg(msg, all=True)
495 if self.port:
496 self.port.write_response(msg)
497 return 1
499 def readRequest(self):
500 """Read a request from the channel.
501 Called internally.
502 """
503 if self.closed:
504 val = None
505 else:
506 val = self.port.read_request()
507 return val
509 def readResponse(self):
510 """Read a response from the channel.
511 Called internally.
512 """
513 if self.closed:
514 val = None
515 else:
516 val = self.port.read_response()
517 if DEBUG and val:
518 print 'Channel>readResponse>', self,
519 printMsg(val, all=True)
520 return val
522 def requestResponse(self, msg, timeout=None):
523 """Write a request and wait for a response.
524 Raises IOError on timeout.
526 @param msg request message
527 @param timeout timeout (0 is forever)
528 @return response message
529 """
530 if self.closed:
531 raise IOError("closed")
532 if self.closed:
533 return None
534 if timeout is None:
535 timeout = RESPONSE_TIMEOUT
536 elif timeout <= 0:
537 timeout = None
538 return self.queue.call(msg, timeout)
540 def responseReceived(self, msg):
541 """A response has been received, look for a waiter to
542 give it to.
543 Called internally.
544 """
545 if DEBUG:
546 print 'Channel>responseReceived>', self,
547 printMsg(msg)
548 self.queue.response(getMessageId(msg), msg)
550 def virq(self):
551 self.factory.virq()
553 class Response:
554 """Entry in the response queue.
555 Used to signal a response to a message.
556 """
558 def __init__(self, mid):
559 self.mid = mid
560 self.msg = None
561 self.ready = threading.Event()
563 def response(self, msg):
564 """Signal arrival of a response to a waiting thread.
565 Passing msg None cancels the wait with an IOError.
566 """
567 if msg:
568 self.msg = msg
569 else:
570 self.mid = -1
571 self.ready.set()
573 def wait(self, timeout):
574 """Wait up to 'timeout' seconds for a response.
575 Returns the response or raises an IOError.
576 """
577 self.ready.wait(timeout)
578 if self.mid < 0:
579 raise IOError("wait canceled")
580 if self.msg is None:
581 raise IOError("response timeout")
582 return self.msg
584 class ResponseQueue:
585 """Response queue. Manages waiters for responses to messages.
586 """
588 def __init__(self, channel):
589 self.channel = channel
590 self.lock = threading.Lock()
591 self.responses = {}
593 def add(self, mid):
594 r = Response(mid)
595 self.responses[mid] = r
596 return r
598 def get(self, mid):
599 return self.responses.get(mid)
601 def remove(self, mid):
602 r = self.responses.get(mid)
603 if r:
604 del self.responses[mid]
605 return r
607 def response(self, mid, msg):
608 """Process a response - signals any waiter that a response
609 has arrived.
610 """
611 try:
612 self.lock.acquire()
613 r = self.remove(mid)
614 finally:
615 self.lock.release()
616 if r:
617 r.response(msg)
619 def call(self, msg, timeout):
620 """Send the message and wait for 'timeout' seconds for a response.
621 Returns the response.
622 Raises IOError on timeout.
623 """
624 mid = getMessageId(msg)
625 try:
626 self.lock.acquire()
627 r = self.add(mid)
628 finally:
629 self.lock.release()
630 self.channel.writeRequest(msg)
631 return r.wait(timeout)