#!/usr/bin/python3
#
# Tiny SAT>IP client which activates multicast streaming.
#
# Create /etc/sysconfig/multicast file with syntax:
#
# <mport>:<satip-params>
#
# You can modify this file when the program is active and
# reload the configuration using the SIGHUP signal.
#
# Example:
#
# # Channel 1 on multicast port 5554, using first tuner (fe=1)
# 5554:fe=1&src=1&freq=12525&sr=27500&\
#      msys=dvbs&mtype=qpsk&pol=v&fec=34&\
#      pids=0,1,16,17,18,52,57,100,104,165,299,1029,3002,3003
#
# The SAT>IP parameters should match the SAT>IP specification:
#   http://www.satip.info/resources
#
# Note: The multicast address must be specified through minisatip
# option -r <ipv4_mcast> or --remote-rtp <ipv4_mcast>. By default,
# this address is 239.255.255.250.
#

import os
import signal
import functools
import asyncio
import syslog

VERSION='0.1'
SERVER='127.0.0.1:554'
CONFIG='/etc/sysconfig/multicast'
SYSLOG=True
if not os.path.exists('/etc/init.d/axe-settings'):
  SERVER='gssbox1:554'
  CONFIG='multicast'
  SYSLOG=False

#
#
#

def log(msg):
  if SYSLOG:
    syslog.syslog(msg)
  else:
    print('LOG:', msg)

#
# Basic RTSP Protocol class
#

class RTSP_Protocol:

  def __init__(self, client, params):
    self.mport = 0
    try:
      self.client = client
      client.protocol = self
      self.mport, self.params = params.split(':')[:2]
      self.mport = int(self.mport) & ~1
      self.params = self.params.strip().rstrip()
      self.action = 'SETUP'
    except:
      log('Invalid parameters: %s' % params)
      client.fatal()
    self.timeout = 30
    self.session = ''
    self.streamID = 0
    self.transport = None
    self.cseq = 0

  def connection_made(self, transport):
    self.transport = transport
    peer = transport._sock.getpeername()
    if self.action == 'SETUP':
      s = 'rtsp://%s:%d/?%s' % (peer[0], peer[1], self.params)
      log('%d (open): %s' % (self.mport, s))
      msg = 'SETUP ' + s + ' RTSP/1.0\r\n'
      msg += 'Transport: RTP/AVP;multicast;port=%d-%d\r\n' % (self.mport, self.mport + 1)
      msg += 'CSeq: 1\r\n\r\n'
      transport.write(msg.encode())
      self.action = 'PLAY'

  def data_received(self, data):
    if not self.transport:
      return
    lines = data.decode("utf-8").splitlines()
    if lines[0] != 'RTSP/1.0 200 OK':
      return self.retry()
    try:
      cseq = 0
      streamID = 0
      session = ''
      timeout = 0
      for line in lines[1:]:
        if line == '':
          break
        pos = line.find(':')
        if pos >= 0:
          header = line[:pos].lower()
          data = line[pos+1:].strip().lstrip()
          if header == 'cseq':
            cseq = int(data)
          elif header == 'com.ses.streamid':
            streamID = int(data)
          elif header == 'session':
            a = data.split(';')
            session = a[0]
            for b in a[1:]:
              if b.startswith('timeout='):
                timeout = int(b[8:])
      if timeout: self.timeout = timeout
      if session: self.session = session
      if streamID: self.streamID = streamID
    except:
      self.retry()
      return
    peer = self.transport._sock.getpeername()
    if self.action == 'PLAY':
      log('%d (setup): Session %s, streamID %d' % (self.mport, self.session, self.streamID))
      self.action = 'OPTIONS'
      msg = 'PLAY rtsp://%s:%d/stream=%d RTSP/1.0\r\n' % (peer[0], peer[1], self.streamID)
      msg += 'Session: %s\r\n' % self.session
      msg += 'CSeq: 2\r\n\r\n'
      self.transport.write(msg.encode())
      self.client.loop.call_at(loop.time() + self.timeout/2, self.options)
      self.cseq = 2
    elif self.action == 'OPTIONS' and self.cseq == 2:
      log('%d (play): OK' % self.mport)

  def options(self):
    if not self.transport:
      return
    peer = self.transport._sock.getpeername()
    if self.action == 'OPTIONS':
      self.cseq += 1
      msg = 'OPTIONS rtsp://%s:%d/stream=%d RTSP/1.0\r\n' % (peer[0], peer[1], self.streamID)
      msg += 'Session: %s\r\n' % self.session
      msg += 'CSeq: %d\r\n\r\n' % self.cseq
      self.transport.write(msg.encode())
      self.client.loop.call_at(loop.time() + self.timeout/2, self.options)

  def eof_received(self):
    self.connection_lost(None)

  def connection_lost(self, exc):
    if self.transport:
      log('%d (connection lost)' % self.mport)
      self.client.connection_lost()
      self.transport = None

  def retry(self):
    if self.transport:
      self.transport.close()

  def fatal(self):
    if self.transport:
      self.transport.close()
    self.client.protocol = None
    self.client.fatal()

#
# Basic RTSP Client class
#

class RTSP_Client:

  def __init__(self, clients, params):
    self.clients = clients
    self.loop = clients.loop
    self.protocol = None
    self.params = params
    self.dead = False
    self.deleteme = False
    log('New multicast client: %s' % (params))
    self.create_connection()

  def create_connection(self):
    srv = SERVER.split(':')
    coro = self.loop.create_connection(lambda: RTSP_Protocol(self, self.params),
                                       srv[0], srv[1] and int(srv[1]) or 554)
    asyncio.async(coro, loop=self.loop)

  def connection_lost(self):
    self.protocol = None
    self.loop.call_at(loop.time() + 30, self.create_connection)

  def fatal(self):
    if self.dead:
      return
    log("Remove multicast client: %s" % self.params)
    self.dead = True
    if self.protocol:
      self.protocol.fatal()
    self.clients.remove_client(self)

#
# RTSP clients
#

class RTSP_Clients:

  def __init__(self, loop):
    self.loop = loop
    self.clients = []
    self.parse_config_file()
    hupfcn = lambda: self.parse_config_file()
    loop.add_signal_handler(signal.SIGHUP,
                            functools.partial(hupfcn))

  def remove_client(self, client):
    if client in self.clients:
      self.clients.remove(client)
      client.fatal()

  def add_client(self, params):
    for c in self.clients:
      if c.params == params:
        c.deleteme = False
        return
    self.clients.append(RTSP_Client(self, params))

  def parse_config_file(self):
    # mark clients
    for c in self.clients:
      c.deleteme = True
    # parse config
    lines = []
    try:
      a = open(CONFIG)
      lines = a.readlines()
      a.close()
    except:
      pass
    prev = ''
    for l in lines:
      l = prev + l.strip().rstrip()
      prev = ''
      if not l or (l[0] in ['#', ';']): continue
      if l[-1] == '\\':
        prev = l[:-1].rstrip()
      else:
        self.add_client(l)
    # remove marked clients
    for c in self.clients:
      if c.deleteme:
        self.remove_client(c)

#
# configuration file
#

if SYSLOG:
  syslog.openlog('multicast-rtp', 0, syslog.LOG_LOCAL7)
log('Version %s' % VERSION)
loop = asyncio.get_event_loop()
clients = RTSP_Clients(loop)
try:
    loop.run_forever()
finally:
    loop.close()
if SYSLOG:
  syslog.closelog()