// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; using System.Net; using System.Net.Sockets; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; using Newtonsoft.Json; using Xunit; namespace Microsoft.DotNet.ProjectModel.Server.Tests { public class DthTestClient : IDisposable { private readonly string _hostId; private readonly BinaryReader _reader; private readonly BinaryWriter _writer; private readonly NetworkStream _networkStream; private readonly ILogger _logger; private readonly BlockingCollection _messageQueue; private readonly CancellationTokenSource _readCancellationToken; // Keeps track of initialized project contexts // REVIEW: This needs to be exposed if we ever create 2 clients in order to simulate how build // works in visual studio private readonly Dictionary _projectContexts = new Dictionary(); private int _nextContextId; private readonly Socket _socket; public DthTestClient(DthTestServer server, ILoggerFactory loggerFactory) { // Avoid Socket exception 10006 on Linux Thread.Sleep(100); _logger = loggerFactory.CreateLogger(); _socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); _socket.Connect(new IPEndPoint(IPAddress.Loopback, server.Port)); _hostId = server.HostId; _networkStream = new NetworkStream(_socket); _reader = new BinaryReader(_networkStream); _writer = new BinaryWriter(_networkStream); _messageQueue = new BlockingCollection(); _readCancellationToken = new CancellationTokenSource(); Task.Run(() => ReadMessage(_readCancellationToken.Token), _readCancellationToken.Token); } public void SendPayLoad(Project project, string messageType) { SendPayLoad(project.ProjectDirectory, messageType); } public void SendPayLoad(string projectPath, string messageType) { int contextId; if (!_projectContexts.TryGetValue(projectPath, out contextId)) { Assert.True(false, $"Unable to resolve context for {projectPath}"); } SendPayLoad(contextId, messageType); } public void SendPayLoad(int contextId, string messageType) { SendPayLoad(contextId, messageType, new { }); } public void SendPayLoad(int contextId, string messageType, object payload) { lock (_writer) { var message = new { ContextId = contextId, HostId = _hostId, MessageType = messageType, Payload = payload }; _writer.Write(JsonConvert.SerializeObject(message)); } } public int Initialize(string projectPath) { var contextId = _nextContextId++; _projectContexts[projectPath] = contextId; SendPayLoad(contextId, MessageTypes.Initialize, new { ProjectFolder = projectPath }); return contextId; } public int Initialize(string projectPath, int protocolVersion) { var contextId = _nextContextId++; _projectContexts[projectPath] = contextId; SendPayLoad(contextId, MessageTypes.Initialize, new { ProjectFolder = projectPath, Version = protocolVersion }); return contextId; } public int Initialize(string projectPath, int protocolVersion, string configuration) { var contextId = _nextContextId++; _projectContexts[projectPath] = contextId; SendPayLoad(contextId, MessageTypes.Initialize, new { ProjectFolder = projectPath, Version = protocolVersion, Configuration = configuration }); return contextId; } public void SetProtocolVersion(int version) { SendPayLoad(0, MessageTypes.ProtocolVersion, new { Version = version }); } public List DrainMessage(int count) { var result = new List(); while (count > 0) { result.Add(GetResponse(timeout: TimeSpan.FromSeconds(10))); count--; } return result; } public List DrainAllMessages() { return DrainAllMessages(TimeSpan.FromSeconds(10)); } /// /// Read all messages from pipeline till timeout /// /// The timeout /// All the messages in a list public List DrainAllMessages(TimeSpan timeout) { var result = new List(); while (true) { try { result.Add(GetResponse(timeout)); } catch (TimeoutException) { return result; } catch (Exception) { throw; } } } /// /// Read messages from pipeline until the first match /// ] /// A message type /// The first match message public DthMessage DrainTillFirst(string type) { return DrainTillFirst(type, TimeSpan.FromSeconds(10)); } /// /// Read messages from pipeline until the first match /// /// A message type /// Timeout for each read /// The first match message public DthMessage DrainTillFirst(string type, TimeSpan timeout) { while (true) { var next = GetResponse(timeout); if (next.MessageType == type) { return next; } } } /// /// Read messages from pipeline until the first match /// /// A message type /// Timeout /// All the messages read before the first match /// The first match public DthMessage DrainTillFirst(string type, TimeSpan timeout, out List leadingMessages) { leadingMessages = new List(); while (true) { var next = GetResponse(timeout); if (next.MessageType == type) { return next; } else { leadingMessages.Add(next); } } } public void Dispose() { _reader.Dispose(); _writer.Dispose(); _networkStream.Dispose(); _readCancellationToken.Cancel(); try { _socket.Shutdown(SocketShutdown.Both); } catch (SocketException ex) { // Swallow this error for now. // This is a temporary fix for a random failure on CI. The issue happens on Windowx x86 // only. _logger.LogError($"Exception thrown durning socket shutting down: {ex.SocketErrorCode}."); } } private void ReadMessage(CancellationToken cancellationToken) { while (true) { try { if (cancellationToken.IsCancellationRequested) { return; } var content = _reader.ReadString(); var message = JsonConvert.DeserializeObject(content); _messageQueue.Add(message); } catch (IOException) { // swallow } catch (JsonSerializationException deserializException) { throw new InvalidOperationException( $"Fail to deserailze data into {nameof(DthMessage)}.", deserializException); } catch (Exception ex) { throw ex; } } } private DthMessage GetResponse(TimeSpan timeout) { DthMessage message; if (_messageQueue.TryTake(out message, timeout)) { return message; } else { throw new TimeoutException($"Response time out after {timeout.TotalSeconds} seconds."); } } } }