###############################################################################
#                                                                             #
#  D2K: A Doom Source Port for the 21st Century                               #
#                                                                             #
#  Copyright (C) 2014: See COPYRIGHT file                                     #
#                                                                             #
#  This file is part of D2K.                                                  #
#                                                                             #
#  D2K is free software: you can redistribute it and/or modify it under the   #
#  terms of the GNU General Public License as published by the Free Software  #
#  Foundation, either version 2 of the License, or (at your option) any       #
#  later version.                                                             #
#                                                                             #
#  D2K is distributed in the hope that it will be useful, but WITHOUT ANY     #
#  WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS  #
#  FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more     #
#  details.                                                                   #
#                                                                             #
#  You should have received a copy of the GNU General Public License along    #
#  with D2K.  If not, see <http://www.gnu.org/licenses/>.                     #
#                                                                             #
###############################################################################

from zone import *
from doomstat import *
from ticcmd import *
from game import *
from net.state import *
from physics.cmd import *
from physics.user import *

import net
import enet

const NET_THROTTLE_ACCEL: 2
const NET_THROTTLE_DECEL: 1

type PacketTOCEntry {

  :fields {
    val index: uint(0)
    val type:  uint(0)
  }

  fn serialize(var mpdata: MPackBuffer) {
    iferr mpdata.write_uint(self.index) {
      strerror("Error serializing TOC entry index")
    }
    iferr mpdata.write_uint(self.type) {
      strerror("Error serializing TOC entry type")
    }
  }

  fn deserialize(var mpdata: MPackBuffer) {
    var index: uint(0)
    var mtype: uint(0)

    iferr index = mpdata.read_uint() {
      strerror("Error reading message index from TOC")
    }

    iferr mtype = mpdata.read_uint() {
      strerror("Error reading message type from TOC")
    }

    self.index = index
    self.type = mtype
  }

}

type PacketTOC (dynarray<var PacketTOCEntry>) {

  fn serialize(var mpdata: MPackBuffer) {
    iferr mpdata.write_table(self.entries.length) {
      strerror("Error serializing TOC entry count")
    }

    for entry in self.entries {
      entry.serialize(mpdata!)
    }
  }

  fn deserialize(var mpdata: MPackBuffer) {
    var toc_size: uint(0)

    iferr toc_size = mpdata.read_map() {
      strerror("Error reading TOC size")
    }

    self.clear()
    self.ensure_capacity(toc_size)

    for var entry in self {
      entry.deserialize(mpdata)
    }
  }

  fn add_entry(val index: uint, val mtype: uint) {
    self.ensure_capacity(self.length + 1)
    self[self.length - 1].index = index
    self[self.length - 1].type = mtype
  }

}

type NetMessageLoadResult {

  :fields {
    success: bool(false)
  }

  :variants {
    match success {
      true {
        success: bool
        mtype: uint
      }
      false {
        success: bool
        reason: string
      }
    }
  }

}

type NetChannel {

  :fields {
    toc!:         *PacketToc()
    messages!:    *MPackBuffer()
    packet_data!: *MPackBuffer()
    packed_toc!:  *MPackBuffer()
  }

  fn clear() {
    self.toc.clear()
    self.messages.clear()
    self.packet_data.clear()
    self.packed_toc.clear()
  }

  pred is_empty {
    return self.toc.length == 0
  }

  fn begin_message(uint message_type) (MPackBuffer!) {
    var index: self.messages.get_cursor()
    var length: self.toc.length

    self.toc.add_entry(self.messages.get_cursor(), message_type)

    return self.messages!
  }

  fn build_outgoing_packet(enet.PacketFlag packet_flag) (*ENetPacket!) {
    self.packed_toc.ensure_capacity(self.toc.length * 10)
    self.toc.serialize(&self.packed_toc!)

    var toc_size:  self.packed_toc.get_size()
    var msg_size:  self.messages.get_size()
    var pkt_size:  toc_size + msg_size
    var packet!:  *enet.Packet(pkt_size, packet_flag)

    packet.read_from(self.packed_toc, 0, toc_size)
    packet.read_from(self.messages, 0, msg_size)

    return packet!
  }

  fn load_incoming_packet(ByteBuffer bb) (bool) {
    var local_bb!: ByteBuffer(alloc: bb.length)

    self.clear()

    local_bb.read_from(bb, 0, bb.length)

    local_bb.set_cursor(0)

    self.toc.deserialize(local_bb!)

    var message_start_point: self.toc.get_cursor()

    if message_start_point >= bb.length {
      console.echo("NetChannel.load_incoming_packet: Received empty packet.")
      return false
    }

    self.messages.read_from(bb, message_start_point, size - message_start_point)

    return true
  }

  fn load_next_message() (NetMessageLoadResult) {
    var valid_message!: true

    if self.toc.length == 0 {
      return *NetMessageLoadResult(false, "No messages left")
    }

    for entry in self.toc {
      if entry.index >= self.messages.length {
        console.echo({
          "NetChannel.load_next_message: Invalid message index "
          "(${entry.index} >= ${self.messages.length})."
        })
        return *MessageLoadResult(false,
          "Invalid message index ${entry.index}")
      }

      var result: *NetMessageLoadResult(true, entry.type)

      self.messages.set_cursor(entry.index)

      iferr self.toc.remove(0) {}

      return result
    }
  }

}

type NetCom {

  :fields {
    incoming!:   *NetChannel()
    reliable!:   *NetChannel()
    unreliable!: *NetChannel()
  }

  fn begin_reliable_message(uint message_type) (MPackBuffer!) {
    return self.reliable.begin_message(message_type)
  }

  fn begin_unreliable_message(uint message_type) (MPackBuffer!) {
    return self.unreliable.begin_message(message_type)
  }

  fn build_reliable_packet() (*ENetPacket!) {
    return self.reliable.build_packet(enet.PACKET_FLAG_RELIABLE)
  }

  fn build_unreliable_packet() (*ENetPacket!) {
    return self.unreliable.build_packet(enet.PACKET_FLAG_UNSEQUENCED)
  }
  
  fn load_incoming_packet(ByteBuffer bb) (bool) {
    return self.incoming.load_packet(bb)
  }

  fn load_next_message() (NetMessageLoadResult) {
    return self.incoming.load_next_message()
  }

}

type NetSync {

  :fields {
    initialized!:   bool(false)
    outdated!:      bool(false)
    tic!:           uint(0)
    command_index!: uint(0)
    delta!:         GameStateDelta()
  }

  fn reset() {
    self.initialized = false
    self.outdated = false
  }

}

type NetPeer {

  :fields {
    num:             uint(1)
    playernum:       uint(0)
    address:         u32(1017) # 127.0.0.1
    port:            u16(1)
    connection_id:   uint(0)
    netcom!:         NetCom()
    netsync!:        NetSync()
    connect_time:    time()
    disconnect_time: time()
  }

  init from_enet_peer(ENetPeer epeer, uint num) {
    self.num = num
    self.address = epeer.address.host
    self.port = epeer.address.port
    self.connection_id = epeer.connectID
    self.connect_time = time.now()

    if mode.is_server {
      epeer.configure_throttle(300, NET_THROTTLE_ACCEL, NET_THROTTLE_DECEL)
    }
  }

  get enet_peer (ENetPeer) {
    with enet.find_peer(self.connection_id) as epeer {
      return epeer
    }
  }

  fn set_disconnected() {
    console.echo({
      "Removing peer ${self.num} "
      "${net.ip_to_const_string(self.address)}:${self.port}"
    })

    with game.get_player(self.playernum) as player! {
      player.playerstate = PST_DISCONNECTED
    }

    self.disconnect_time = time.now()

    with self.enet_peer as epeer {
      epeer.disconnect(0)
    }
  }

  fn check_timeout() (bool) {
    var now: time.now()

    if self.connect_time == 0 or self.disconnect_time == 0 {
      return false
    }

    if timedelta(now, self.connect_time) > (NET_CONNECT_TIMEOUT * 1000) ||
       timedelta(now, self.disconnect_time > (NET_DISCONNECT_TIMEOUT * 1000) {
      return true
    }

    return false
  }

  fn flush_buffers() {
    with self.enet_peer as epeer {
      if !self.netcom.reliable.is_empty {
        epeer.send(NET_CHANNEL_RELIABLE, self.build_reliable_packet())
        self.netcom.reliable.clear()
      }
      if !self.netcom.unreliable.is_empty {
        epeer.send(NET_CHANNEL_UNRELIABLE, self.build_unreliable_packet())
        self.netcom.unreliable.clear()
      }
    }
  }

  fn begin_reliable_message(uint message_type) (MPackBuffer!) {
    return self.netcom.begin_reliable_message(message_type)
  }

  fn begin_unreliable_message(uint message_type) (MPackBuffer!) {
    return self.netcom.begin_unreliable_message(message_type)
  }

  fn build_reliable_packet() (*ENetPacket!) {
    return self.netcom.build_reliable_packet(enet.PACKET_FLAG_RELIABLE)
  }

  fn build_unreliable_packet() (*ENetPacket!) {
    return self.netcom.build_unreliable_packet(enet.PACKET_FLAG_UNRELIABLE)
  }

  fn load_incoming_packet(ByteBuffer bb) (bool) {
    return self.netcom.load_incoming_packet(bb)
  }

  fn load_next_message() (NetMessageLoadResult) {
    return self.netcom.load_next_message()
  }
}

type NetPeerSearchResult {
  :fields {
    found: bool(false)
  }

  :variants {
    match found {
      true {
        found:  bool
        peer:  &NetPeer
      }
      false {
        found: bool
      }
    }
  }
}

type NetPeerNumSearchResult {
  :fields {
    found: bool(false)
  }

  :variants {
    match found {
      true {
        found: bool
        num:   uint
      }
      false {
        found: bool
      }
    }
  }
}

type NetPeers (table(<uint>, <*NetPeer>)) {

  fn add_from_enet_peer(ENetPeer epeer) (uint) {
    var num: uint(1)

    while self.contains(num) {
      num++
    }

    var peer!: *NetPeer.from_enet_peer(epeer, num)

    self.set(peer.num, *peer!)

    return num
  }

  get server_peer (NetPeerSearchResult) {
    with self.get(1) as server_peer {
      return *NetPeerSearchResult(true, server_peer)
    }

    return *NetPeerSearchResult(false)
  }

  fn lookup_by_enet_peer(ENetPeer epeer) (NetPeerSearchResult) {
    for num, peer in self {
      if peer.connection_id == epeer.connectID {
        return *NetPeerSearchResult(true, peer)
      }
    }

    return *NetPeerSearchResult(false)
  }

  fn lookup_by_playernum(uint playernum) (NetPeerSearchResult) {
    for num, peer in self {
      if peer.playernum == playernum) {
        return *NetPeerSearchResult(true, peer)
      }
    }

    return *NetPeerSearchResult(false)
  }

  fn lookup_peernum_by_playernum(uint playernum) (NetPeerNumSearchResult) {
    with self.lookup_by_playernum(playernum) as peer {
      return *NetPeerNumSearchResult(true, peer.num)
    }

    return *NetPeerNumSearchResult(false)
  }

}

# vi: set et ts=2 sw=2: