// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
using Newtonsoft.Json;
using Xunit;

namespace Microsoft.DotNet.ProjectModel.Server.Tests
{
    public class DthTestClient : IDisposable
    {
        private readonly string _hostId;
        private readonly BinaryReader _reader;
        private readonly BinaryWriter _writer;
        private readonly NetworkStream _networkStream;
        private readonly BlockingCollection<DthMessage> _messageQueue;
        private readonly CancellationTokenSource _readCancellationToken;

        // Keeps track of initialized project contexts
        // REVIEW: This needs to be exposed if we ever create 2 clients in order to simulate how build
        // works in visual studio
        private readonly Dictionary<string, int> _projectContexts = new Dictionary<string, int>();
        private int _nextContextId;
        private readonly Socket _socket;

        public DthTestClient(DthTestServer server)
        {
            // Avoid Socket exception 10006 on Linux
            Thread.Sleep(100);
            
            _socket = new Socket(AddressFamily.InterNetwork,
                                 SocketType.Stream,
                                 ProtocolType.Tcp);

            _socket.Connect(new IPEndPoint(IPAddress.Loopback, server.Port));

            _hostId = server.HostId;

            _networkStream = new NetworkStream(_socket);
            _reader = new BinaryReader(_networkStream);
            _writer = new BinaryWriter(_networkStream);

            _messageQueue = new BlockingCollection<DthMessage>();

            _readCancellationToken = new CancellationTokenSource();
            Task.Run(() => ReadMessage(_readCancellationToken.Token), _readCancellationToken.Token);
        }

        public void SendPayLoad(Project project, string messageType)
        {
            SendPayLoad(project.ProjectDirectory, messageType);
        }

        public void SendPayLoad(string projectPath, string messageType)
        {
            int contextId;
            if (!_projectContexts.TryGetValue(projectPath, out contextId))
            {
                Assert.True(false, $"Unable to resolve context for {projectPath}");
            }

            SendPayLoad(contextId, messageType);
        }

        public void SendPayLoad(int contextId, string messageType)
        {
            SendPayLoad(contextId, messageType, new { });
        }

        public void SendPayLoad(int contextId, string messageType, object payload)
        {
            lock (_writer)
            {
                var message = new
                {
                    ContextId = contextId,
                    HostId = _hostId,
                    MessageType = messageType,
                    Payload = payload
                };
                _writer.Write(JsonConvert.SerializeObject(message));
            }
        }

        public int Initialize(string projectPath)
        {
            var contextId = _nextContextId++;

            _projectContexts[projectPath] = contextId;
            SendPayLoad(contextId, MessageTypes.Initialize, new { ProjectFolder = projectPath });

            return contextId;
        }

        public int Initialize(string projectPath, int protocolVersion)
        {
            var contextId = _nextContextId++;

            _projectContexts[projectPath] = contextId;
            SendPayLoad(contextId, MessageTypes.Initialize, new { ProjectFolder = projectPath, Version = protocolVersion });

            return contextId;
        }

        public int Initialize(string projectPath, int protocolVersion, string configuration)
        {
            var contextId = _nextContextId++;

            _projectContexts[projectPath] = contextId;
            SendPayLoad(contextId, MessageTypes.Initialize, new { ProjectFolder = projectPath, Version = protocolVersion, Configuration = configuration });

            return contextId;
        }

        public void SetProtocolVersion(int version)
        {
            SendPayLoad(0, MessageTypes.ProtocolVersion, new { Version = version });
        }

        public List<DthMessage> DrainMessage(int count)
        {
            var result = new List<DthMessage>();
            while (count > 0)
            {
                result.Add(GetResponse(timeout: TimeSpan.FromSeconds(10)));
                count--;
            }

            return result;
        }

        public List<DthMessage> DrainAllMessages()
        {
            return DrainAllMessages(TimeSpan.FromSeconds(10));
        }

        /// <summary>
        /// Read all messages from pipeline till timeout
        /// </summary>
        /// <param name="timeout">The timeout</param>
        /// <returns>All the messages in a list</returns>
        public List<DthMessage> DrainAllMessages(TimeSpan timeout)
        {
            var result = new List<DthMessage>();
            while (true)
            {
                try
                {
                    result.Add(GetResponse(timeout));
                }
                catch (TimeoutException)
                {
                    return result;
                }
                catch (Exception)
                {
                    throw;
                }
            }
        }

        /// <summary>
        /// Read messages from pipeline until the first match
        /// </summary>]
        /// <param name="type">A message type</param>
        /// <returns>The first match message</returns>
        public DthMessage DrainTillFirst(string type)
        {
            return DrainTillFirst(type, TimeSpan.FromSeconds(10));
        }

        /// <summary>
        /// Read messages from pipeline until the first match
        /// </summary>
        /// <param name="type">A message type</param>
        /// <param name="timeout">Timeout for each read</param>
        /// <returns>The first match message</returns>
        public DthMessage DrainTillFirst(string type, TimeSpan timeout)
        {
            while (true)
            {
                var next = GetResponse(timeout);
                if (next.MessageType == type)
                {
                    return next;
                }
            }
        }

        /// <summary>
        /// Read messages from pipeline until the first match
        /// </summary>
        /// <param name="type">A message type</param>
        /// <param name="timeout">Timeout</param>
        /// <param name="leadingMessages">All the messages read before the first match</param>
        /// <returns>The first match</returns>
        public DthMessage DrainTillFirst(string type, TimeSpan timeout, out List<DthMessage> leadingMessages)
        {
            leadingMessages = new List<DthMessage>();
            while (true)
            {
                var next = GetResponse(timeout);
                if (next.MessageType == type)
                {
                    return next;
                }
                else
                {
                    leadingMessages.Add(next);
                }
            }
        }

        public void Dispose()
        {
            _reader.Dispose();
            _writer.Dispose();
            _networkStream.Dispose();
            _readCancellationToken.Cancel();

            try
            {
                _socket.Shutdown(SocketShutdown.Both);
            }
            catch (SocketException ex)
            {
                // Swallow this error for now.
                // This is a temporary fix for a random failure on CI. The issue happens on Windowx x86
                // only.
                Console.Error.WriteLine($"Exception thrown durning socket shutting down: {ex.SocketErrorCode}.");
            }
        }

        private void ReadMessage(CancellationToken cancellationToken)
        {
            while (true)
            {
                try
                {
                    if (cancellationToken.IsCancellationRequested)
                    {
                        return;
                    }

                    var content = _reader.ReadString();
                    var message = JsonConvert.DeserializeObject<DthMessage>(content);

                    _messageQueue.Add(message);
                }
                catch (IOException)
                {
                    // swallow
                }
                catch (JsonSerializationException deserializException)
                {
                    throw new InvalidOperationException(
                        $"Fail to deserailze data into {nameof(DthMessage)}.",
                        deserializException);
                }
                catch (Exception ex)
                {
                    throw ex;
                }
            }
        }

        private DthMessage GetResponse(TimeSpan timeout)
        {
            DthMessage message;

            if (_messageQueue.TryTake(out message, timeout))
            {
                return message;
            }
            else
            {
                throw new TimeoutException($"Response time out after {timeout.TotalSeconds} seconds.");
            }
        }
    }
}