diff --git a/src/Microsoft.DotNet.Cli.Utils/BlockingMemoryStream.cs b/src/Microsoft.DotNet.Cli.Utils/BlockingMemoryStream.cs new file mode 100644 index 000000000..e80061b01 --- /dev/null +++ b/src/Microsoft.DotNet.Cli.Utils/BlockingMemoryStream.cs @@ -0,0 +1,81 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Concurrent; +using System.IO; +using System.Threading; + +namespace Microsoft.DotNet.Cli.Utils +{ + /// + /// An in-memory stream that will block any read calls until something was written to it. + /// + public sealed class BlockingMemoryStream : Stream + { + private readonly BlockingCollection _buffers = new BlockingCollection(); + private ArraySegment _remaining; + + public override void Write(byte[] buffer, int offset, int count) + { + byte[] tmp = new byte[count]; + Buffer.BlockCopy(buffer, offset, tmp, 0, count); + _buffers.Add(tmp); + } + + public override int Read(byte[] buffer, int offset, int count) + { + if (count == 0) + { + return 0; + } + + if (_remaining.Count == 0) + { + byte[] tmp; + if (!_buffers.TryTake(out tmp, Timeout.Infinite) || tmp.Length == 0) + { + return 0; + } + _remaining = new ArraySegment(tmp, 0, tmp.Length); + } + + if (_remaining.Count <= count) + { + count = _remaining.Count; + Buffer.BlockCopy(_remaining.Array, _remaining.Offset, buffer, offset, count); + _remaining = default(ArraySegment); + } + else + { + Buffer.BlockCopy(_remaining.Array, _remaining.Offset, buffer, offset, count); + _remaining = new ArraySegment(_remaining.Array, _remaining.Offset + count, _remaining.Count - count); + } + return count; + } + + public void DoneWriting() + { + _buffers.CompleteAdding(); + } + + protected override void Dispose(bool disposing) + { + if (disposing) + { + _buffers.Dispose(); + } + + base.Dispose(disposing); + } + + public override bool CanRead => true; + public override bool CanSeek => false; + public override bool CanWrite => true; + public override long Length { get { throw new NotImplementedException(); } } + public override long Position { get { throw new NotImplementedException(); } set { throw new NotImplementedException(); } } + public override void Flush() { } + public override long Seek(long offset, SeekOrigin origin) { throw new NotImplementedException(); } + public override void SetLength(long value) { throw new NotImplementedException(); } + } +} diff --git a/src/Microsoft.DotNet.Cli.Utils/BuiltInCommand.cs b/src/Microsoft.DotNet.Cli.Utils/BuiltInCommand.cs new file mode 100644 index 000000000..abef5ceba --- /dev/null +++ b/src/Microsoft.DotNet.Cli.Utils/BuiltInCommand.cs @@ -0,0 +1,141 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Threading; + +namespace Microsoft.DotNet.Cli.Utils +{ + /// + /// A Command that is capable of running in the current process. + /// + public class BuiltInCommand : ICommand + { + private readonly IEnumerable _commandArgs; + private readonly Func _builtInCommand; + private readonly StreamForwarder _stdOut; + private readonly StreamForwarder _stdErr; + + public string CommandName { get; } + public string CommandArgs => string.Join(" ", _commandArgs); + + public BuiltInCommand(string commandName, IEnumerable commandArgs, Func builtInCommand) + { + CommandName = commandName; + _commandArgs = commandArgs; + _builtInCommand = builtInCommand; + + _stdOut = new StreamForwarder(); + _stdErr = new StreamForwarder(); + } + + public CommandResult Execute() + { + TextWriter originalConsoleOut = Console.Out; + TextWriter originalConsoleError = Console.Error; + + try + { + // redirecting the standard out and error so we can forward + // the output to the caller + using (BlockingMemoryStream outStream = new BlockingMemoryStream()) + using (BlockingMemoryStream errorStream = new BlockingMemoryStream()) + { + Console.SetOut(new StreamWriter(outStream) { AutoFlush = true }); + Console.SetError(new StreamWriter(errorStream) { AutoFlush = true }); + + // Reset the Reporters to the new Console Out and Error. + Reporter.Reset(); + + Thread threadOut = _stdOut.BeginRead(new StreamReader(outStream)); + Thread threadErr = _stdErr.BeginRead(new StreamReader(errorStream)); + + int exitCode = _builtInCommand(_commandArgs.ToArray()); + + outStream.DoneWriting(); + errorStream.DoneWriting(); + + threadOut.Join(); + threadErr.Join(); + + // fake out a ProcessStartInfo using the Muxer command name, since this is a built-in command + ProcessStartInfo startInfo = new ProcessStartInfo(new Muxer().MuxerPath, $"{CommandName} {CommandArgs}"); + return new CommandResult(startInfo, exitCode, null, null); + } + } + finally + { + Console.SetOut(originalConsoleOut); + Console.SetError(originalConsoleError); + + Reporter.Reset(); + } + } + + public ICommand OnOutputLine(Action handler) + { + if (handler == null) + { + throw new ArgumentNullException(nameof(handler)); + } + + _stdOut.ForwardTo(writeLine: handler); + + return this; + } + + public ICommand OnErrorLine(Action handler) + { + if (handler == null) + { + throw new ArgumentNullException(nameof(handler)); + } + + _stdErr.ForwardTo(writeLine: handler); + + return this; + } + + public CommandResolutionStrategy ResolutionStrategy + { + get + { + throw new NotImplementedException(); + } + } + + public ICommand CaptureStdErr() + { + throw new NotImplementedException(); + } + + public ICommand CaptureStdOut() + { + throw new NotImplementedException(); + } + + public ICommand EnvironmentVariable(string name, string value) + { + throw new NotImplementedException(); + } + + public ICommand ForwardStdErr(TextWriter to = null, bool onlyIfVerbose = false, bool ansiPassThrough = true) + { + throw new NotImplementedException(); + } + + public ICommand ForwardStdOut(TextWriter to = null, bool onlyIfVerbose = false, bool ansiPassThrough = true) + { + throw new NotImplementedException(); + } + + public ICommand WorkingDirectory(string projectDirectory) + { + throw new NotImplementedException(); + } + } +} diff --git a/src/Microsoft.DotNet.Cli.Utils/Reporter.cs b/src/Microsoft.DotNet.Cli.Utils/Reporter.cs index 8dcbd55fa..9b5984773 100644 --- a/src/Microsoft.DotNet.Cli.Utils/Reporter.cs +++ b/src/Microsoft.DotNet.Cli.Utils/Reporter.cs @@ -14,16 +14,34 @@ namespace Microsoft.DotNet.Cli.Utils private readonly AnsiConsole _console; + static Reporter() + { + Reset(); + } + private Reporter(AnsiConsole console) { _console = console; } - public static Reporter Output { get; } = new Reporter(AnsiConsole.GetOutput()); - public static Reporter Error { get; } = new Reporter(AnsiConsole.GetError()); - public static Reporter Verbose { get; } = CommandContext.IsVerbose() ? - new Reporter(AnsiConsole.GetOutput()) : - NullReporter; + public static Reporter Output { get; private set; } + public static Reporter Error { get; private set; } + public static Reporter Verbose { get; private set; } + + /// + /// Resets the Reporters to write to the current Console Out/Error. + /// + public static void Reset() + { + lock (_lock) + { + Output = new Reporter(AnsiConsole.GetOutput()); + Error = new Reporter(AnsiConsole.GetError()); + Verbose = CommandContext.IsVerbose() ? + new Reporter(AnsiConsole.GetOutput()) : + NullReporter; + } + } public void WriteLine(string message) { diff --git a/src/Microsoft.DotNet.Cli.Utils/DotNetCommandFactory.cs b/src/dotnet/DotNetCommandFactory.cs similarity index 51% rename from src/Microsoft.DotNet.Cli.Utils/DotNetCommandFactory.cs rename to src/dotnet/DotNetCommandFactory.cs index b22edae3f..d7f8eba62 100644 --- a/src/Microsoft.DotNet.Cli.Utils/DotNetCommandFactory.cs +++ b/src/dotnet/DotNetCommandFactory.cs @@ -1,10 +1,13 @@ // Copyright (c) .NET Foundation and contributors. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System; using System.Collections.Generic; +using System.Diagnostics; +using Microsoft.DotNet.Cli.Utils; using NuGet.Frameworks; -namespace Microsoft.DotNet.Cli.Utils +namespace Microsoft.DotNet.Cli { public class DotNetCommandFactory : ICommandFactory { @@ -14,6 +17,15 @@ namespace Microsoft.DotNet.Cli.Utils NuGetFramework framework = null, string configuration = Constants.DefaultConfiguration) { + Func builtInCommand; + if (Program.TryGetBuiltInCommand(commandName, out builtInCommand)) + { + Debug.Assert(framework == null, "BuiltInCommand doesn't support the 'framework' argument."); + Debug.Assert(configuration == Constants.DefaultConfiguration, "BuiltInCommand doesn't support the 'configuration' argument."); + + return new BuiltInCommand(commandName, args, builtInCommand); + } + return Command.CreateDotNet(commandName, args, framework, configuration); } } diff --git a/src/dotnet/Program.cs b/src/dotnet/Program.cs index 9d212a2f1..0b360d313 100644 --- a/src/dotnet/Program.cs +++ b/src/dotnet/Program.cs @@ -6,26 +6,37 @@ using System.Collections.Generic; using System.IO; using System.Linq; using Microsoft.DotNet.Cli.Utils; -using Microsoft.DotNet.Tools.Test; -using Microsoft.Extensions.PlatformAbstractions; -using NuGet.Frameworks; using Microsoft.DotNet.ProjectModel.Server; using Microsoft.DotNet.Tools.Build; using Microsoft.DotNet.Tools.Compiler; using Microsoft.DotNet.Tools.Compiler.Csc; -using Microsoft.DotNet.Tools.Compiler.Native; using Microsoft.DotNet.Tools.Help; using Microsoft.DotNet.Tools.New; using Microsoft.DotNet.Tools.Publish; -using Microsoft.DotNet.Tools.Repl; -using Microsoft.DotNet.Tools.Resgen; using Microsoft.DotNet.Tools.Restore; using Microsoft.DotNet.Tools.Run; +using Microsoft.DotNet.Tools.Test; +using Microsoft.Extensions.PlatformAbstractions; +using NuGet.Frameworks; namespace Microsoft.DotNet.Cli { public class Program { + private static Dictionary> s_builtIns = new Dictionary> + { + ["build"] = BuildCommand.Run, + ["compile-csc"] = CompileCscCommand.Run, + ["help"] = HelpCommand.Run, + ["new"] = NewCommand.Run, + ["pack"] = PackCommand.Run, + ["projectmodel-server"] = ProjectModelServerCommand.Run, + ["publish"] = PublishCommand.Run, + ["restore"] = RestoreCommand.Run, + ["run"] = RunCommand.Run, + ["test"] = TestCommand.Run + }; + public static int Main(string[] args) { DebugHelper.HandleDebugSwitch(ref args); @@ -100,23 +111,9 @@ namespace Microsoft.DotNet.Cli command = "help"; } - var builtIns = new Dictionary> - { - ["build"] = BuildCommand.Run, - ["compile-csc"] = CompileCscCommand.Run, - ["help"] = HelpCommand.Run, - ["new"] = NewCommand.Run, - ["pack"] = PackCommand.Run, - ["projectmodel-server"] = ProjectModelServerCommand.Run, - ["publish"] = PublishCommand.Run, - ["restore"] = RestoreCommand.Run, - ["run"] = RunCommand.Run, - ["test"] = TestCommand.Run - }; - int exitCode; Func builtIn; - if (builtIns.TryGetValue(command, out builtIn)) + if (s_builtIns.TryGetValue(command, out builtIn)) { exitCode = builtIn(appArgs.ToArray()); } @@ -141,6 +138,11 @@ namespace Microsoft.DotNet.Cli } + internal static bool TryGetBuiltInCommand(string commandName, out Func builtInCommand) + { + return s_builtIns.TryGetValue(commandName, out builtInCommand); + } + private static void PrintVersion() { Reporter.Output.WriteLine(Product.Version); diff --git a/src/dotnet/commands/dotnet-compile/ManagedCompiler.cs b/src/dotnet/commands/dotnet-compile/ManagedCompiler.cs index a21b1be51..1a9007b1e 100644 --- a/src/dotnet/commands/dotnet-compile/ManagedCompiler.cs +++ b/src/dotnet/commands/dotnet-compile/ManagedCompiler.cs @@ -1,11 +1,14 @@ // Copyright (c) .NET Foundation and contributors. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System; using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Linq; using System.Runtime.InteropServices; +using System.Text; +using Microsoft.DotNet.Cli; using Microsoft.DotNet.Cli.Compiler.Common; using Microsoft.DotNet.Cli.Utils; using Microsoft.DotNet.ProjectModel; @@ -81,7 +84,7 @@ namespace Microsoft.DotNet.Tools.Compiler var compilationOptions = context.ResolveCompilationOptions(args.ConfigValue); // Set default platform if it isn't already set and we're on desktop - if(compilationOptions.EmitEntryPoint == true && string.IsNullOrEmpty(compilationOptions.Platform) && context.TargetFramework.IsDesktop()) + if (compilationOptions.EmitEntryPoint == true && string.IsNullOrEmpty(compilationOptions.Platform) && context.TargetFramework.IsDesktop()) { // See https://github.com/dotnet/cli/issues/2428 for more details. compilationOptions.Platform = RuntimeInformation.ProcessArchitecture == Architecture.X64 ? @@ -181,31 +184,15 @@ namespace Microsoft.DotNet.Tools.Compiler _scriptRunner.RunScripts(context, ScriptNames.PreCompile, contextVariables); - var result = _commandFactory.Create($"compile-{compilerName}", new[] { "@" + $"{rsp}" }) - .OnErrorLine(line => - { - var diagnostic = ParseDiagnostic(context.ProjectDirectory, line); - if (diagnostic != null) - { - diagnostics.Add(diagnostic); - } - else - { - Reporter.Error.WriteLine(line); - } - }) - .OnOutputLine(line => - { - var diagnostic = ParseDiagnostic(context.ProjectDirectory, line); - if (diagnostic != null) - { - diagnostics.Add(diagnostic); - } - else - { - Reporter.Output.WriteLine(line); - } - }).Execute(); + // Cache the reporters before invoking the command in case it is a built-in command, which replaces + // the static Reporter instances. + Reporter errorReporter = Reporter.Error; + Reporter outputReporter = Reporter.Output; + + CommandResult result = _commandFactory.Create($"compile-{compilerName}", new[] { $"@{rsp}" }) + .OnErrorLine(line => HandleCompilerOutputLine(line, context, diagnostics, errorReporter)) + .OnOutputLine(line => HandleCompilerOutputLine(line, context, diagnostics, outputReporter)) + .Execute(); // Run post-compile event contextVariables["compile:CompilerExitCode"] = result.ExitCode.ToString(); @@ -225,5 +212,18 @@ namespace Microsoft.DotNet.Tools.Compiler return PrintSummary(diagnostics, sw, success); } + + private static void HandleCompilerOutputLine(string line, ProjectContext context, List diagnostics, Reporter reporter) + { + var diagnostic = ParseDiagnostic(context.ProjectDirectory, line); + if (diagnostic != null) + { + diagnostics.Add(diagnostic); + } + else + { + reporter.WriteLine(line); + } + } } } diff --git a/src/dotnet/commands/dotnet-compile/Program.cs b/src/dotnet/commands/dotnet-compile/Program.cs index 5e10eac2e..c7fe08c65 100644 --- a/src/dotnet/commands/dotnet-compile/Program.cs +++ b/src/dotnet/commands/dotnet-compile/Program.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; +using Microsoft.DotNet.Cli; using Microsoft.DotNet.Cli.Utils; namespace Microsoft.DotNet.Tools.Compiler diff --git a/test/Microsoft.DotNet.Cli.Utils.Tests/BlockingMemoryStreamTests.cs b/test/Microsoft.DotNet.Cli.Utils.Tests/BlockingMemoryStreamTests.cs new file mode 100644 index 000000000..bb2b2c1d4 --- /dev/null +++ b/test/Microsoft.DotNet.Cli.Utils.Tests/BlockingMemoryStreamTests.cs @@ -0,0 +1,111 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Threading; +using Microsoft.DotNet.Tools.Test.Utilities; +using Xunit; + +namespace Microsoft.DotNet.Cli.Utils +{ + public class BlockingMemoryStreamTests : TestBase + { + /// + /// Tests reading a bigger buffer than what is available. + /// + [Fact] + public void ReadBiggerBuffer() + { + using (var stream = new BlockingMemoryStream()) + { + stream.Write(new byte[] { 1, 2, 3 }, 0, 3); + + byte[] buffer = new byte[10]; + int count = stream.Read(buffer, 0, buffer.Length); + Assert.Equal(3, count); + Assert.Equal(1, buffer[0]); + Assert.Equal(2, buffer[1]); + Assert.Equal(3, buffer[2]); + } + } + + /// + /// Tests reading smaller buffers than what is available. + /// + [Fact] + public void ReadSmallerBuffers() + { + using (var stream = new BlockingMemoryStream()) + { + stream.Write(new byte[] { 1, 2, 3, 4 }, 0, 4); + stream.Write(new byte[] { 5, 6, 7, 8, 9 }, 0, 5); + + byte[] buffer = new byte[3]; + + int count = stream.Read(buffer, 0, buffer.Length); + Assert.Equal(3, count); + Assert.Equal(1, buffer[0]); + Assert.Equal(2, buffer[1]); + Assert.Equal(3, buffer[2]); + + count = stream.Read(buffer, 0, buffer.Length); + Assert.Equal(1, count); + Assert.Equal(4, buffer[0]); + + count = stream.Read(buffer, 0, buffer.Length); + Assert.Equal(3, count); + Assert.Equal(5, buffer[0]); + Assert.Equal(6, buffer[1]); + Assert.Equal(7, buffer[2]); + + count = stream.Read(buffer, 0, buffer.Length); + Assert.Equal(2, count); + Assert.Equal(8, buffer[0]); + Assert.Equal(9, buffer[1]); + } + } + + /// + /// Tests reading will block until the stream is written to. + /// + [Fact] + public void TestReadBlocksUntilWrite() + { + using (var stream = new BlockingMemoryStream()) + { + ManualResetEvent readerThreadExecuting = new ManualResetEvent(false); + bool readerThreadSuccessful = false; + + Thread readerThread = new Thread(() => + { + byte[] buffer = new byte[10]; + readerThreadExecuting.Set(); + int count = stream.Read(buffer, 0, buffer.Length); + + Assert.Equal(3, count); + Assert.Equal(1, buffer[0]); + Assert.Equal(2, buffer[1]); + Assert.Equal(3, buffer[2]); + + readerThreadSuccessful = true; + }); + + readerThread.IsBackground = true; + readerThread.Start(); + + // ensure the thread is executing + readerThreadExecuting.WaitOne(); + + Assert.True(readerThread.IsAlive); + + // give it a little while to ensure it is blocking + Thread.Sleep(10); + Assert.True(readerThread.IsAlive); + + stream.Write(new byte[] { 1, 2, 3 }, 0, 3); + + Assert.True(readerThread.Join(1000)); + Assert.True(readerThreadSuccessful); + } + } + } +} diff --git a/test/Microsoft.DotNet.Cli.Utils.Tests/BuiltInCommandTests.cs b/test/Microsoft.DotNet.Cli.Utils.Tests/BuiltInCommandTests.cs new file mode 100644 index 000000000..78184781f --- /dev/null +++ b/test/Microsoft.DotNet.Cli.Utils.Tests/BuiltInCommandTests.cs @@ -0,0 +1,89 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Linq; +using Microsoft.DotNet.Tools.Test.Utilities; +using Xunit; + +namespace Microsoft.DotNet.Cli.Utils +{ + public class BuiltInCommandTests : TestBase + { + /// + /// Tests that BuiltInCommand.Execute returns the correct exit code and a + /// valid StartInfo FileName and Arguments. + /// + [Fact] + public void TestExecute() + { + Func testCommand = args => args.Length; + string[] testCommandArgs = new[] { "1", "2" }; + + var builtInCommand = new BuiltInCommand("fakeCommand", testCommandArgs, testCommand); + CommandResult result = builtInCommand.Execute(); + + Assert.Equal(testCommandArgs.Length, result.ExitCode); + Assert.Equal(new Muxer().MuxerPath, result.StartInfo.FileName); + Assert.Equal("fakeCommand 1 2", result.StartInfo.Arguments); + } + + /// + /// Tests that BuiltInCommand.Execute raises the OnOutputLine and OnErrorLine + /// the correct number of times and with the correct content. + /// + [Fact] + public void TestOnOutputLines() + { + int exitCode = 29; + + Func testCommand = args => + { + Console.Out.Write("first"); + Console.Out.WriteLine("second"); + Console.Out.WriteLine("third"); + + Console.Error.WriteLine("fourth"); + Console.Error.WriteLine("fifth"); + + return exitCode; + }; + + int onOutputLineCallCount = 0; + int onErrorLineCallCount = 0; + + CommandResult result = new BuiltInCommand("fakeCommand", Enumerable.Empty(), testCommand) + .OnOutputLine(line => + { + onOutputLineCallCount++; + + if (onOutputLineCallCount == 1) + { + Assert.Equal($"firstsecond", line); + } + else + { + Assert.Equal($"third", line); + } + }) + .OnErrorLine(line => + { + onErrorLineCallCount++; + + if (onErrorLineCallCount == 1) + { + Assert.Equal($"fourth", line); + } + else + { + Assert.Equal($"fifth", line); + } + }) + .Execute(); + + Assert.Equal(exitCode, result.ExitCode); + Assert.Equal(2, onOutputLineCallCount); + Assert.Equal(2, onErrorLineCallCount); + } + } +}