Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 45 additions & 12 deletions patches/kaggle_gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def init_bigquery():
return bigquery

# If this Kernel has bigquery integration on startup, preload the Kaggle Credentials
# object for magics to work.
# object for magics to work.
if get_integrations().has_bigquery():
from google.cloud.bigquery import magics
magics.context.credentials = KaggleKernelCredentials()
Expand All @@ -139,7 +139,7 @@ def monkeypatch_bq(bq_client, *args, **kwargs):
from kaggle_gcp import get_integrations, PublicBigqueryClient, KaggleKernelCredentials
specified_credentials = kwargs.get('credentials')
has_bigquery = get_integrations().has_bigquery()
# Prioritize passed in project id, but if it is missing look for env var.
# Prioritize passed in project id, but if it is missing look for env var.
arg_project = kwargs.get('project')
explicit_project_id = arg_project or os.environ.get(environment_vars.PROJECT)
# This is a hack to get around the bug in google-cloud library.
Expand Down Expand Up @@ -176,6 +176,18 @@ def monkeypatch_bq(bq_client, *args, **kwargs):
bq_client, *args, **kwargs)
return bigquery

def monkeypatch_client(client_klass, kaggle_kernel_credentials):
client_init = client_klass.__init__
def patched_init(self, *args, **kwargs):
specified_credentials = kwargs.get('credentials')
if specified_credentials is None:
Log.info("No credentials specified, using KaggleKernelCredentials.")
kwargs['credentials'] = kaggle_kernel_credentials
return client_init(self, *args, **kwargs)

if (not has_been_monkeypatched(client_klass.__init__)):
client_klass.__init__ = patched_init

def init_gcs():
is_user_secrets_token_set = "KAGGLE_USER_SECRETS_TOKEN" in os.environ
from google.cloud import storage
Expand All @@ -188,21 +200,42 @@ def init_gcs():

from kaggle_secrets import GcpTarget
from kaggle_gcp import KaggleKernelCredentials
gcs_client_init = storage.Client.__init__
def monkeypatch_gcs(self, *args, **kwargs):
specified_credentials = kwargs.get('credentials')
if specified_credentials is None:
Log.info("No credentials specified, using KaggleKernelCredentials.")
kwargs['credentials'] = KaggleKernelCredentials(target=GcpTarget.GCS)
return gcs_client_init(self, *args, **kwargs)

if (not has_been_monkeypatched(storage.Client.__init__)):
storage.Client.__init__ = monkeypatch_gcs
monkeypatch_client(
storage.Client,
KaggleKernelCredentials(target=GcpTarget.GCS))
return storage

def init_automl():
is_user_secrets_token_set = "KAGGLE_USER_SECRETS_TOKEN" in os.environ
from google.cloud import automl_v1beta1 as automl
if not is_user_secrets_token_set:
return automl

from kaggle_gcp import get_integrations
if not get_integrations().has_automl():
return automl

from kaggle_secrets import GcpTarget
from kaggle_gcp import KaggleKernelCredentials
kaggle_kernel_credentials = KaggleKernelCredentials(target=GcpTarget.AUTOML)

# The AutoML client library exposes 4 different client classes (AutoMlClient,
# TablesClient, PredictionServiceClient and GcsClient), so patch each of them.
# The same KaggleKernelCredentials are passed to all of them.
monkeypatch_client(automl.AutoMlClient, kaggle_kernel_credentials)
monkeypatch_client(automl.TablesClient, kaggle_kernel_credentials)
monkeypatch_client(automl.PredictionServiceClient, kaggle_kernel_credentials)
# TODO(markcollins): The GcsClient in the AutoML client library version
# 0.5.0 doesn't handle credentials properly. I wrote PR:
# https://github.com/googleapis/google-cloud-python/pull/9299
# to address this issue. Add patching for GcsClient when we get a version of
# the library that includes the fixes.
return automl

def init():
init_bigquery()
init_gcs()
init_automl()

# We need to initialize the monkeypatching of the client libraries
# here since there is a circular dependency between our import hook version
Expand Down
5 changes: 3 additions & 2 deletions patches/sitecustomize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import importlib.machinery

class GcpModuleFinder(importlib.abc.MetaPathFinder):
_MODULES = ['google.cloud.bigquery', 'google.cloud.storage']
_MODULES = ['google.cloud.bigquery', 'google.cloud.storage', 'google.cloud.automl_v1beta1']
_KAGGLE_GCP_PATH = 'kaggle_gcp.py'
def __init__(self):
pass
Expand Down Expand Up @@ -39,7 +39,8 @@ def create_module(self, spec):
import kaggle_gcp
_LOADERS = {
'google.cloud.bigquery': kaggle_gcp.init_bigquery,
'google.cloud.storage': kaggle_gcp.init_gcs
'google.cloud.storage': kaggle_gcp.init_gcs,
'google.cloud.automl_v1beta1': kaggle_gcp.init_automl,
}
monkeypatch_gcp_module = _LOADERS[spec.name]()
return monkeypatch_gcp_module
Expand Down
81 changes: 80 additions & 1 deletion tests/test_automl.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,88 @@
import unittest

from google.cloud import automl_v1beta1 as automl
from unittest.mock import Mock, patch

from kaggle_gcp import KaggleKernelCredentials, init_automl
from test.support import EnvironmentVarGuard
from google.cloud import storage, automl_v1beta1 as automl

def _make_credentials():
import google.auth.credentials
return Mock(spec=google.auth.credentials.Credentials)

class TestAutoMl(unittest.TestCase):

def test_version(self):
self.assertIsNotNone(automl.auto_ml_client._GAPIC_LIBRARY_VERSION)
version_parts = automl.auto_ml_client._GAPIC_LIBRARY_VERSION.split('.')
version = float('.'.join(version_parts[0:2]));
self.assertGreaterEqual(version, 0.5);

class FakeClient:
def __init__(self, credentials=None):
self.credentials = credentials

@patch("google.cloud.automl_v1beta1.AutoMlClient", new=FakeClient)
def test_user_provided_credentials(self):
credentials = _make_credentials()
env = EnvironmentVarGuard()
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
env.set('KAGGLE_KERNEL_INTEGRATIONS', 'AUTOML')
with env:
init_automl()
client = automl.AutoMlClient(credentials=credentials)
self.assertNotIsInstance(client.credentials, KaggleKernelCredentials)
self.assertIsNotNone(client.credentials)


def test_tables_gcs_client(self):
# The GcsClient can't currently be monkeypatched for default
# credentials because it requires a project which can't be set.
# Verify that creating an automl.GcsClient given an actual
# storage.Client sets the client properly.
gcs_client = storage.Client(project="xyz", credentials=_make_credentials())
tables_gcs_client = automl.GcsClient(client=gcs_client)
self.assertIs(tables_gcs_client.client, gcs_client)


@patch("google.cloud.automl_v1beta1.AutoMlClient", new=FakeClient)
def test_default_credentials_automl_client(self):
env = EnvironmentVarGuard()
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
env.set('KAGGLE_KERNEL_INTEGRATIONS', 'AUTOML')
with env:
init_automl()
automl_client = automl.AutoMlClient()
self.assertIsNotNone(automl_client.credentials)
self.assertIsInstance(automl_client.credentials, KaggleKernelCredentials)

@patch("google.cloud.automl_v1beta1.TablesClient", new=FakeClient)
def test_default_credentials_tables_client(self):
env = EnvironmentVarGuard()
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
env.set('KAGGLE_KERNEL_INTEGRATIONS', 'AUTOML')
with env:
init_automl()
tables_client = automl.TablesClient()
self.assertIsNotNone(tables_client.credentials)
self.assertIsInstance(tables_client.credentials, KaggleKernelCredentials)

@patch("google.cloud.automl_v1beta1.PredictionServiceClient", new=FakeClient)
def test_default_credentials_prediction_client(self):
env = EnvironmentVarGuard()
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
env.set('KAGGLE_KERNEL_INTEGRATIONS', 'AUTOML')
with env:
prediction_client = automl.PredictionServiceClient()
self.assertIsNotNone(prediction_client.credentials)
self.assertIsInstance(prediction_client.credentials, KaggleKernelCredentials)

def test_monkeypatching_idempotent(self):
env = EnvironmentVarGuard()
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
env.set('KAGGLE_KERNEL_INTEGRATIONS', 'GCS')
with env:
client1 = automl.AutoMlClient.__init__
init_automl()
client2 = automl.AutoMlClient.__init__
self.assertEqual(client1, client2)