#!

"""Ethernet tunneled over GRE.

"""

import select
import socket

from .common import *
from . import logging
from . import datalink
from . import host

SvnFileRev = "$LastChangedRevision: 582 $"

greflags = bytes (2)

class GREPort (datalink.BcPort):
    """DEC Ethernet port class for GRE-encapsulated Ethernet.
    """
    def __init__ (self, datalink, owner, proto, pad = True):
        super ().__init__ (datalink, owner, proto)
        self.pad = pad
        f = self.frame = bytearray (1504)
        f[0:2] = greflags
        f[2:4] = self.proto

    def set_promiscuous (self, promisc = True):
        raise RuntimeError ("GRE does not support promiscuous mode")
                
    def send (self, msg, dest):
        """Send an "Ethernet" frame to the specified address.  Since GRE
        is point to point, the address is ignored.
        """
        l = len (msg)
        if logging.tracing:
            logging.tracepkt ("Sending packet on {}",
                              self.parent.name, pkt = msg)
        f = self.frame
        if self.pad:
            if l > 1498:
                raise ValueError ("Ethernet packet too long")
            f[4] = l & 0xff
            f[5] = l >> 8
            f[6:6 + l] = msg
            l += 6
        else:
            if l > 1500:
                raise ValueError ("Ethernet packet too long")
            f[4:4 + l] = msg
            l += 4
        self.counters.bytes_sent += l
        self.counters.pkts_sent += 1
        # We don't do padding, since GRE doesn't require it (it isn't
        # real Ethernet and doesn't have minimum frame lengths)
        self.parent.send_frame (memoryview (f)[:l])

GREPROTO = 47
class GRE (datalink.BcDatalink, StopThread):
    """DEC Ethernet datalink tunneled over GRE encapsulation.

    The --device parameter is required.  Its value is the remote host
    address or name.  The GRE protocol id (47) is assumed and hardwired.
    """
    port_class = GREPort
    use_mop = False    # True if we want MOP to run on this type of datalink
    
    def __init__ (self, owner, name, config):
        tname = "{}.{}".format (owner.node.nodename, name)
        StopThread.__init__ (self, name = tname)
        datalink.BcDatalink.__init__ (self, owner, name, config)
        if config.ipv4 and config.ipv6:
            raise ValueError ("GRE does not support dual IPv4/v6 mode")
        dest = config.device or config.destination
        # GREPROTO is an IP protocol, not a socket, but in the socket
        # calls it's passed along in the second element of an address
        # tuple just like a UDP address/port pair is, so we'll handle
        # it that way.
        self.source = host.SourceAddress (config, GREPROTO)
        if not self.source.can_listen:
            raise ValueError ("Source port must be specified")
        self.host = host.HostAddress (dest, GREPROTO, self.source)
        self.socket = None
        
    def open (self):
        # Create the socket and start receive thread.  Note that we do not
        # set the HDRINCL option, so transmitted packets have their IP
        # header generated by the kernel.  (But received packets appear
        # with an IP header on the front, what fun...)
        self.socket = self.host.create_raw (self.source, GREPROTO)
        self.skipIpHdr = self.host.listen_family == socket.AF_INET
        self.start ()
        
    def close (self):
        self.stop ()
        if self.socket:
            self.socket.close ()
        self.socket = None
        
    def create_port (self, owner, proto, pad = True):
        return super ().create_port (owner, proto, pad)

    def send_frame (self, buf):
        """Send an GRE-encapsulated Ethernet frame.  Ignore any errors,
        because that's the DECnet way.
        """
        try:
            self.socket.sendto (buf, self.host.sockaddr)
        except (AttributeError, IOError, TypeError):
            pass
        
    def run (self):
        logging.trace ("GRE datalink {} receive thread started", self.name)
        sock = self.socket
        if not sock:
            return
        p = select.poll ()
        p.register (sock, datalink.REGPOLLIN)
        while True:
            try:
                pl = p.poll (datalink.POLLTS)
            except select.error:
                logging.trace ("Poll error", exc_info = True)
                return False
            if self.stopnow:
                logging.trace ("Exiting due to stopnow")
                return
            if not pl:
                continue
            fn, mask = pl[0]
            if mask & datalink.POLLERRHUP:
                return
            if mask & select.POLLIN:
                # Receive a packet
                try:
                    msg, addr = sock.recvfrom (1504)
                except (AttributeError, OSError, socket.error):
                    msg = None
                if not msg or len (msg) <= 4:
                    continue
                if not self.host.valid (addr):
                    # Not from peer, ignore
                    continue
                # Skip past the IP header, if we're using IPv4.
                # Strangely enough, we don't get the header if IPv6.
                if self.skipIpHdr:
                    ver, hlen = divmod (msg[0], 16)
                    if ver == 4:
                        # IPv4, use the header length to skip past header
                        # and any options.
                        pos = 4 * hlen
                    else:
                        # Unknown IP header version
                        logging.trace ("Unknown IP header version {}", ver)
                        continue
                else:
                    pos = 0
                if logging.tracing:
                    logging.tracepkt ("Received packet on {}",
                                      self.name, pkt = msg)
                if msg[pos:pos + 2] != greflags:
                    # Unexpected flags or version in header, ignore
                    logging.debug ("On {}, unexpected header {}",
                                   self.name, msg[pos:pos + 2])
                    continue
                proto = msg[pos + 2:pos + 4]
                try:
                    port = self.ports[proto]
                except KeyError:
                    # No protocol type match, ignore msg
                    self.counters.unk_dest += 1
                    continue
                plen = len (msg) - (pos + 4)
                port.counters.bytes_recv += plen
                port.counters.pkts_recv += 1
                if port.pad:
                    plen2 = msg[pos + 4] + (msg[pos + 5] << 8)
                    if plen < plen2:
                        logging.debug ("On {}, msg length field {} " \
                                       "inconsistent with msg length {}",
                                       self.name, plen2, plen)
                        continue
                    msg = memoryview (msg)[pos + 6:pos + 6 + plen2]
                else:
                    msg = memoryview (msg)[pos + 4:]
                self.node.addwork (Received (port.owner,
                                             src = None, packet = msg))