Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import {
Common as LocCommon,
Azure as LocAzure,
Fabric as LocFabric,
refreshTokenLabel,
} from "../constants/locConstants";
import * as LocAll from "../constants/locConstants";
import {
Expand All @@ -50,9 +49,8 @@ import {
import { sendActionEvent, sendErrorEvent, startActivity } from "../telemetry/telemetry";

import { ApiStatus } from "../sharedInterfaces/webview";
import { AzureController } from "../azure/azureController";
import { AzureSubscription } from "@microsoft/vscode-azext-azureauth";
import { ConnectionDetails, IConnectionInfo, IToken } from "vscode-mssql";
import { ConnectionDetails, IConnectionInfo } from "vscode-mssql";
import MainController from "../controllers/mainController";
import { ObjectExplorerProvider } from "../objectExplorer/objectExplorerProvider";
import { UserSurvey } from "../nps/userSurvey";
Expand All @@ -62,7 +60,7 @@ import {
getServerTypes,
getDefaultConnection,
} from "../models/connectionInfo";
import { formatEpochSecondsForDisplay, getErrorMessage, uuid } from "../utils/utils";
import { getErrorMessage, uuid } from "../utils/utils";
import { l10n } from "vscode";
import {
CredentialsQuickPickItemType,
Expand Down Expand Up @@ -1967,7 +1965,6 @@ export class ConnectionDialogWebviewController extends FormWebviewController<
}

private async getAzureActionButtons(): Promise<FormItemActionButton[]> {
const self = this;
const actionButtons: FormItemActionButton[] = [];

actionButtons.push({
Expand Down Expand Up @@ -2032,112 +2029,6 @@ export class ConnectionDialogWebviewController extends FormWebviewController<
},
});

if (previewService.isFeatureEnabled(PreviewFeature.UseVscodeAccountsForEntraMFA)) {
return actionButtons;
}

if (
this.state.connectionProfile.authenticationType === AuthenticationType.AzureMFA &&
this.state.connectionProfile.accountId
) {
const account = (await this._mainController.azureAccountService.getAccounts()).find(
(account) => account.displayInfo.userId === this.state.connectionProfile.accountId,
);

if (account) {
let isTokenExpired = false;

async function refreshToken(): Promise<IToken | undefined> {
const account = (
await self._mainController.azureAccountService.getAccounts()
).find(
(account) =>
account.displayInfo.userId === self.state.connectionProfile.accountId,
);

if (account) {
try {
const token =
await self._mainController.azureAccountService.getAccountSecurityToken(
account,
undefined,
);

if (AzureController.isTokenValid(token.token, token.expiresOn)) {
self.vscodeWrapper.showInformationMessage(
Loc.tokenRefreshedSuccessfully,
);

self.logger.log(
`Token refreshed. Next expiration: ${formatEpochSecondsForDisplay(token.expiresOn)}`,
);

return token;
} else {
throw new Error(
Loc.unableToAcquireValidToken(
formatEpochSecondsForDisplay(token.expiresOn),
formatEpochSecondsForDisplay(Date.now() / 1000),
),
);
}
} catch (err) {
self.logger.error(`Error refreshing token: ${getErrorMessage(err)}`);
self.vscodeWrapper.showErrorMessage(
Loc.errorRefreshingToken(getErrorMessage(err)),
);
}
} else {
self.logger.error(
`Account not found when attempting token refresh: ${self.state.connectionProfile.email} (${self.state.connectionProfile.accountId})`,
);
}

return undefined;
}

try {
// Check if token is expired or expiring soon...
const session =
await this._mainController.azureAccountService.getAccountSecurityToken(
account,
undefined,
);

isTokenExpired = !AzureController.isTokenValid(
session.token,
session.expiresOn,
);
} catch (err) {
this.logger.verbose(
`Error getting token or checking validity; prompting for refresh. Error: ${getErrorMessage(err)}`,
);

void this.vscodeWrapper
.showErrorMessage(
Loc.errorValidatingEntraToken(getErrorMessage(err)),
refreshTokenLabel,
)
.then((result) => {
if (result === refreshTokenLabel) {
void refreshToken();
}
});

isTokenExpired = true;
}

if (isTokenExpired) {
actionButtons.push({
label: refreshTokenLabel,
id: "refreshToken",
callback: async () => {
await refreshToken();
},
});
}
}
}
return actionButtons;
}

Expand Down
194 changes: 33 additions & 161 deletions extensions/mssql/src/controllers/connectionManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ import {
TelemetryActions,
TelemetryViews,
} from "../sharedInterfaces/telemetry";
import { ApiStatus, isStatus, Status } from "../sharedInterfaces/webview";
import { ObjectExplorerUtils } from "../objectExplorer/objectExplorerUtils";
import { changeLanguageServiceForFile } from "../languageservice/utils";
import { AddFirewallRuleWebviewController } from "./addFirewallRuleWebviewController";
Expand Down Expand Up @@ -126,11 +125,6 @@ export default class ConnectionManager {
Deferred<ConnectionContracts.ConnectionCompleteParams>
>;
private _keyVaultTokenCache: Map<string, IToken> = new Map<string, IToken>();
private _entraSqlTokenCache: Map<string, IToken> = new Map<string, IToken>();
private _entraSqlTokenRefreshInFlight: Map<string, Promise<IToken>> = new Map<
string,
Promise<IToken>
>();
private _accountService: AccountService;
private _firewallService: FirewallService;
public azureController: AzureController;
Expand Down Expand Up @@ -1047,19 +1041,21 @@ export default class ConnectionManager {
return;
}

// 2. Validate that the token needs refreshing (isn't expired)
if (
AzureController.isTokenValid(connectionInfo.azureAccountToken, connectionInfo.expiresOn)
) {
this._logger?.verbose(
`Entra token for account ${connectionInfo.user} (${connectionInfo.email}) is still valid until ${connectionInfo.expiresOn}. No refresh needed.`,
);
return;
}

// 3. Refresh the token
// A3. If the user is using VS Code accounts for Entra MFA, use that flow to refresh the token
// 2. If the user is using VS Code accounts for Entra MFA, use that flow to refresh the token.
// STS cannot read VS Code auth sessions, so this path still needs to pass a token.
if (previewService.isFeatureEnabled(PreviewFeature.UseVscodeAccountsForEntraMFA)) {
if (
AzureController.isTokenValid(
connectionInfo.azureAccountToken,
connectionInfo.expiresOn,
)
) {
this._logger?.verbose(
`Entra token for account ${connectionInfo.user} (${connectionInfo.email}) is still valid until ${connectionInfo.expiresOn}. No refresh needed.`,
);
return;
}

const tokenInfo = await acquireSqlAccessTokenFromVscodeAccount(
connectionInfo.accountId,
connectionInfo.tenantId,
Expand All @@ -1076,14 +1072,12 @@ export default class ConnectionManager {
return;
}

// B3. Otherwise, use the MSAL flow to refresh the token
// B3.1 Collect Entra account information
// 3. Otherwise, use the MSAL flow. STS registers a SqlAuthenticationProvider
// and reads the shared MSAL cache, so do not pass a pre-acquired SQL token.
let account: IAccount | undefined;
let profile: ConnectionProfile;

if (connectionInfo.accountId) {
account = await this.accountStore.getAccount(connectionInfo.accountId);
profile = new ConnectionProfile(connectionInfo);
} else {
// Send telemetry to identify code paths where accountId is missing
sendErrorEvent(
Expand All @@ -1109,150 +1103,29 @@ export default class ConnectionManager {
//LocalizedConstants.msgAccountNotFound
}

connectionInfo.user = account.displayInfo.displayName;
connectionInfo.user =
account.displayInfo.email ??
account.displayInfo.userId ??
account.displayInfo.displayName;
connectionInfo.email = account.displayInfo.email;
profile.user = account.displayInfo.displayName;
profile.email = account.displayInfo.email;
connectionInfo.tenantId ??= account.properties?.owningTenant?.id;

// B4. Use cached token if present and valid/unexpired
const cacheKey = this.getEntraSqlTokenCacheKey(
connectionInfo,
account.properties?.owningTenant?.id,
// Keep the MSAL account cache fresh before handing SQL token acquisition to STS.
// This may prompt for sign-in if the account can no longer refresh silently,
// but the returned non-SQL token is intentionally not sent to SqlClient.
await this.azureController.refreshAccessToken(
account,
this.accountStore,
connectionInfo.tenantId,
getCloudProviderSettings(account.key.providerId).settings.armResource,
Comment on lines +1116 to +1120
);
const cachedToken = this._entraSqlTokenCache.get(cacheKey);

connectionInfo.azureAccountToken = undefined;
connectionInfo.expiresOn = undefined;
Comment thread
aasimkhan30 marked this conversation as resolved.

this._logger?.verbose(
`Cached token ${cachedToken ? "found" : "not found"} for cache key ${cacheKey}.`,
`Using SQL Authentication Provider for MSAL Entra account ${connectionInfo.user} and tenant ${connectionInfo.tenantId}.`,
);

if (cachedToken) {
// If there's a cached token, use it if still valid, or remove it from cache if expired
if (AzureController.isTokenValid(cachedToken.token, cachedToken.expiresOn)) {
this.applyEntraToken(connectionInfo, cachedToken);
this._logger?.verbose(
`Using cached Entra token for account ${account.displayInfo.displayName} (${account.displayInfo.email}) and tenant ${profile.tenantId}. Cached token expires on ${cachedToken.expiresOn}. (currently ${Date.now() / 1000})`,
);

return;
} else {
this._logger?.verbose(
`Cached token for cache key ${cacheKey} is expired. Removing from cache. (currently ${Date.now() / 1000})`,
);
this._entraSqlTokenCache.delete(cacheKey);
}
}

// B5. Lastly, refresh the token, cache the new token, and update the connection info with it
const refreshTask = async () => {
return await this.azureController.refreshAccessToken(
account,
this.accountStore,
profile.tenantId,
getCloudProviderSettings(account.key.providerId).settings.sqlResource!,
);
};

// Dedupe concurrent token refresh requests for the same account into a single request, and share the result
let refreshPromise = this._entraSqlTokenRefreshInFlight.get(cacheKey);
if (!refreshPromise) {
// Token refresh code cannot figure out if the user closed the browser window,
// so we wrap it in a cancellable progress dialog to allow the user to cancel
// the operation.
refreshPromise = new Promise<IToken>((resolve, reject) => {
void vscode.window.withProgress(
{
location: vscode.ProgressLocation.Notification,
title: LocalizedConstants.ObjectExplorer.AzureSignInMessage(
account.displayInfo.displayName || account.displayInfo.email,
),
cancellable: true,
},
async (_progress, token) => {
token.onCancellationRequested(() => {
reject({
status: ApiStatus.Cancelled,
message: "Azure sign in cancelled by user.",
} as Status);
});
try {
const refreshedToken = await refreshTask();
if (!refreshedToken) {
reject({
status: ApiStatus.Error,
message: LocalizedConstants.msgAccountRefreshFailed(),
} as Status);
return;
}
this._logger?.verbose(
`Successfully refreshed Entra token for account ${account.displayInfo.displayName} (${account.displayInfo.email}) and tenant ${profile.tenantId}; now expires on ${refreshedToken.expiresOn} (currently ${Date.now() / 1000}).`,
);
resolve(refreshedToken);
} catch (error) {
const refreshErrorStatus: Status = {
status: ApiStatus.Error,
message: getErrorMessage(error),
};
this._logger?.error(
`Error refreshing Entra token for account ${account.displayInfo.displayName} (${account.displayInfo.email}) and tenant ${profile.tenantId}: ${refreshErrorStatus.message}`,
);
reject(refreshErrorStatus);
}
},
);
}).finally(() => {
this._entraSqlTokenRefreshInFlight.delete(cacheKey);
});
this._entraSqlTokenRefreshInFlight.set(cacheKey, refreshPromise);
}

try {
const azureAccountToken = await refreshPromise;
this.applyEntraToken(connectionInfo, azureAccountToken);
// Save refreshed token so other connections for the same account+tenant can reuse it.
this._entraSqlTokenCache.set(cacheKey, azureAccountToken);
this._logger?.verbose(
`Successfully refreshed Entra token for account ${account.displayInfo.displayName} (${account.displayInfo.email}) and tenant ${profile.tenantId}. Cached token for future use with cache key ${cacheKey}.`,
);
} catch (error) {
this._logger?.verbose(
`Failed to refresh Entra token for account ${account.displayInfo.displayName} (${account.displayInfo.email}) and tenant ${profile.tenantId}. Error: ${getErrorMessage(error)}`,
);
if (isStatus(error)) {
if (error.status === ApiStatus.Cancelled) {
this._logger.verbose("Refresh cancelled: " + error.message);
throw new Error(LocalizedConstants.cannotConnect);
}

if (error.status === ApiStatus.Error) {
const message = LocalizedConstants.msgAccountRefreshFailed(error.message);
this._logger.error("Error refreshing account: " + message);
await this.vscodeWrapper.showErrorMessage(message);
throw new Error(message);
}
}

throw error;
}
}

private getEntraSqlTokenCacheKey(
connectionInfo: IConnectionInfo,
defaultTenantId?: string,
): string {
return `${connectionInfo.accountId ?? ""}|${connectionInfo.tenantId ?? defaultTenantId ?? ""}`;
}

private applyEntraToken(connectionInfo: IConnectionInfo, token: IToken): void {
connectionInfo.azureAccountToken = token.token;
connectionInfo.expiresOn = token.expiresOn;
}

/**
* Clears both token entries and any in-flight refresh promises.
*/
private clearEntraSqlTokenCache(): void {
this._entraSqlTokenCache.clear();
this._entraSqlTokenRefreshInFlight.clear();
}

/**
Expand Down Expand Up @@ -2033,7 +1906,6 @@ export default class ConnectionManager {

public onClearAzureTokenCache(): void {
this.azureController.clearTokenCache();
this.clearEntraSqlTokenCache();
this.vscodeWrapper.showInformationMessage(
LocalizedConstants.Accounts.clearedEntraTokenCache,
);
Expand Down
Loading
Loading