// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. using Microsoft.Build.Framework; using Microsoft.Build.Utilities; using System; using System.Collections.Generic; using System.Linq; using System.Net; using System.Net.Http; using System.Net.Http.Headers; using System.Security.Cryptography; using System.Text; using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; namespace Microsoft.DotNet.Build.CloudTestTasks { public static class AzureHelper { /// /// The storage api version. /// public static readonly string StorageApiVersion = "2015-04-05"; public const string DateHeaderString = "x-ms-date"; public const string VersionHeaderString = "x-ms-version"; public const string AuthorizationHeaderString = "Authorization"; public enum SasAccessType { Read, Write, }; public static string AuthorizationHeader( string storageAccount, string storageKey, string method, DateTime now, HttpRequestMessage request, string ifMatch = "", string contentMD5 = "", string size = "", string contentType = "") { string stringToSign = string.Format( "{0}\n\n\n{1}\n{5}\n{6}\n\n\n{2}\n\n\n\n{3}{4}", method, (size == string.Empty) ? string.Empty : size, ifMatch, GetCanonicalizedHeaders(request), GetCanonicalizedResource(request.RequestUri, storageAccount), contentMD5, contentType); byte[] signatureBytes = Encoding.UTF8.GetBytes(stringToSign); string authorizationHeader; using (HMACSHA256 hmacsha256 = new HMACSHA256(Convert.FromBase64String(storageKey))) { authorizationHeader = "SharedKey " + storageAccount + ":" + Convert.ToBase64String(hmacsha256.ComputeHash(signatureBytes)); } return authorizationHeader; } public static string CreateContainerSasToken( string accountName, string containerName, string key, SasAccessType accessType, int validityTimeInDays) { string signedPermissions = string.Empty; switch (accessType) { case SasAccessType.Read: signedPermissions = "r"; break; case SasAccessType.Write: signedPermissions = "wdl"; break; default: throw new ArgumentOutOfRangeException(nameof(accessType), accessType, "Unrecognized value"); } string signedStart = DateTime.UtcNow.ToString("O"); string signedExpiry = DateTime.UtcNow.AddDays(validityTimeInDays).ToString("O"); string canonicalizedResource = "/blob/" + accountName + "/" + containerName; string signedIdentifier = string.Empty; string signedVersion = StorageApiVersion; string stringToSign = ConstructServiceStringToSign( signedPermissions, signedVersion, signedExpiry, canonicalizedResource, signedIdentifier, signedStart); byte[] signatureBytes = Encoding.UTF8.GetBytes(stringToSign); string signature; using (HMACSHA256 hmacSha256 = new HMACSHA256(Convert.FromBase64String(key))) { signature = Convert.ToBase64String(hmacSha256.ComputeHash(signatureBytes)); } string sasToken = string.Format( "?sv={0}&sr={1}&sig={2}&st={3}&se={4}&sp={5}", WebUtility.UrlEncode(signedVersion), WebUtility.UrlEncode("c"), WebUtility.UrlEncode(signature), WebUtility.UrlEncode(signedStart), WebUtility.UrlEncode(signedExpiry), WebUtility.UrlEncode(signedPermissions)); return sasToken; } public static string GetCanonicalizedHeaders(HttpRequestMessage request) { StringBuilder sb = new StringBuilder(); List headerNameList = (from headerName in request.Headers where headerName.Key.ToLowerInvariant() .StartsWith("x-ms-", StringComparison.Ordinal) select headerName.Key.ToLowerInvariant()).ToList(); headerNameList.Sort(); foreach (string headerName in headerNameList) { StringBuilder builder = new StringBuilder(headerName); string separator = ":"; foreach (string headerValue in GetHeaderValues(request.Headers, headerName)) { string trimmedValue = headerValue.Replace("\r\n", string.Empty); builder.Append(separator); builder.Append(trimmedValue); separator = ","; } sb.Append(builder); sb.Append("\n"); } return sb.ToString(); } public static string GetCanonicalizedResource(Uri address, string accountName) { StringBuilder str = new StringBuilder(); StringBuilder builder = new StringBuilder("/"); builder.Append(accountName); builder.Append(address.AbsolutePath); str.Append(builder); Dictionary> queryKeyValues = ExtractQueryKeyValues(address); Dictionary> dictionary = GetCommaSeparatedList(queryKeyValues); foreach (KeyValuePair> pair in dictionary.OrderBy(p => p.Key)) { StringBuilder stringBuilder = new StringBuilder(string.Empty); stringBuilder.Append(pair.Key + ":"); string commaList = string.Join(",", pair.Value); stringBuilder.Append(commaList); str.Append("\n"); str.Append(stringBuilder); } return str.ToString(); } public static List GetHeaderValues(HttpRequestHeaders headers, string headerName) { List list = new List(); IEnumerable values; headers.TryGetValues(headerName, out values); if (values != null) { list.Add((values.FirstOrDefault() ?? string.Empty).TrimStart(null)); } return list; } private static bool IsWithinRetryRange(HttpStatusCode statusCode) { // Retry on http client and server error codes (4xx - 5xx) as well as redirect var rawStatus = (int)statusCode; if (rawStatus == 302) return true; else if (rawStatus >= 400 && rawStatus <= 599) return true; else return false; } public static async Task RequestWithRetry(TaskLoggingHelper loggingHelper, HttpClient client, Func createRequest, Func validationCallback = null, int retryCount = 5, int retryDelaySeconds = 5) { if (loggingHelper == null) throw new ArgumentNullException(nameof(loggingHelper)); if (client == null) throw new ArgumentNullException(nameof(client)); if (createRequest == null) throw new ArgumentNullException(nameof(createRequest)); if (retryCount < 1) throw new ArgumentException(nameof(retryCount)); if (retryDelaySeconds < 1) throw new ArgumentException(nameof(retryDelaySeconds)); int retries = 0; HttpResponseMessage response = null; // add a bit of randomness to the retry delay var rng = new Random(); while (retries < retryCount) { if (retries > 0) { if (response != null) { response.Dispose(); response = null; } int delay = retryDelaySeconds * retries * rng.Next(1, 5); loggingHelper.LogMessage(MessageImportance.Low, "Waiting {0} seconds before retry", delay); await System.Threading.Tasks.Task.Delay(delay * 1000); } try { using (var request = createRequest()) response = await client.SendAsync(request); } catch (Exception e) { loggingHelper.LogWarningFromException(e, true); // if this is the final iteration let the exception bubble up if (retries + 1 == retryCount) throw; } // response can be null if we fail to send the request if (response != null) { if (validationCallback == null) { // check if the response code is within the range of failures if (IsWithinRetryRange(response.StatusCode)) { loggingHelper.LogWarning("Request failed with status code {0}", response.StatusCode); } else { loggingHelper.LogMessage(MessageImportance.Low, "Response completed with status code {0}", response.StatusCode); return response; } } else { bool isSuccess = validationCallback(response); if (!isSuccess) { loggingHelper.LogMessage("Validation callback returned retry for status code {0}", response.StatusCode); } else { loggingHelper.LogMessage("Validation callback returned success for status code {0}", response.StatusCode); return response; } } } ++retries; } // retry count exceeded loggingHelper.LogWarning("Retry count {0} exceeded", retryCount); // set some default values in case response is null var statusCode = "None"; var contentStr = "Null"; if (response != null) { statusCode = response.StatusCode.ToString(); contentStr = await response.Content.ReadAsStringAsync(); response.Dispose(); } throw new HttpRequestException(string.Format("Request failed with status {0} response {1}", statusCode, contentStr)); } private static string ConstructServiceStringToSign( string signedPermissions, string signedVersion, string signedExpiry, string canonicalizedResource, string signedIdentifier, string signedStart, string signedIP = "", string signedProtocol = "", string rscc = "", string rscd = "", string rsce = "", string rscl = "", string rsct = "") { // constructing string to sign based on spec in https://msdn.microsoft.com/en-us/library/azure/dn140255.aspx var stringToSign = string.Join( "\n", signedPermissions, signedStart, signedExpiry, canonicalizedResource, signedIdentifier, signedIP, signedProtocol, signedVersion, rscc, rscd, rsce, rscl, rsct); return stringToSign; } private static Dictionary> ExtractQueryKeyValues(Uri address) { Dictionary> values = new Dictionary>(); //Decode this to allow the regex to pull out the correct groups for signing address = new Uri(WebUtility.UrlDecode(address.ToString())); Regex newreg = new Regex(@"\?(\w+)\=([\w|\=]+)|\&(\w+)\=([\w|\=]+)"); MatchCollection matches = newreg.Matches(address.Query); foreach (Match match in matches) { string key, value; if (!string.IsNullOrEmpty(match.Groups[1].Value)) { key = match.Groups[1].Value; value = match.Groups[2].Value; } else { key = match.Groups[3].Value; value = match.Groups[4].Value; } HashSet setOfValues; if (values.TryGetValue(key, out setOfValues)) { setOfValues.Add(value); } else { HashSet newSet = new HashSet { value }; values.Add(key, newSet); } } return values; } private static Dictionary> GetCommaSeparatedList( Dictionary> queryKeyValues) { Dictionary> dictionary = new Dictionary>(); foreach (string queryKeys in queryKeyValues.Keys) { HashSet setOfValues; queryKeyValues.TryGetValue(queryKeys, out setOfValues); List list = new List(); list.AddRange(setOfValues); list.Sort(); string commaSeparatedValues = string.Join(",", list); string key = queryKeys.ToLowerInvariant(); HashSet setOfValues2; if (dictionary.TryGetValue(key, out setOfValues2)) { setOfValues2.Add(commaSeparatedValues); } else { HashSet newSet = new HashSet { commaSeparatedValues }; dictionary.Add(key, newSet); } } return dictionary; } } }