#!/usr/bin/python3
import logging

import tornado.escape, tornado.ioloop, tornado.options, tornado.web, tornado.websocket
from tornado.options import define, options

define("port", default=1234, help="run on the given port", type=int)
define("secret", default="channel", help="upstream secret token", type=str)


class Application(tornado.web.Application):
    
    def __init__(self):
        handlers = [(r"/upload/(.*)", StreamHandler), (r"/live.ts", SocketHandler),]
        settings = dict(max_body_size = 504857600, max_buffer_size = 504857600)
        super(Application, self).__init__(handlers, **settings)

@tornado.web.stream_request_body
class StreamHandler(tornado.web.RequestHandler):
    
    def data_received(self, data):
        try:
            if options.secret and self.request.path != '/upload/' + options.secret:
                logging.info('Failed Stream Connection: %s - wrong secret.', self.request.remote_ip)
                self.write_error(403)
                return
            
            SocketHandler.max_message_size = 504857600
            SocketHandler.broadcast(data)
            SocketHandler.flush()
            SocketHandler.clear()
        
        except Exception:
            pass


class SocketHandler(tornado.websocket.WebSocketHandler):
    waiters = set()
    connected = False

    def check_origin(self, origin):
        return True

    def open(self):
        try:
            self.connected = True
            
            for waiter in SocketHandler.waiters:
                try:
                    waiter.close()
                    waiter.on_close()
                    
                except Exception as e:
                    pass
            
            SocketHandler.waiters.add(self)
            logging.info('New WebSocket Connection: %d total', len(SocketHandler.waiters))
            
        except Exception:
            pass

    def select_subprotocol(self, subprotocol):
        if len(subprotocol): return subprotocol[0]
        return super().select_subprotocol(subprotocol)

    def on_message(self, message):
        pass

    def on_close(self):
        try:
            self.connected = False
            SocketHandler.waiters.remove(self)
            logging.info('Disconnected WebSocket (%d total)', len(SocketHandler.waiters))
            
        except Exception as e:
            pass
            
    def on_connection_close(self):
        try:
            self.connected = False
            SocketHandler.waiters.remove(self)
            logging.info('Disconnected (Extended) WebSocket (%d total)', len(SocketHandler.waiters))
            
        except Exception as e:
            pass

    @classmethod
    def broadcast(cls, data):
        for waiter in cls.waiters:
            try:
                if waiter.connected: waiter.write_message(data, binary=True)
                
            except Exception as e:
                logging.error('Error broadcasting message %s to WebSocket: %s', e, waiter.request.remote_ip)
                # pass

def main():
    tornado.options.parse_command_line()
    
    app = Application()
    app.listen(options.port)
    
    tornado.ioloop.IOLoop.current().start()


if __name__ == "__main__":
    main()
