Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

auth azure: support access token #3580

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 61 additions & 5 deletions lisa/sut_orchestrator/azure/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import base64
import hashlib
import json
import os
Expand All @@ -26,6 +27,7 @@

import requests
from assertpy import assert_that
from azure.core.credentials import AccessToken, TokenCredential
from azure.core.exceptions import ResourceExistsError
from azure.keyvault.certificates import (
CertificateClient,
Expand Down Expand Up @@ -1696,6 +1698,7 @@ def get_or_create_storage_container(
"""
Create a Azure Storage container if it does not exist.
"""
credential = get_static_access_token("AZURE_STORAGE_ACCESS_TOKEN") or credential
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the env name variable is different? Is it defined by the staroge class or us?

Copy link
Collaborator Author

@LiliDeng LiliDeng Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is defined by us, the tokens need to be generated based on different scopes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Is it possible to have one token with all scopes like *?
  2. Define schema to accept tokens, instead of env vars. All LISA vars should be defined in runbook.

Let me know, if it's hard to fulfill above.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I have confirmed with the maintainer of python sdk, he said we have to use different tokens for different scopes.
  2. Use different variables to accept the tokens?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Use different variables to accept the tokens?

As I said, define in schema for them, not directly use env vars. After you defined schema, you can set values by LISA vars, and then assign LISA vars by env vars.

Like you defined access_token in PlatformSchema. Please define similar tokens for other Azure scopes.

blob_service_client = get_blob_service_client(
cloud=cloud,
credential=credential,
Expand Down Expand Up @@ -2696,14 +2699,60 @@ def get_size(disk_type: schema.DiskType, data_disk_iops: int = 1) -> int:
raise LisaException(f"Data disk type {disk_type} is unsupported.")


class StaticAccessTokenCredential(TokenCredential):
def __init__(self, token: str) -> None:
squirrelsc marked this conversation as resolved.
Show resolved Hide resolved
"""
Initialize StaticAccessTokenCredential with the provided token.

:param token: The Azure access token as a string.
"""
self._token = token
self._expires_on = self._get_exp()

def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
"""
Get the access token for the specified scopes.

:param scopes: The OAuth 2.0 scopes the token applies to.
:param kwargs: Additional keyword arguments that may be required by the SDK.
:return: An AccessToken instance containing the token and its expiry time.
"""
# You can choose to print or log the scopes and kwargs for debugging if needed
return AccessToken(self._token, self._expires_on)

def _get_exp(self) -> Any:
# The second part of the JWT is the payload
payload = self._token.split(".")[1]
# Add padding to ensure Base64 decoding works properly
padded_payload = payload + "=" * (4 - len(payload) % 4)
# Decode the Base64 URL-safe encoded payload
decoded_payload = base64.urlsafe_b64decode(padded_payload)
# Convert the payload into a dictionary and get the expiration time
# 'exp' is the UNIX timestamp for expiration
return json.loads(decoded_payload).get("exp")


def get_static_access_token(token_type: str) -> Any:
credential = None
if token_type in os.environ:
credential = StaticAccessTokenCredential(os.environ[token_type])
return credential


def get_certificate_client(
vault_url: str, platform: "AzurePlatform"
) -> CertificateClient:
return CertificateClient(vault_url, platform.credential)
credential = (
get_static_access_token("AZURE_KEYVAULT_ACCESS_TOKEN") or platform.credential
)
return CertificateClient(vault_url, credential)


def get_secret_client(vault_url: str, platform: "AzurePlatform") -> SecretClient:
return SecretClient(vault_url, platform.credential)
credential = (
get_static_access_token("AZURE_KEYVAULT_ACCESS_TOKEN") or platform.credential
)
return SecretClient(vault_url, credential)


def get_key_vault_management_client(
Expand Down Expand Up @@ -2799,7 +2848,13 @@ def get_identity_id(
else:
endpoint = "me"
graph_api_url = f"{base_url}{api_version}/{endpoint}"
token = platform.credential.get_token("https://graph.microsoft.com/.default").token
credential = (
get_static_access_token("AZURE_GRAPH_ACCESS_TOKEN") or platform.credential
)
if isinstance(credential, StaticAccessTokenCredential):
token = credential._token
else:
token = credential.get_token("https://graph.microsoft.com/.default").token
# Set up the API call headers
headers = {
"Authorization": f"Bearer {token}",
Expand Down Expand Up @@ -3002,9 +3057,10 @@ def create_certificate(
def check_certificate_existence(
vault_url: str, cert_name: str, log: Logger, platform: "AzurePlatform"
) -> bool:
certificate_client = CertificateClient(
vault_url=vault_url, credential=platform.credential
credential = (
get_static_access_token("AZURE_KEYVAULT_ACCESS_TOKEN") or platform.credential
)
certificate_client = CertificateClient(vault_url=vault_url, credential=credential)

try:
certificate = certificate_client.get_certificate(cert_name)
Expand Down
16 changes: 13 additions & 3 deletions lisa/sut_orchestrator/azure/platform_.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast

import requests
from azure.core.credentials import TokenCredential
from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
from azure.identity import DefaultAzureCredential
from azure.mgmt.compute.models import (
Expand Down Expand Up @@ -116,6 +117,7 @@
get_or_create_storage_container,
get_primary_ip_addresses,
get_resource_management_client,
get_static_access_token,
get_storage_account_name,
get_vhd_details,
get_vm,
Expand Down Expand Up @@ -246,6 +248,7 @@ class AzurePlatformSchema:
),
)
service_principal_key: str = field(default="")
access_token: str = field(default="")
subscription_id: str = field(
default="",
metadata=field_metadata(
Expand Down Expand Up @@ -320,6 +323,7 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None:
"service_principal_tenant_id",
"service_principal_client_id",
"service_principal_key",
"access_token",
"subscription_id",
"shared_resource_group_name",
"resource_group_name",
Expand All @@ -338,6 +342,8 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None:
add_secret(self.subscription_id, mask=PATTERN_GUID)
if self.service_principal_key:
add_secret(self.service_principal_key)
if self.access_token:
add_secret(self.access_token)
if self.service_principal_client_id:
add_secret(self.service_principal_client_id, mask=PATTERN_GUID)

Expand Down Expand Up @@ -407,14 +413,14 @@ class AzurePlatform(Platform):
)
_arm_template: Any = None

_credentials: Dict[str, DefaultAzureCredential] = {}
_credentials: Dict[str, Union[DefaultAzureCredential, TokenCredential]] = {}
_locations_data_cache: Dict[str, AzureLocation] = {}

def __init__(self, runbook: schema.Platform) -> None:
super().__init__(runbook=runbook)

# for type detection
self.credential: DefaultAzureCredential
self.credential: Union[DefaultAzureCredential, TokenCredential]
self.cloud: Cloud

# It has to be defined after the class definition is loaded. So it
Expand Down Expand Up @@ -936,8 +942,12 @@ def _initialize_credential(self) -> None:
] = azure_runbook.service_principal_client_id
if azure_runbook.service_principal_key:
os.environ["AZURE_CLIENT_SECRET"] = azure_runbook.service_principal_key
if azure_runbook.access_token:
os.environ["AZURE_ACCESS_TOKEN"] = azure_runbook.access_token

credential = DefaultAzureCredential(
credential = get_static_access_token(
"AZURE_ACCESS_TOKEN"
) or DefaultAzureCredential(
authority=self.cloud.endpoints.active_directory,
)

Expand Down
Loading