#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Copyright(C) 2007 INL
Written by Romain Bignon <romain AT inl.fr>

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3 of the License.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software
Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.

$Id: table.py 12148 2008-01-11 17:12:26Z romain $
"""

from inl import TableBase
from cStringIO import StringIO
import struct

from twisted.internet.defer import gatherResults


class PacketTable(TableBase):

    def __init__(self, table):
        TableBase.__init__(self, table,
              ['id','username','ip_saddr', 'ip_daddr',  'proto','SPort', 'DPort','timestamp','oob_prefix'])

        self.args = {'sortby': 'timestamp',
                     'sort':   'DESC',
                     'limit':  30,
                     'start':  0,
                     'tiny':   False}

    def entry_form(self, entry):
        """ We transform IP form to a string
            @param entry [tuple]
            @return [tuple]
        """

        """ id, username, user_id, ip_saddr, ip_daddr, protocol, SPort, DPort, timestamp, oob_prefix, state """

        if entry[1]:
            _id = 'user'
        else:
            _id = 'host'
        result = ((_id, entry[0]),)

        if not self.args['tiny'] or (not self.filters.has_key('userlike') and not self.filters.has_key('user_id')):
            result += ((entry[1], entry[2]),)

        result += (self.ip2str(entry[3]),)
        result += (self.ip2str(entry[4]),)

        if self.args['tiny']:
            result += entry[5:8]
        else:
            result += entry[5:10]

        if not hasattr(self, 'states'):
            self.states = dict()
        self.states[entry[0]] = entry[10]
        return result

    def make_tiny(self):

        if self.filters.has_key('userlike') or self.filters.has_key('user_id'):
            self._remove_column('username')
        self._remove_column('timestamp')
        self._remove_column('oob_prefix')

    def _arg_where_userlike(self, args, key, value):
        return self._arg_where_like(args, 'username', value)

    def check_args(self, args, where):
        """
            @param start [integer] first entry
            @param limit [integer] number of results
            @param sortby [string] field to sort by in ('id', 'ip_saddr', 'ip_daddr', 'protocol',
                                                        'SPort', 'DPort', 'timestamp', 'oob_prefix')
            @param sort [string] ASC or DESC
            @param ip_saddr [string] use a filter on this field
            @param ip_daddr [string] use a filter on this field
            @param user_id [string] use a filter on this field
            @param tcp_sport [string] use a filter on this field
            @param tcp_dport [string] use a filter on this field
        """

        self._arg_int(args, 'limit')
        self._arg_int(args, 'start')
        self._arg_in (args, 'sortby', ('id', 'ip_saddr', 'ip_daddr', 'protocol', 'SPort', 'DPort', 'timestamp', 'oob_prefix'))
        self._arg_in (args, 'sort',   ('DESC', 'ASC'))
        self._arg_where(args, where, {'ip_saddr':   self._arg_where_ip,
                                      'ip_daddr':   self._arg_where_ip,
                                      'ip_addr':    self._arg_where_ip_both,
                                      'user_id':    self._arg_where_int,
                                      'userlike':   self._arg_where_userlike,
                                      'sport':      self._arg_where_port,
                                      'dport':      self._arg_where_port,
                                      'state':      self._arg_where_state,
                                      'proto':      self._arg_where_proto,
                                      'client_app': None,
                                      'begin':      self._arg_where_begin_time,
                                      'end':        self._arg_where_end_time
                                      })

        self._arg_bool(args, 'tiny')

        if self.args['tiny']:
            self.make_tiny()

    def __call__(self, **args):
        where = StringIO()

        self.check_args(args, where)

        # This is an optimization to only find on ids >= MAX(id)-limit-start
        # We can think this is slower to do a second request, but apparently no...
        if not where.getvalue():
            where.write('WHERE id >= (SELECT MAX(id)-%d-%d FROM ulog)' % (self.args['limit'], self.args['start']))

        result = self._sql_query("select_packets", where.getvalue())

        result.addCallback(self._print_result)
        return result

class ConnTrackTable(PacketTable):

    def __init__(self, table):
        PacketTable.__init__(self, table)

    def __call__(self, **args):

        where = StringIO()

        if not args.has_key('state'):
            args['state'] = 4

        self.check_args(args, where)

        if self.args.has_key('begin'):
            self.args.pop('begin')
        if self.args.has_key('end'):
            self.args.pop('end')

        result = self._sql_query("select_packets", where.getvalue(), conntrack=True)

        result.addCallback(self._print_result)
        return result


class ConUserTable(TableBase):

    def __init__(self, table):
        TableBase.__init__(self, table, ['username', 'ip_saddr', 'os_sysname', 'start_time', 'end_time'])

        self.args = {'sortby':   'start_time',
                     'sort':     'DESC',
                     'limit':    30,
                     'start':    0,
                     'currents': False
                    }

    def entry_form(self, entry):
        """ We transform IP form to a string
            @param entry [tuple]
            @return [tuple]
        """

        """ username, user_id, ip_saddr, os_sysname, start_time[, end_time] """

        # On *usersstats* table, ip endian is not the same than other tables, so we reverse all bits.
        if self.database.ip_type == 4 and (isinstance(entry[2], int) or isinstance(entry[2], long)
                                          or isinstance(entry[2], str) and entry[2].isdigit()):
            ip = struct.unpack("<I", struct.pack(">I", int(entry[2])))[0]
        else:
            ip = entry[2]
        result = ((entry[0], entry[1]),)
        result += (self.ip2str(ip),)
        result += entry[3:5]
        if not self.args['currents']:
            result += (entry[5],)

        return result

    def __call__(self, **args):
        """
            @param start [integer] first entry
            @param limit [integer] number of results
            @param sortby [string] field to sort by in ('id', 'ip_saddr', 'ip_daddr', 'protocol',
                                                        'SPort', 'DPort', 'timestamp', 'oob_prefix')
            @param sort [string] ASC or DESC
            @param ip_saddr [string] use a filter on this field
            @param user_id [integer] use a filter on this field
            @param os_sysname [string] filter on os name.
        """
        where = StringIO()

        self._arg_int(args, 'limit')
        self._arg_int(args, 'start')
        self._arg_in (args, 'sortby', ('username', 'ip_saddr', 'os_sysname', 'start_time', 'end_time'))
        self._arg_in (args, 'sort',   ('DESC', 'ASC'))
        self._arg_bool(args, 'currents')

        if self.args['currents']:
            where.write('WHERE end_time IS NULL')
            self._remove_column('end_time')

        self._arg_where(args, where, {'ip_saddr':   self._arg_where_REVERSEip,
                                      'user_id':    self._arg_where_int,
                                      'os_sysname': None
                                     })

        result = self._sql_query("select_conusers", where.getvalue())

        result.addCallback(self._print_result)
        return result

class PortTable(TableBase):

    def __init__(self, table, protocol):
        self.proto = protocol
        TableBase.__init__(self, table, [self.proto + '_dport', 'packets', 'begin', 'end'])

        self.args = {'start':     0,
                     'limit':     10,
                     'sortby':    'end',
                     'sort':      'DESC',
                     'proto':     protocol}

    def __call__(self, **args):
        """
            @param limit [integer] Number of entry returned
            @param start [integer] First entry number
            @param sortby [string] Sort by this field ('tcp_dport', 'packets', 'begin', 'end')
            @param sort [string] Sort order ('DESC', 'ASC')
            @param ip_saddr [string] filter on source ip
            @param ip_daddr [string] filter on destination ip
            @param user_id [integer] filter on user_id
            @param proto [string] filter on protocol.
        """

        where = StringIO()
        self._arg_int(args, 'limit')
        self._arg_int(args, 'start')
        self._arg_in (args, 'sortby', (self.proto + '_dport', 'packets', 'begin', 'end'))
        self._arg_in (args, 'sort',   ('DESC', 'ASC'))
        self._arg_where(args, where, {'ip_saddr': self._arg_where_ip,
                                      'ip_daddr': self._arg_where_ip,
                                      'ip_addr':  self._arg_where_ip_both,
                                      'user_id':  self._arg_where_int,
                                      })

        result = self._sql_query("select_ports", self.proto, where.getvalue())

        result.addCallback(self._print_result)
        count = self._sql_query("count_drop_port", self.proto, where.getvalue(), display=False)
        count.addCallback(self._save_count)

        return gatherResults([result, count])

class TCPTable(PortTable):

    def __init__(self, table):

        PortTable.__init__(self, table, 'tcp')

class UDPTable(PortTable):

    def __init__(self, table):

        PortTable.__init__(self, table, 'udp')

class IpTable(TableBase):

    def __init__(self, table, direction):
        TableBase.__init__(self, table, ['ip_%saddr' % direction, 'packets', 'begin', 'end' ])
        self.direction = direction

        self.args = {'start':     0,
                     'limit':     10,
                     'sortby':    'end',
                     'sort':      'DESC'}

    def entry_form(self, entry):
        """ We transform IP form to a string
            @param entry [tuple]
            @return [list of tuple]
        """

        # ip, packts, begin, end
        result = (self.ip2str(entry[0]),)
        result += entry[1:]
        return result

    def __call__(self, **args):
        """
            @param start [integer] First entry number
            @param limit [integer] Number of entry returned
            @param sortby [string] Field used to order table, in ['ip_saddr', 'packets', 'begin', 'end']
            @param sort [string] Sort in a ascendant or a descendant order ['ASC', 'DESC']
            @param ip_saddr [string] filter on source ip
            @param ip_daddr [string] filter on destination ip
            @param user_id [integer] filter on user_id
            @param dport [integer] filter on destination port
            @param sport [integer] filter on source port
            @param proto [string] filter on protocol.
        """

        where = StringIO()
        self._arg_int(args, 'limit')
        self._arg_int(args, 'start')
        self._arg_in (args, 'sortby', ('ip_%saddr' % self.direction, 'packets', 'begin', 'end'))
        self._arg_in (args, 'sort',   ('DESC', 'ASC'))
        self._arg_where(args, where, {'ip_saddr': self._arg_where_ip,
                                      'ip_daddr': self._arg_where_ip,
                                      'user_id':  self._arg_where_int,
                                      'dport':    self._arg_where_port,
                                      'sport':    self._arg_where_port,
                                      'proto':    self._arg_where_proto
                                     })

        result = self._sql_query("select_ip", self.direction, where.getvalue())
        result.addCallback(self._print_result)
        count = self._sql_query("count_drop_ip", self.direction, where.getvalue(), display=False)
        count.addCallback(self._save_count)

        return gatherResults([result, count])

class IPsrcTable(IpTable):

    def __init__(self, table):

        IpTable.__init__(self, table, 's')

class IPdstTable(IpTable):

    def __init__(self, table):

        IpTable.__init__(self, table, 'd')

class UserTable(TableBase):

    def __init__(self, table):
        TableBase.__init__(self, table, ['username', 'packets', 'begin', 'end' ])

        self.args = {'start':     0,
                     'limit':     10,
                     'sortby':    'end',
                     'sort':      'DESC'}

    def entry_form(self, entry):

        # username, user_id, packets, begin, end

        result = ((entry[0], entry[1]),)
        result += entry[2:]

        return result

    def __call__(self, **args):
        """
            @param start [integer] First entry number
            @param limit [integer] Number of entry returned
            @param sortby [string] Field used to order table, in ['ip_saddr', 'packets', 'begin', 'end']
            @param sort [string] Sort in a ascendant or a descendant order ['ASC', 'DESC']
        """

        self._arg_int(args, 'limit')
        self._arg_int(args, 'start')
        self._arg_in (args, 'sortby', ('username', 'packets', 'begin', 'end'))
        self._arg_in (args, 'sort',   ('DESC', 'ASC'))

        where = StringIO()
        self._arg_where(args, where, {'ip_saddr': self._arg_where_ip,
                                      'ip_daddr': self._arg_where_ip,
                                      'dport':    self._arg_where_port,
                                      'sport':    self._arg_where_port,
                                      'proto':    self._arg_where_proto
                                     })

        result = self._sql_query("select_user", where.getvalue())
        count = self._sql_query("count_drop_user", where.getvalue(), display=False)
        count.addCallback(self._save_count)
        result.addCallback(self._print_result)

        return gatherResults([result, count])

class AppTable(TableBase):

    def __init__(self, table):
        TableBase.__init__(self, table, ['client_app',  'packets', 'begin', 'end'])

        self.args = {'start':       0,
                     'limit':       10,
                     'sortby':      'end',
                     'sort':        'DESC'}

    def __call__(self, **args):
        """
            @param start [integer] First entry number
            @param limit [integer] Number of entry returned
            @param sortby [string] Field used to order table, in ['ip_saddr', 'packets', 'begin', 'end']
            @param sort [string] Sort in a ascendant or a descendant order ['ASC', 'DESC']
            @param ip_saddr [string] filter on source ip
            @param ip_daddr [string] filter on destination ip
            @param user_id [integer] filter on user_id
            @param dport [integer] filter on destination port
            @param sport [integer] filter on source port
            @param proto [string] filter on protocol.
        """

        where = StringIO()
        self._arg_int(args, 'limit')
        self._arg_int(args, 'start')
        self._arg_in (args, 'sortby', ('client_app', 'packets', 'begin', 'end'))
        self._arg_in (args, 'sort',   ('DESC', 'ASC'))
        self._arg_where(args, where, {'ip_saddr': self._arg_where_ip,
                                      'ip_daddr': self._arg_where_ip,
                                      'user_id':  self._arg_where_int,
                                      'dport':    self._arg_where_port,
                                      'sport':    self._arg_where_port,
                                      'proto':    self._arg_where_proto
                                     })

        result = self._sql_query("select_apps", where.getvalue())
        result.addCallback(self._print_result)
        count = self._sql_query("count_drop_apps", where.getvalue(), display=False)
        count.addCallback(self._save_count)

        return gatherResults([result, count])

class BadHosts(TableBase):

    def __init__(self, table):
        TableBase.__init__(self, table, ['ip_saddr', 'rate'])

        self.args = {'start':    0,
                     'limit':    5,
                     'sortby':   'rate',
                     'sort':     'DESC',
                    }

    def entry_form(self, entry):
        """ We transform IP form to a string
            @param entry [tuple]
            @return [list of tuple]
        """

        # ip, packts
        result = (self.ip2str(entry[0]),)
        result += entry[1:]
        return result

    def __call__(self, **args):

        self._arg_int(args, 'limit')
        self._arg_int(args, 'start')
        self._arg_in (args, 'sortby', ('ip_saddr', 'rate'))
        self._arg_in (args, 'sort',   ('DESC', 'ASC'))

        return self._sql_query("select_badhosts").addCallback(self._print_result)

class BadUsers(TableBase):

    def __init__(self, table):
        TableBase.__init__(self, table, ['username', 'rate'])

        self.args = {'start':    0,
                     'limit':    5,
                     'sortby':   'rate',
                     'sort':     'DESC',
                    }

    def entry_form(self, entry):
        """ We transform IP form to a string
            @param entry [tuple]
            @return [list of tuple]
        """

        # user_id, username, packets
        result = ((entry[1], entry[0]),)
        result += entry[2:]
        return result

    def __call__(self, **args):

        self._arg_int(args, 'limit')
        self._arg_int(args, 'start')
        self._arg_in (args, 'sortby', ('username', 'rate'))
        self._arg_in (args, 'sort',   ('DESC', 'ASC'))

        return self._sql_query("select_badusers").addCallback(self._print_result)
