fix: use BlockedRequest struct to handle webRequest data (#42750)

* refactor: use BlockedRequest model to handle webRequest

Co-authored-by: Shelley Vohr <shelley.vohr@gmail.com>

* refactor: finish de-templating

Co-authored-by: Shelley Vohr <shelley.vohr@gmail.com>

* chore: address some feedback from review

Co-authored-by: Shelley Vohr <shelley.vohr@gmail.com>

---------

Co-authored-by: trop[bot] <37223003+trop[bot]@users.noreply.github.com>
Co-authored-by: Shelley Vohr <shelley.vohr@gmail.com>
This commit is contained in:
trop[bot] 2024-07-03 13:03:09 +02:00 committed by GitHub
parent de6e6b60bc
commit 57e859d0af
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 304 additions and 111 deletions

View file

@ -185,37 +185,41 @@ void FillDetails(gin_helper::Dictionary* details, Arg arg, Args... args) {
FillDetails(details, args...); FillDetails(details, args...);
} }
// Fill the native types with the result from the response object. // Modified from extensions/browser/api/web_request/web_request_api_helpers.cc.
void ReadFromResponse(v8::Isolate* isolate, std::pair<std::set<std::string>, std::set<std::string>>
gin::Dictionary* response, CalculateOnBeforeSendHeadersDelta(const net::HttpRequestHeaders* old_headers,
GURL* new_location) { const net::HttpRequestHeaders* new_headers) {
response->Get("redirectURL", new_location); // Newly introduced or overridden request headers.
} std::set<std::string> modified_request_headers;
// Keys of request headers to be deleted.
std::set<std::string> deleted_request_headers;
void ReadFromResponse(v8::Isolate* isolate, // The event listener might not have passed any new headers if it
gin::Dictionary* response, // just wanted to cancel the request.
net::HttpRequestHeaders* headers) { if (new_headers) {
v8::Local<v8::Value> value; // Find deleted headers.
if (response->Get("requestHeaders", &value) && value->IsObject()) { {
headers->Clear(); net::HttpRequestHeaders::Iterator i(*old_headers);
gin::Converter<net::HttpRequestHeaders>::FromV8(isolate, value, headers); while (i.GetNext()) {
} if (!new_headers->HasHeader(i.name())) {
} deleted_request_headers.insert(i.name());
}
}
}
void ReadFromResponse(v8::Isolate* isolate, // Find modified headers.
gin::Dictionary* response, {
const std::pair<scoped_refptr<net::HttpResponseHeaders>*, net::HttpRequestHeaders::Iterator i(*new_headers);
const std::string>& headers) { while (i.GetNext()) {
std::string status_line; std::string value;
if (!response->Get("statusLine", &status_line)) if (!old_headers->GetHeader(i.name(), &value) || i.value() != value) {
status_line = headers.second; modified_request_headers.insert(i.name());
v8::Local<v8::Value> value; }
if (response->Get("responseHeaders", &value) && value->IsObject()) { }
*headers.first = new net::HttpResponseHeaders(""); }
(*headers.first)->ReplaceStatusLine(status_line);
gin::Converter<net::HttpResponseHeaders*>::FromV8(isolate, value,
(*headers.first).get());
} }
return std::make_pair(modified_request_headers, deleted_request_headers);
} }
} // namespace } // namespace
@ -260,6 +264,24 @@ bool WebRequest::RequestFilter::MatchesRequest(
return MatchesURL(info->url) && MatchesType(info->web_request_type); return MatchesURL(info->url) && MatchesType(info->web_request_type);
} }
struct WebRequest::BlockedRequest {
BlockedRequest() = default;
raw_ptr<const extensions::WebRequestInfo> request = nullptr;
net::CompletionOnceCallback callback;
// Only used for onBeforeSendHeaders.
BeforeSendHeadersCallback before_send_headers_callback;
// Only used for onBeforeSendHeaders.
raw_ptr<net::HttpRequestHeaders> request_headers = nullptr;
// Only used for onHeadersReceived.
scoped_refptr<const net::HttpResponseHeaders> original_response_headers;
// Only used for onHeadersReceived.
raw_ptr<scoped_refptr<net::HttpResponseHeaders>> override_response_headers =
nullptr;
std::string status_line;
// Only used for onBeforeRequest.
raw_ptr<GURL> new_url = nullptr;
};
WebRequest::SimpleListenerInfo::SimpleListenerInfo(RequestFilter filter_, WebRequest::SimpleListenerInfo::SimpleListenerInfo(RequestFilter filter_,
SimpleListener listener_) SimpleListener listener_)
: filter(std::move(filter_)), listener(listener_) {} : filter(std::move(filter_)), listener(listener_) {}
@ -320,19 +342,152 @@ int WebRequest::OnBeforeRequest(extensions::WebRequestInfo* info,
const network::ResourceRequest& request, const network::ResourceRequest& request,
net::CompletionOnceCallback callback, net::CompletionOnceCallback callback,
GURL* new_url) { GURL* new_url) {
return HandleResponseEvent(ResponseEvent::kOnBeforeRequest, info, return HandleOnBeforeRequestResponseEvent(info, request, std::move(callback),
std::move(callback), new_url, request); new_url);
}
int WebRequest::HandleOnBeforeRequestResponseEvent(
extensions::WebRequestInfo* request_info,
const network::ResourceRequest& request,
net::CompletionOnceCallback callback,
GURL* new_url) {
const auto iter = response_listeners_.find(ResponseEvent::kOnBeforeRequest);
if (iter == std::end(response_listeners_))
return net::OK;
const auto& info = iter->second;
if (!info.filter.MatchesRequest(request_info))
return net::OK;
BlockedRequest blocked_request;
blocked_request.callback = std::move(callback);
blocked_request.new_url = new_url;
blocked_requests_[request_info->id] = std::move(blocked_request);
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope handle_scope(isolate);
gin_helper::Dictionary details(isolate, v8::Object::New(isolate));
FillDetails(&details, request_info, request, *new_url);
ResponseCallback response =
base::BindOnce(&WebRequest::OnBeforeRequestListenerResult,
base::Unretained(this), request_info->id);
info.listener.Run(gin::ConvertToV8(isolate, details), std::move(response));
return net::ERR_IO_PENDING;
}
void WebRequest::OnBeforeRequestListenerResult(uint64_t id,
v8::Local<v8::Value> response) {
const auto iter = blocked_requests_.find(id);
if (iter == std::end(blocked_requests_))
return;
auto& request = iter->second;
int result = net::OK;
if (response->IsObject()) {
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
gin::Dictionary dict(isolate, response.As<v8::Object>());
bool cancel = false;
dict.Get("cancel", &cancel);
if (cancel) {
result = net::ERR_BLOCKED_BY_CLIENT;
} else {
dict.Get("redirectURL", request.new_url.get());
}
}
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(std::move(request.callback), result));
blocked_requests_.erase(iter);
} }
int WebRequest::OnBeforeSendHeaders(extensions::WebRequestInfo* info, int WebRequest::OnBeforeSendHeaders(extensions::WebRequestInfo* info,
const network::ResourceRequest& request, const network::ResourceRequest& request,
BeforeSendHeadersCallback callback, BeforeSendHeadersCallback callback,
net::HttpRequestHeaders* headers) { net::HttpRequestHeaders* headers) {
return HandleResponseEvent( return HandleOnBeforeSendHeadersResponseEvent(info, request,
ResponseEvent::kOnBeforeSendHeaders, info, std::move(callback), headers);
base::BindOnce(std::move(callback), std::set<std::string>(), }
std::set<std::string>()),
headers, request, *headers); int WebRequest::HandleOnBeforeSendHeadersResponseEvent(
extensions::WebRequestInfo* request_info,
const network::ResourceRequest& request,
BeforeSendHeadersCallback callback,
net::HttpRequestHeaders* headers) {
const auto iter =
response_listeners_.find(ResponseEvent::kOnBeforeSendHeaders);
if (iter == std::end(response_listeners_))
return net::OK;
const auto& info = iter->second;
if (!info.filter.MatchesRequest(request_info))
return net::OK;
BlockedRequest blocked_request;
blocked_request.before_send_headers_callback = std::move(callback);
blocked_request.request_headers = headers;
blocked_requests_[request_info->id] = std::move(blocked_request);
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope handle_scope(isolate);
gin_helper::Dictionary details(isolate, v8::Object::New(isolate));
FillDetails(&details, request_info, request, *headers);
ResponseCallback response =
base::BindOnce(&WebRequest::OnBeforeSendHeadersListenerResult,
base::Unretained(this), request_info->id);
info.listener.Run(gin::ConvertToV8(isolate, details), std::move(response));
return net::ERR_IO_PENDING;
}
void WebRequest::OnBeforeSendHeadersListenerResult(
uint64_t id,
v8::Local<v8::Value> response) {
const auto iter = blocked_requests_.find(id);
if (iter == std::end(blocked_requests_))
return;
auto& request = iter->second;
net::HttpRequestHeaders* old_headers = request.request_headers;
net::HttpRequestHeaders new_headers;
int result = net::OK;
bool user_modified_headers = false;
if (response->IsObject()) {
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
gin::Dictionary dict(isolate, response.As<v8::Object>());
bool cancel = false;
dict.Get("cancel", &cancel);
if (cancel) {
result = net::ERR_BLOCKED_BY_CLIENT;
} else {
v8::Local<v8::Value> value;
if (dict.Get("requestHeaders", &value) && value->IsObject()) {
user_modified_headers = true;
gin::Converter<net::HttpRequestHeaders>::FromV8(isolate, value,
&new_headers);
}
}
}
// If the user passes |cancel|, |new_headers| should be nullptr.
const auto updated_headers = CalculateOnBeforeSendHeadersDelta(
old_headers,
result == net::ERR_BLOCKED_BY_CLIENT ? nullptr : &new_headers);
// Leave |request.request_headers| unchanged if the user didn't modify it.
if (user_modified_headers)
request.request_headers->Swap(&new_headers);
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(std::move(request.before_send_headers_callback),
updated_headers.first, updated_headers.second, result));
blocked_requests_.erase(iter);
} }
int WebRequest::OnHeadersReceived( int WebRequest::OnHeadersReceived(
@ -342,12 +497,86 @@ int WebRequest::OnHeadersReceived(
const net::HttpResponseHeaders* original_response_headers, const net::HttpResponseHeaders* original_response_headers,
scoped_refptr<net::HttpResponseHeaders>* override_response_headers, scoped_refptr<net::HttpResponseHeaders>* override_response_headers,
GURL* allowed_unsafe_redirect_url) { GURL* allowed_unsafe_redirect_url) {
const std::string& status_line = return HandleOnHeadersReceivedResponseEvent(
original_response_headers ? original_response_headers->GetStatusLine() info, request, std::move(callback), original_response_headers,
: std::string(); override_response_headers);
return HandleResponseEvent( }
ResponseEvent::kOnHeadersReceived, info, std::move(callback),
std::make_pair(override_response_headers, status_line), request); int WebRequest::HandleOnHeadersReceivedResponseEvent(
extensions::WebRequestInfo* request_info,
const network::ResourceRequest& request,
net::CompletionOnceCallback callback,
const net::HttpResponseHeaders* original_response_headers,
scoped_refptr<net::HttpResponseHeaders>* override_response_headers) {
const auto iter = response_listeners_.find(ResponseEvent::kOnHeadersReceived);
if (iter == std::end(response_listeners_))
return net::OK;
const auto& info = iter->second;
if (!info.filter.MatchesRequest(request_info))
return net::OK;
BlockedRequest blocked_request;
blocked_request.callback = std::move(callback);
blocked_request.override_response_headers = override_response_headers;
blocked_request.status_line = original_response_headers
? original_response_headers->GetStatusLine()
: std::string();
blocked_requests_[request_info->id] = std::move(blocked_request);
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope handle_scope(isolate);
gin_helper::Dictionary details(isolate, v8::Object::New(isolate));
FillDetails(&details, request_info, request);
ResponseCallback response =
base::BindOnce(&WebRequest::OnHeadersReceivedListenerResult,
base::Unretained(this), request_info->id);
info.listener.Run(gin::ConvertToV8(isolate, details), std::move(response));
return net::ERR_IO_PENDING;
}
void WebRequest::OnHeadersReceivedListenerResult(
uint64_t id,
v8::Local<v8::Value> response) {
const auto iter = blocked_requests_.find(id);
if (iter == std::end(blocked_requests_))
return;
auto& request = iter->second;
int result = net::OK;
bool user_modified_headers = false;
scoped_refptr<net::HttpResponseHeaders> override_headers(
new net::HttpResponseHeaders(""));
if (response->IsObject()) {
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
gin::Dictionary dict(isolate, response.As<v8::Object>());
bool cancel = false;
dict.Get("cancel", &cancel);
if (cancel) {
result = net::ERR_BLOCKED_BY_CLIENT;
} else {
std::string status_line;
if (!dict.Get("statusLine", &status_line))
status_line = request.status_line;
v8::Local<v8::Value> value;
if (dict.Get("responseHeaders", &value) && value->IsObject()) {
user_modified_headers = true;
override_headers->ReplaceStatusLine(status_line);
gin::Converter<net::HttpResponseHeaders*>::FromV8(
isolate, value, override_headers.get());
}
}
}
if (user_modified_headers)
request.override_response_headers->swap(override_headers);
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(std::move(request.callback), result));
blocked_requests_.erase(iter);
} }
void WebRequest::OnSendHeaders(extensions::WebRequestInfo* info, void WebRequest::OnSendHeaders(extensions::WebRequestInfo* info,
@ -371,7 +600,7 @@ void WebRequest::OnResponseStarted(extensions::WebRequestInfo* info,
void WebRequest::OnErrorOccurred(extensions::WebRequestInfo* info, void WebRequest::OnErrorOccurred(extensions::WebRequestInfo* info,
const network::ResourceRequest& request, const network::ResourceRequest& request,
int net_error) { int net_error) {
callbacks_.erase(info->id); blocked_requests_.erase(info->id);
HandleSimpleEvent(SimpleEvent::kOnErrorOccurred, info, request, net_error); HandleSimpleEvent(SimpleEvent::kOnErrorOccurred, info, request, net_error);
} }
@ -379,13 +608,13 @@ void WebRequest::OnErrorOccurred(extensions::WebRequestInfo* info,
void WebRequest::OnCompleted(extensions::WebRequestInfo* info, void WebRequest::OnCompleted(extensions::WebRequestInfo* info,
const network::ResourceRequest& request, const network::ResourceRequest& request,
int net_error) { int net_error) {
callbacks_.erase(info->id); blocked_requests_.erase(info->id);
HandleSimpleEvent(SimpleEvent::kOnCompleted, info, request, net_error); HandleSimpleEvent(SimpleEvent::kOnCompleted, info, request, net_error);
} }
void WebRequest::OnRequestWillBeDestroyed(extensions::WebRequestInfo* info) { void WebRequest::OnRequestWillBeDestroyed(extensions::WebRequestInfo* info) {
callbacks_.erase(info->id); blocked_requests_.erase(info->id);
} }
template <WebRequest::SimpleEvent event> template <WebRequest::SimpleEvent event>
@ -479,62 +708,6 @@ void WebRequest::HandleSimpleEvent(SimpleEvent event,
info.listener.Run(gin::ConvertToV8(isolate, details)); info.listener.Run(gin::ConvertToV8(isolate, details));
} }
template <typename Out, typename... Args>
int WebRequest::HandleResponseEvent(ResponseEvent event,
extensions::WebRequestInfo* request_info,
net::CompletionOnceCallback callback,
Out out,
Args... args) {
const auto iter = response_listeners_.find(event);
if (iter == std::end(response_listeners_))
return net::OK;
const auto& info = iter->second;
if (!info.filter.MatchesRequest(request_info))
return net::OK;
callbacks_[request_info->id] = std::move(callback);
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope handle_scope(isolate);
gin_helper::Dictionary details(isolate, v8::Object::New(isolate));
FillDetails(&details, request_info, args...);
ResponseCallback response =
base::BindOnce(&WebRequest::OnListenerResult<Out>, base::Unretained(this),
request_info->id, out);
info.listener.Run(gin::ConvertToV8(isolate, details), std::move(response));
return net::ERR_IO_PENDING;
}
template <typename T>
void WebRequest::OnListenerResult(uint64_t id,
T out,
v8::Local<v8::Value> response) {
const auto iter = callbacks_.find(id);
if (iter == std::end(callbacks_))
return;
int result = net::OK;
if (response->IsObject()) {
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
gin::Dictionary dict(isolate, response.As<v8::Object>());
bool cancel = false;
dict.Get("cancel", &cancel);
if (cancel)
result = net::ERR_BLOCKED_BY_CLIENT;
else
ReadFromResponse(isolate, &dict, out);
}
// The ProxyingURLLoaderFactory expects the callback to be executed
// asynchronously, because it used to work on IO thread before NetworkService.
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(std::move(callbacks_[id]), result));
callbacks_.erase(iter);
}
// static // static
gin::Handle<WebRequest> WebRequest::FromOrCreate( gin::Handle<WebRequest> WebRequest::FromOrCreate(
v8::Isolate* isolate, v8::Isolate* isolate,

View file

@ -84,6 +84,10 @@ class WebRequest : public gin::Wrappable<WebRequest>, public WebRequestAPI {
WebRequest(v8::Isolate* isolate, content::BrowserContext* browser_context); WebRequest(v8::Isolate* isolate, content::BrowserContext* browser_context);
~WebRequest() override; ~WebRequest() override;
// Contains info about requests that are blocked waiting for a response from
// the user.
struct BlockedRequest;
enum class SimpleEvent { enum class SimpleEvent {
kOnSendHeaders, kOnSendHeaders,
kOnBeforeRedirect, kOnBeforeRedirect,
@ -91,6 +95,7 @@ class WebRequest : public gin::Wrappable<WebRequest>, public WebRequestAPI {
kOnCompleted, kOnCompleted,
kOnErrorOccurred, kOnErrorOccurred,
}; };
enum class ResponseEvent { enum class ResponseEvent {
kOnBeforeRequest, kOnBeforeRequest,
kOnBeforeSendHeaders, kOnBeforeSendHeaders,
@ -113,15 +118,30 @@ class WebRequest : public gin::Wrappable<WebRequest>, public WebRequestAPI {
void HandleSimpleEvent(SimpleEvent event, void HandleSimpleEvent(SimpleEvent event,
extensions::WebRequestInfo* info, extensions::WebRequestInfo* info,
Args... args); Args... args);
template <typename Out, typename... Args>
int HandleResponseEvent(ResponseEvent event,
extensions::WebRequestInfo* info,
net::CompletionOnceCallback callback,
Out out,
Args... args);
template <typename T> int HandleOnBeforeRequestResponseEvent(
void OnListenerResult(uint64_t id, T out, v8::Local<v8::Value> response); extensions::WebRequestInfo* info,
const network::ResourceRequest& request,
net::CompletionOnceCallback callback,
GURL* redirect_url);
int HandleOnBeforeSendHeadersResponseEvent(
extensions::WebRequestInfo* info,
const network::ResourceRequest& request,
BeforeSendHeadersCallback callback,
net::HttpRequestHeaders* headers);
int HandleOnHeadersReceivedResponseEvent(
extensions::WebRequestInfo* info,
const network::ResourceRequest& request,
net::CompletionOnceCallback callback,
const net::HttpResponseHeaders* original_response_headers,
scoped_refptr<net::HttpResponseHeaders>* override_response_headers);
void OnBeforeRequestListenerResult(uint64_t id,
v8::Local<v8::Value> response);
void OnBeforeSendHeadersListenerResult(uint64_t id,
v8::Local<v8::Value> response);
void OnHeadersReceivedListenerResult(uint64_t id,
v8::Local<v8::Value> response);
class RequestFilter { class RequestFilter {
public: public:
@ -164,7 +184,7 @@ class WebRequest : public gin::Wrappable<WebRequest>, public WebRequestAPI {
std::map<SimpleEvent, SimpleListenerInfo> simple_listeners_; std::map<SimpleEvent, SimpleListenerInfo> simple_listeners_;
std::map<ResponseEvent, ResponseListenerInfo> response_listeners_; std::map<ResponseEvent, ResponseListenerInfo> response_listeners_;
std::map<uint64_t, net::CompletionOnceCallback> callbacks_; std::map<uint64_t, BlockedRequest> blocked_requests_;
// Weak-ref, it manages us. // Weak-ref, it manages us.
raw_ptr<content::BrowserContext> browser_context_; raw_ptr<content::BrowserContext> browser_context_;

View file

@ -328,7 +328,7 @@ describe('webRequest module', () => {
ses.webRequest.onBeforeSendHeaders((details, callback) => { ses.webRequest.onBeforeSendHeaders((details, callback) => {
const requestHeaders = details.requestHeaders; const requestHeaders = details.requestHeaders;
requestHeaders.Accept = '*/*;test/header'; requestHeaders.Accept = '*/*;test/header';
callback({ requestHeaders: requestHeaders }); callback({ requestHeaders });
}); });
const { data } = await ajax('no-cors://fake-host/redirect'); const { data } = await ajax('no-cors://fake-host/redirect');
expect(data).to.equal('header-received'); expect(data).to.equal('header-received');
@ -341,7 +341,7 @@ describe('webRequest module', () => {
ses.webRequest.onBeforeSendHeaders((details, callback) => { ses.webRequest.onBeforeSendHeaders((details, callback) => {
const requestHeaders = details.requestHeaders; const requestHeaders = details.requestHeaders;
requestHeaders.Origin = 'http://new-origin'; requestHeaders.Origin = 'http://new-origin';
callback({ requestHeaders: requestHeaders }); callback({ requestHeaders });
}); });
const { data } = await ajax(defaultURL); const { data } = await ajax(defaultURL);
expect(data).to.equal('/new/origin'); expect(data).to.equal('/new/origin');