cirro.auth

 1from io import StringIO
 2from typing import Optional
 3
 4from cirro.auth.access_token import AccessTokenAuth
 5from cirro.auth.device_code import DeviceCodeAuth
 6
 7__all__ = [
 8    'get_auth_info_from_config',
 9    "DeviceCodeAuth",
10    "AccessTokenAuth",
11]
12
13from cirro.config import AppConfig
14
15
16def get_auth_info_from_config(app_config: AppConfig, auth_io: Optional[StringIO] = None):
17    user_config = app_config.user_config
18    if not user_config or not user_config.auth_method:
19        return DeviceCodeAuth(region=app_config.region,
20                              client_id=app_config.client_id,
21                              auth_endpoint=app_config.auth_endpoint,
22                              auth_io=auth_io)
23
24    auth_methods = [
25        DeviceCodeAuth
26    ]
27    matched_auth_method = next((m for m in auth_methods if m.__name__ == user_config.auth_method), None)
28    if not matched_auth_method:
29        # Backwards compatibility
30        if user_config.auth_method == 'ClientAuth':
31            matched_auth_method = DeviceCodeAuth
32        else:
33            raise RuntimeError(f'{user_config.auth_method} not found, please re-run configuration')
34
35    auth_config = user_config.auth_method_config
36
37    if matched_auth_method == DeviceCodeAuth:
38        return DeviceCodeAuth(region=app_config.region,
39                              client_id=app_config.client_id,
40                              auth_endpoint=app_config.auth_endpoint,
41                              enable_cache=auth_config.get('enable_cache') == 'True',
42                              auth_io=auth_io)
def get_auth_info_from_config( app_config: cirro.config.AppConfig, auth_io: Optional[_io.StringIO] = None):
17def get_auth_info_from_config(app_config: AppConfig, auth_io: Optional[StringIO] = None):
18    user_config = app_config.user_config
19    if not user_config or not user_config.auth_method:
20        return DeviceCodeAuth(region=app_config.region,
21                              client_id=app_config.client_id,
22                              auth_endpoint=app_config.auth_endpoint,
23                              auth_io=auth_io)
24
25    auth_methods = [
26        DeviceCodeAuth
27    ]
28    matched_auth_method = next((m for m in auth_methods if m.__name__ == user_config.auth_method), None)
29    if not matched_auth_method:
30        # Backwards compatibility
31        if user_config.auth_method == 'ClientAuth':
32            matched_auth_method = DeviceCodeAuth
33        else:
34            raise RuntimeError(f'{user_config.auth_method} not found, please re-run configuration')
35
36    auth_config = user_config.auth_method_config
37
38    if matched_auth_method == DeviceCodeAuth:
39        return DeviceCodeAuth(region=app_config.region,
40                              client_id=app_config.client_id,
41                              auth_endpoint=app_config.auth_endpoint,
42                              enable_cache=auth_config.get('enable_cache') == 'True',
43                              auth_io=auth_io)
class DeviceCodeAuth(cirro.auth.base.AuthInfo):
 92class DeviceCodeAuth(AuthInfo):
 93    """
 94    Authenticates to Cirro by asking
 95    the user to enter a verification code on the portal website
 96
 97    :param client_id: The client ID for the OAuth application
 98    :param region: The AWS region where the Cognito user pool is located
 99    :param auth_endpoint: The endpoint for the OAuth authorization server
100    :param enable_cache: Optionally enable cache to avoid re-authentication
101    :param auth_io: Optionally provide a StringIO object for the authentication link
102    :param await_completion:
103        If True, block until the user completes the authorization.
104            If auth_io is provided, the authorization message will be written to that buffer.
105            If auth_io is not provided, the authorization message will be printed.
106        If False, the object will be instantiated without fully completing the authorization.
107            The authorization message can be accessed using the .auth_message property.
108            Then, the await_completion() method must be run to complete the process.
109
110    Implements the OAuth device code flow
111    This is the preferred way to authenticate
112    """
113    def __init__(
114        self,
115        client_id: str,
116        region: str,
117        auth_endpoint: str,
118        enable_cache=False,
119        auth_io: Optional[StringIO] = None,
120        await_completion=True
121    ):
122        self.client_id = client_id
123        self.auth_endpoint = auth_endpoint
124        self.region = region
125        self._token_info: Optional[OAuthTokenResponse] = None
126        self._persistence: Optional[BasePersistence] = None
127        self._flow: Optional[DeviceTokenResponse] = None
128        self._token_path = Path(Constants.home, f'{client_id}.token.dat').expanduser()
129
130        if enable_cache:
131            self._persistence = _build_token_persistence(str(self._token_path), fallback_to_plaintext=True)
132            self._token_info = self._load_token_info()
133
134        # Check saved token for change in endpoint
135        if self._token_info and self._token_info.get('client_id') != client_id:
136            logger.debug('Different client ID found, clearing saved token info')
137            self._clear_token_info()
138
139        # Check saved token for refresh token expiry
140        if self._token_info and self._token_info.get('refresh_expires_in'):
141            refresh_expiry_threshold = datetime.fromtimestamp(self._token_info.get('refresh_expires_in'))\
142                                       - timedelta(hours=12)
143            if refresh_expiry_threshold < datetime.now():
144                logger.debug('Refresh token expiry is too soon, re-authenticating')
145                self._clear_token_info()
146
147        if not self._token_info:
148            if await_completion:
149                self._token_info = _authenticate(client_id=client_id, auth_endpoint=auth_endpoint, auth_io=auth_io)
150            else:
151                self._flow = _initialize_auth_flow(client_id=client_id, auth_endpoint=auth_endpoint)
152
153        if self._token_info:
154            self._save_token_info()
155            self._update_token_metadata()
156            self._get_token_lock = threading.Lock()
157
158    @property
159    def auth_message(self):
160        """
161        If the DeviceCodeAuth was instantiated with await_completion=False,
162        then the authorization message will be populated by this property.
163        """
164        if self._flow is None:
165            raise ValueError("The DeviceTokenResponse is not available")
166        else:
167            return self._flow["message"]
168
169    @property
170    def auth_message_markdown(self):
171        """
172        Markdown syntax for the authorization message, so that links are rendered appropriately.
173        """
174        return " ".join([
175            (
176                f"[{part}]({part})"
177                if part.startswith("http")
178                else part
179            )
180            for part in self.auth_message.split(" ")
181        ])
182
183    def await_completion(self):
184        """Block until the user completes the authorization process."""
185        self._token_info = _await_completion(
186            client_id=self.client_id,
187            auth_endpoint=self.auth_endpoint,
188            flow=self._flow
189        )
190        self._save_token_info()
191        self._update_token_metadata()
192        self._get_token_lock = threading.Lock()
193
194    def get_auth_method(self) -> AuthMethod:
195        return RefreshableTokenAuth(token_getter=lambda: self._get_token()['access_token'])
196
197    def get_current_user(self) -> str:
198        return self._username
199
200    def _get_token(self):
201        with self._get_token_lock:
202            # Refresh access token using refresh token
203            if datetime.now() > self.access_token_expiry:
204                self._refresh_access_token()
205
206        return self._token_info
207
208    def _refresh_access_token(self):
209        try:
210            cognito = boto3.client('cognito-idp', region_name=self.region)
211            resp = cognito.initiate_auth(
212                ClientId=self.client_id,
213                AuthFlow='REFRESH_TOKEN_AUTH',
214                AuthParameters={
215                    'REFRESH_TOKEN': self._token_info['refresh_token']
216                }
217            )
218            logger.debug('Successfully refreshed token')
219        except ClientError as err:
220            logger.warning(err)
221            self._clear_token_info()
222            raise RuntimeError('Failed to refresh token, please reauthenticate')
223
224        auth_result = resp['AuthenticationResult']
225        self._token_info['access_token'] = auth_result['AccessToken']
226        self._token_info['id_token'] = auth_result['IdToken']
227        self._save_token_info()
228        self._update_token_metadata()
229
230    def _update_token_metadata(self):
231        decoded_access_token = jwt.decode(self._token_info['access_token'],
232                                          options={"verify_signature": False})
233        self.access_token_expiry = datetime.fromtimestamp(decoded_access_token['exp'])
234        self._username = decoded_access_token['username']
235
236    def _load_token_info(self) -> Optional[OAuthTokenResponse]:
237        if not self._persistence or not self._token_path.exists():
238            return None
239
240        token_info = json.loads(self._persistence.load())
241        if 'access_token' not in token_info:
242            return None
243
244        return token_info
245
246    def _save_token_info(self):
247        if not self._persistence:
248            return
249
250        self._persistence.save(json.dumps(self._token_info))
251
252    def _clear_token_info(self):
253        if not self._persistence:
254            return
255
256        Path(self._persistence.get_location()).unlink(missing_ok=True)
257        self._token_info = None

Authenticates to Cirro by asking the user to enter a verification code on the portal website

Parameters
  • client_id: The client ID for the OAuth application
  • region: The AWS region where the Cognito user pool is located
  • auth_endpoint: The endpoint for the OAuth authorization server
  • enable_cache: Optionally enable cache to avoid re-authentication
  • auth_io: Optionally provide a StringIO object for the authentication link
  • await_completion: If True, block until the user completes the authorization. If auth_io is provided, the authorization message will be written to that buffer. If auth_io is not provided, the authorization message will be printed. If False, the object will be instantiated without fully completing the authorization. The authorization message can be accessed using the .auth_message property. Then, the await_completion() method must be run to complete the process.

Implements the OAuth device code flow This is the preferred way to authenticate

DeviceCodeAuth( client_id: str, region: str, auth_endpoint: str, enable_cache=False, auth_io: Optional[_io.StringIO] = None, await_completion=True)
113    def __init__(
114        self,
115        client_id: str,
116        region: str,
117        auth_endpoint: str,
118        enable_cache=False,
119        auth_io: Optional[StringIO] = None,
120        await_completion=True
121    ):
122        self.client_id = client_id
123        self.auth_endpoint = auth_endpoint
124        self.region = region
125        self._token_info: Optional[OAuthTokenResponse] = None
126        self._persistence: Optional[BasePersistence] = None
127        self._flow: Optional[DeviceTokenResponse] = None
128        self._token_path = Path(Constants.home, f'{client_id}.token.dat').expanduser()
129
130        if enable_cache:
131            self._persistence = _build_token_persistence(str(self._token_path), fallback_to_plaintext=True)
132            self._token_info = self._load_token_info()
133
134        # Check saved token for change in endpoint
135        if self._token_info and self._token_info.get('client_id') != client_id:
136            logger.debug('Different client ID found, clearing saved token info')
137            self._clear_token_info()
138
139        # Check saved token for refresh token expiry
140        if self._token_info and self._token_info.get('refresh_expires_in'):
141            refresh_expiry_threshold = datetime.fromtimestamp(self._token_info.get('refresh_expires_in'))\
142                                       - timedelta(hours=12)
143            if refresh_expiry_threshold < datetime.now():
144                logger.debug('Refresh token expiry is too soon, re-authenticating')
145                self._clear_token_info()
146
147        if not self._token_info:
148            if await_completion:
149                self._token_info = _authenticate(client_id=client_id, auth_endpoint=auth_endpoint, auth_io=auth_io)
150            else:
151                self._flow = _initialize_auth_flow(client_id=client_id, auth_endpoint=auth_endpoint)
152
153        if self._token_info:
154            self._save_token_info()
155            self._update_token_metadata()
156            self._get_token_lock = threading.Lock()
client_id
auth_endpoint
region
auth_message
158    @property
159    def auth_message(self):
160        """
161        If the DeviceCodeAuth was instantiated with await_completion=False,
162        then the authorization message will be populated by this property.
163        """
164        if self._flow is None:
165            raise ValueError("The DeviceTokenResponse is not available")
166        else:
167            return self._flow["message"]

If the DeviceCodeAuth was instantiated with await_completion=False, then the authorization message will be populated by this property.

auth_message_markdown
169    @property
170    def auth_message_markdown(self):
171        """
172        Markdown syntax for the authorization message, so that links are rendered appropriately.
173        """
174        return " ".join([
175            (
176                f"[{part}]({part})"
177                if part.startswith("http")
178                else part
179            )
180            for part in self.auth_message.split(" ")
181        ])

Markdown syntax for the authorization message, so that links are rendered appropriately.

def await_completion(self):
183    def await_completion(self):
184        """Block until the user completes the authorization process."""
185        self._token_info = _await_completion(
186            client_id=self.client_id,
187            auth_endpoint=self.auth_endpoint,
188            flow=self._flow
189        )
190        self._save_token_info()
191        self._update_token_metadata()
192        self._get_token_lock = threading.Lock()

Block until the user completes the authorization process.

def get_auth_method(self) -> cirro_api_client.cirro_auth.AuthMethod:
194    def get_auth_method(self) -> AuthMethod:
195        return RefreshableTokenAuth(token_getter=lambda: self._get_token()['access_token'])
def get_current_user(self) -> str:
197    def get_current_user(self) -> str:
198        return self._username
class AccessTokenAuth(cirro.auth.base.AuthInfo):
11class AccessTokenAuth(AuthInfo):
12    """
13    Authenticates to Cirro with a static access token
14
15    :param token: Access token
16    """
17
18    def __init__(self, token: str):
19        self._token = token
20        self._username = None
21        self._access_token_expiry = None
22        self._update_token_metadata()
23
24    def get_current_user(self) -> str:
25        return self._username
26
27    def get_auth_method(self) -> AuthMethod:
28        return TokenAuth(token=self._token)
29
30    def _update_token_metadata(self):
31        decoded_access_token = jwt.decode(self._token,
32                                          options={"verify_signature": False})
33        self._access_token_expiry = datetime.fromtimestamp(decoded_access_token['exp'])
34        self._username = decoded_access_token['username']

Authenticates to Cirro with a static access token

Parameters
  • token: Access token
AccessTokenAuth(token: str)
18    def __init__(self, token: str):
19        self._token = token
20        self._username = None
21        self._access_token_expiry = None
22        self._update_token_metadata()
def get_current_user(self) -> str:
24    def get_current_user(self) -> str:
25        return self._username
def get_auth_method(self) -> cirro_api_client.cirro_auth.AuthMethod:
27    def get_auth_method(self) -> AuthMethod:
28        return TokenAuth(token=self._token)