###############################################################################
#                                                                             #
#  D2K: A Doom Source Port for the 21st Century                               #
#                                                                             #
#  Copyright (C) 2014: See COPYRIGHT file                                     #
#                                                                             #
#  This file is part of D2K.                                                  #
#                                                                             #
#  D2K is free software: you can redistribute it and/or modify it under the   #
#  terms of the GNU General Public License as published by the Free Software  #
#  Foundation, either version 2 of the License, or (at your option) any       #
#  later version.                                                             #
#                                                                             #
#  D2K is distributed in the hope that it will be useful, but WITHOUT ANY     #
#  WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS  #
#  FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more     #
#  details.                                                                   #
#                                                                             #
#  You should have received a copy of the GNU General Public License along    #
#  with D2K.  If not, see <http://www.gnu.org/licenses/>.                     #
#                                                                             #
###############################################################################

from zone import *
from doomstat import *
from ticcmd import *
from game import *
from net.state import *
from physics.cmd import *
from physics.user import *

import net
import enet

const NET_THROTTLE_ACCEL: 2
const NET_THROTTLE_DECEL: 1

type PacketTOCEntry {

  var index: uint(0)
  var type: uint(0)

  fn serialize(MPackBuffer mpdata!) {
    mpdata.write_uint(self.index)
    mpdata.write_uint(self.type)
  }

  fn deserialize(MPackBuffer mpdata!) {
    var index: mpdata.read_uint()
    var mtype: mpdata.read_uint()

    self.index = index
    self.type = mtype
  }

}

type PacketTOC (dynarray<PacketTOCEntry!>) {

  fn serialize(MPackBuffer mpdata!) {
    mpdata.write_table(self.entries.length)

    for entry in self.entries {
      entry.serialize(mpdata!)
    }
  }

  fn deserialize(MPackBuffer mpdata!) {
    var toc_size: mpdata.read_map()

    self.clear()
    self.ensure_capacity(toc_size)

    for entry! in self {
      entry.deserialize(mpdata!)
    }
  }

  fn add_entry(uint index, uint mtype) {
    # This code is all provably correct, but it's pretty hard for a compiler
    # to know that.  I wonder if there should be better indexing syntax or
    # something.
    self.ensure_capacity(self.length + 1)
    self[self.length - 1].index = index
    self[self.length - 1].type = mtype
  }

}

type NetMessageLoadResult {

  val success: bool(false)

  @variants {
    switch success {
      case true {
        val success: bool(true)
        val mtype: uint(0)
      }
      case false {
        val success: bool(false)
        val reason: str("unknown error occured")
      }
    }
  }

}

@error PacketEmpty {
  @message: "Packet is empty"

  @handler {
    if @this:error.message == 'Fuck off!!!!' {
      echo("This error is mega-rude")
    }

    echo("Got an error:")
    echo("    Message: {{@this:error.message}}")
    echo("    File: {{@this:error.file}}")
    echo("    Line: {{@this:error.line}}")
  }
}

type NetChannel {

  var toc:         *PacketToc()
  var messages:    *MPackBuffer()
  var packet_data: *MPackBuffer()
  var packed_toc:  *MPackBuffer()

  fn clear() {
    self.toc.clear()
    self.messages.clear()
    self.packet_data.clear()
    self.packed_toc.clear()
  }

  pred is_empty {
    return self.toc.length == 0
  }

  fn begin_message(uint message_type) (MPackBuffer!) {
    var index: self.messages.get_cursor()
    var length: self.toc.length

    self.toc.add_entry(self.messages.get_cursor(), message_type)

    return self.messages!
  }

  fn build_outgoing_packet(enet.PacketFlag packet_flag) (*ENetPacket!) {
    self.packed_toc.clear()
    self.packed_toc.ensure_capacity(self.toc.length * 10)
    self.toc.serialize(self.packed_toc!)

    val toc_size:  self.packed_toc.get_size()
    val msg_size:  self.messages.get_size()
    var packet:   *enet.Packet(toc_size + msg_size, packet_flag)

    packet.read_from(self.packed_toc, 0, toc_size)
    packet.read_from(self.messages, 0, msg_size)

    return packet!
  }

  fn load_incoming_packet(ByteBuffer bb) {
    var local_bb: ByteBuffer.new_with_alloc(bb.length)

    self.clear()

    local_bb.read_from(bb, 0, bb.length)

    local_bb.set_cursor(0)

    self.toc.deserialize(local_bb!)

    val message_start_point: self.toc.get_cursor()

    if message_start_point >= bb.length {
      fail PacketEmpty()
    }

    self.messages.read_from(bb, message_start_point, size - message_start_point)
  }

  fn load_next_message() (NetMessageLoadResult) {
    var valid_message: true

    if self.toc.length == 0 {
      return *NetMessageLoadResult(false, "No messages left")
    }

    for entry in self.toc {
      if entry.index >= self.messages.length {
        console.echo({
          "NetChannel.load_next_message: Invalid message index "
          "(${entry.index} >= ${self.messages.length})."
        })
        return *MessageLoadResult(false,
          "Invalid message index ${entry.index}")
      }

      var result: *NetMessageLoadResult(true, entry.type)

      self.messages.set_cursor(entry.index)

      iferr self.toc.remove(0) {}

      return result
    }
  }

}

type NetCom {

  var incoming:   *NetChannel()
  var reliable:   *NetChannel()
  var unreliable: *NetChannel()

  fn begin_reliable_message(uint message_type) (MPackBuffer!) {
    return self.reliable.begin_message(message_type)
  }

  fn begin_unreliable_message(uint message_type) (MPackBuffer!) {
    return self.unreliable.begin_message(message_type)
  }

  fn build_reliable_packet() (*ENetPacket!) {
    return self.reliable.build_packet(enet.PACKET_FLAG_RELIABLE)
  }

  fn build_unreliable_packet() (*ENetPacket!) {
    return self.unreliable.build_packet(enet.PACKET_FLAG_UNSEQUENCED)
  }
  
  fn load_incoming_packet(ByteBuffer bb) (bool) {
    return self.incoming.load_packet(bb)
  }

  fn load_next_message() (NetMessageLoadResult) {
    return self.incoming.load_next_message()
  }

}

type NetSync {

  var initialized:   bool(false)
  var outdated:      bool(false)
  var tic:           uint(0)
  var command_index: uint(0)
  var delta:         GameStateDelta()

  fn reset() {
    self.initialized = false
    self.outdated = false
  }

}

type NetPeer {

  val num:             uint(1)
  val playernum:       uint(0)
  val address:         u32(1017) # 127.0.0.1
  val port:            u16(1)
  val connection_id:   uint(0)
  var netcom:          NetCom()
  var netsync:         NetSync()
  val connect_time:    time()
  val disconnect_time: time()

  static from_enet_peer(ENetPeer epeer, uint num) (*NetPeer!) {
    var np: {
      num: num
      address: epeer.address.host
      port: epeer.address.port
      connection_id: epeer.connectID
      connect_time: time.now()
    }

    if mode.is_server {
      epeer.configure_throttle(300, NET_THROTTLE_ACCEL, NET_THROTTLE_DECEL)
    }

    return np!
  }

  get enet_peer (ENetPeer) {
    with enet.find_peer(self.connection_id) as epeer {
      return epeer
    }
  }

  fn set_disconnected() {
    console.echo({
      "Removing peer ${self.num} "
      "${net.ip_to_const_string(self.address)}:${self.port}"
    })

    with game.get_player(self.playernum) as player! {
      player.playerstate = PST_DISCONNECTED
    }

    self.disconnect_time = time.now()

    with self.enet_peer as epeer {
      epeer.disconnect(0)
    }
  }

  fn check_timeout() (bool) {
    var now: time.now()

    if self.connect_time == 0 or self.disconnect_time == 0 {
      return false
    }

    if timedelta(now, self.connect_time) > (NET_CONNECT_TIMEOUT * 1000) ||
       timedelta(now, self.disconnect_time > (NET_DISCONNECT_TIMEOUT * 1000) {
      return true
    }

    return false
  }

  fn flush_buffers() {
    with self.enet_peer as epeer {
      if !self.netcom.reliable.is_empty {
        epeer.send(NET_CHANNEL_RELIABLE, self.build_reliable_packet())
        self.netcom.reliable.clear()
      }
      if !self.netcom.unreliable.is_empty {
        epeer.send(NET_CHANNEL_UNRELIABLE, self.build_unreliable_packet())
        self.netcom.unreliable.clear()
      }
    }
  }

  fn begin_reliable_message(uint message_type) (MPackBuffer!) {
    return self.netcom.begin_reliable_message(message_type)
  }

  fn begin_unreliable_message(uint message_type) (MPackBuffer!) {
    return self.netcom.begin_unreliable_message(message_type)
  }

  fn build_reliable_packet() (*ENetPacket!) {
    return self.netcom.build_reliable_packet(enet.PACKET_FLAG_RELIABLE)
  }

  fn build_unreliable_packet() (*ENetPacket!) {
    return self.netcom.build_unreliable_packet(enet.PACKET_FLAG_UNRELIABLE)
  }

  fn load_incoming_packet(ByteBuffer bb) (bool) {
    return self.netcom.load_incoming_packet(bb)
  }

  fn load_next_message() (NetMessageLoadResult) {
    return self.netcom.load_next_message()
  }
}

type NetPeerSearchResult {
  val found: bool(false)

  @variants {
    switch found {
      case true {
        val found:  bool(true)
        val peer:  *NetPeer
      }
    }
  }
}

type NetPeerNumSearchResult {
  val found: bool(false)

  @variants {
    switch found {
      case true {
        val found: bool(true)
        val num:   uint(0)
      }
    }
  }
}

type NetPeers (table(<uint!>, <*NetPeer!>)) {

  fn add_from_enet_peer(ENetPeer epeer) (uint) {
    var num: uint(1)

    while self.contains(num) {
      num++
    }

    var peer: *NetPeer.from_enet_peer(epeer, num)

    self.set(peer.num!, *peer!)

    return num
  }

  get server_peer (NetPeerSearchResult) {
    with self.get(1) as server_peer {
      return *NetPeerSearchResult(true, server_peer)
    }

    return *NetPeerSearchResult(false)
  }

  fn lookup_by_enet_peer(ENetPeer epeer) (NetPeerSearchResult) {
    for num, peer in self {
      if peer.connection_id == epeer.connectID {
        return *NetPeerSearchResult(true, peer)
      }
    }

    return *NetPeerSearchResult(false)
  }

  fn lookup_by_playernum(uint playernum) (NetPeerSearchResult) {
    for num, peer in self {
      if peer.playernum == playernum) {
        return *NetPeerSearchResult(true, peer)
      }
    }

    return *NetPeerSearchResult(false)
  }

  fn lookup_peernum_by_playernum(uint playernum) (NetPeerNumSearchResult) {
    with self.lookup_by_playernum(playernum) as peer {
      return *NetPeerNumSearchResult(true, peer.num)
    }

    return *NetPeerNumSearchResult(false)
  }

}

# vi: set et ts=2 sw=2: