GoogleCloudPlatform / cloud-sql-jdbc-socket-factory

A collection of Java libraries for connecting securely to Cloud SQL
Apache License 2.0
229 stars 118 forks source link

How to use a custom Credential Factory with spark jdbc ? #2040

Open sl-nicolasmoteley opened 1 month ago

sl-nicolasmoteley commented 1 month ago

Question

Hi there !

Headaches for us since weeks...

We don't want to use GOOGLE_APPLICATION_CREDENTIALS var with a json file but we want to generate GoogleCredentials from aws secretmanager.

So we implement a custom Credential Factory then we use it with our task to read a postgre database through spark jdbc.

But we got this error message. It looks like the workers don't get the creds (master yes)...

24/07/11 14:27:13 WARN TaskSetManager: Lost task 0.0 in stage 0.0 (TID 0) (ip-10-32-25-50.eu-west-1.compute.internal executor 3): org.postgresql.util.PSQLException: Something unusual has occurred to cause the driver to fail. Please report this exception.
    at org.postgresql.Driver.connect(Driver.java:285)
    at org.apache.spark.sql.execution.datasources.jdbc.connection.BasicConnectionProvider.getConnection(BasicConnectionProvider.scala:49)
    at org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProviderBase.create(ConnectionProvider.scala:102)
    at org.apache.spark.sql.jdbc.JdbcDialect.$anonfun$createConnectionFactory$1(JdbcDialects.scala:122)
    at org.apache.spark.sql.jdbc.JdbcDialect.$anonfun$createConnectionFactory$1$adapted(JdbcDialects.scala:118)
    at org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD.compute(JDBCRDD.scala:273)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
    at org.apache.spark.scheduler.Task.run(Task.scala:138)
    at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1516)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:750)
Caused by: java.lang.RuntimeException: Unable to obtain credentials to communicate with the Cloud SQL API
    at com.shaded.com.google.cloud.sql.core.ApplicationDefaultCredentialFactory.getCredentials(ApplicationDefaultCredentialFactory.java:42)
    at com.shaded.com.google.cloud.sql.core.InternalConnectorRegistry.createConnector(InternalConnectorRegistry.java:311)
    at com.shaded.com.google.cloud.sql.core.InternalConnectorRegistry.lambda$getConnector$0(InternalConnectorRegistry.java:300)
    at java.util.concurrent.ConcurrentHashMap.computeIfAbsent(ConcurrentHashMap.java:1660)
    at com.shaded.com.google.cloud.sql.core.InternalConnectorRegistry.getConnector(InternalConnectorRegistry.java:299)
    at com.shaded.com.google.cloud.sql.core.InternalConnectorRegistry.connect(InternalConnectorRegistry.java:179)
    at com.shaded.com.google.cloud.sql.postgres.SocketFactory.createSocket(SocketFactory.java:81)
    at org.postgresql.core.PGStream.createSocket(PGStream.java:223)
    at org.postgresql.core.PGStream.<init>(PGStream.java:95)
    at org.postgresql.core.v3.ConnectionFactoryImpl.tryConnect(ConnectionFactoryImpl.java:98)
    at org.postgresql.core.v3.ConnectionFactoryImpl.openConnectionImpl(ConnectionFactoryImpl.java:213)
    at org.postgresql.core.ConnectionFactory.openConnection(ConnectionFactory.java:51)
    at org.postgresql.jdbc.PgConnection.<init>(PgConnection.java:223)
    at org.postgresql.Driver.makeConnection(Driver.java:465)
    at org.postgresql.Driver.connect(Driver.java:264)
    ... 21 more
Caused by: java.io.IOException: Your default credentials were not found. To set up Application Default Credentials for your environment, see https://cloud.google.com/docs/authentication/external/set-up-adc.
    at com.shaded.com.google.auth.oauth2.DefaultCredentialsProvider.getDefaultCredentials(DefaultCredentialsProvider.java:127)
    at com.shaded.com.google.auth.oauth2.GoogleCredentials.getApplicationDefault(GoogleCredentials.java:152)
    at com.shaded.com.google.auth.oauth2.GoogleCredentials.getApplicationDefault(GoogleCredentials.java:124)
    at com.shaded.com.google.cloud.sql.core.ApplicationDefaultCredentialFactory.getCredentials(ApplicationDefaultCredentialFactory.java:40)
    ... 35 more

Code

package com.aviv.data.spark.datalake.tasks.ma_www

import com.amazonaws.services.secretsmanager.AWSSecretsManager
import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest
import com.google.api.client.http.HttpRequestInitializer
import com.google.auth.http.HttpCredentialsAdapter
import com.google.auth.oauth2.GoogleCredentials
import com.google.cloud.sql.CredentialFactory
import com.seloger.data.spark.boot.utils.SerializableConfiguration
import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
import org.json4s.jackson.JsonMethods
import org.json4s.jackson.Serialization.write
import org.slf4j.{Logger, LoggerFactory}

import java.io.ByteArrayInputStream
import java.nio.charset.StandardCharsets
import java.nio.file.Files

class CustomCredentialFactory extends CredentialFactory {

  val logger: Logger = LoggerFactory.getLogger(getClass)

  override def create(): HttpRequestInitializer = {
    val credentials = getCredentials()
    new HttpCredentialsAdapter(credentials)
  }

  override def getCredentials(): GoogleCredentials = {

    logger.info("Entered getCredentials of CustomCredentialFactory")

    val secretName = "XXXXXX-secret"
    val client: AWSSecretsManager = SerializableConfiguration.defaultClient
    val getSecretValueRequest = new GetSecretValueRequest()
    getSecretValueRequest.setSecretId(secretName)
    val secretString = client.getSecretValue(getSecretValueRequest).getSecretString
    logger.info(s"Secret string : $secretString")
    implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats
    val secretMap = JsonMethods.parse(secretString).extract[Map[String, String]]
    logger.info(s"Secret map : ${secretMap.toString()}")
    val credentialSource = CredentialSource(
      environment_id = secretMap("credential_source_environment_id"),
      region_url = secretMap("credential_source_region_url"),
      url = secretMap("credential_source_url"),
      regional_cred_verification_url = secretMap("credential_source_regional_cred_verification_url")
    )
    val googleCredentialsParam = GoogleCredentialsParam(
      `type` = secretMap("type"),
      audience = secretMap("audience"),
      subject_token_type = secretMap("subject_token_type"),
      credential_source = credentialSource,
      service_account_impersonation_url = secretMap("service_account_impersonation_url")
    )
    val secret = write(googleCredentialsParam)
    GoogleCredentials.fromStream(new ByteArrayInputStream(secret.getBytes(StandardCharsets.UTF_8)))
  }
}

package com.aviv.data.spark.datalake.tasks.ma_www

import com.google.auth.oauth2.GoogleCredentials
import com.google.cloud.sql.CredentialFactory
import com.seloger.data.spark.boot.api.SparkTask
import com.seloger.data.spark.boot.utils.SerializableConfiguration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.json4s.jackson.Serialization.write
import org.json4s.DefaultFormats
import org.json4s.jackson.JsonMethods

import java.io.{BufferedWriter, File, OutputStreamWriter}
import java.nio.charset.StandardCharsets
import java.util.Properties

class CloudSqlProxyTask(config: SerializableConfiguration) extends SparkTask {

  override def run(): Unit = {

    implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats

    System.setProperty(
      CredentialFactory.CREDENTIAL_FACTORY_PROPERTY,
      "com.aviv.data.spark.datalake.tasks.ma_www.CustomCredentialFactory"
    )

    val dbName = "postgres"
    val instanceConnectionName = "pg8000"
    val jdbcUrl = s"jdbc:postgresql:///$dbName"

    val connProps = new Properties()
    connProps.setProperty("driver", "org.postgresql.Driver")
    connProps.setProperty("user", "XXXXXX")
    connProps.setProperty("sslmode", "disable")
    connProps.setProperty("socketFactory", "com.google.cloud.sql.postgres.SocketFactory")
    connProps.setProperty("cloudSqlInstance", "XXXXX")
    connProps.setProperty("enableIamAuth", "true")

    val query = "(SELECT * FROM information_schema.tables) tables"

    val df = session.read.jdbc(jdbcUrl, query, connProps)

    df.show(false)
  }
}

case class CredentialSource(
  environment_id: String,
  region_url: String,
  url: String,
  regional_cred_verification_url: String
)

case class GoogleCredentialsParam(
  `type`: String,
  audience: String,
  subject_token_type: String,
  credential_source: CredentialSource,
  service_account_impersonation_url: String
)

    <dependency>
      <groupId>com.google.cloud.sql</groupId>
      <artifactId>postgres-socket-factory</artifactId>
      <version>1.18.0</version>
    </dependency>

Additional Details

No response

jackwotherspoon commented 1 month ago

Hi @sl-nicolasmoteley , thanks for asking a question on the Cloud SQL Java Connector!

Yes our library uses Application Default Credentials (ADC) to source creds from the environment.

There are several different ways that ADC creds can be set, I wonder if workload identity federation would be a good candidate for your use-case since you mention AWS secret manager.

https://cloud.google.com/iam/docs/workload-identity-federation-with-other-clouds

@hessjcg may be more familiar with Spark and may have suggestions for you.

hessjcg commented 1 month ago

Hi @sl-nicolasmoteley,

It seems likely that this bit of code from your sample:

    System.setProperty(
      CredentialFactory.CREDENTIAL_FACTORY_PROPERTY,
      "com.aviv.data.spark.datalake.tasks.ma_www.CustomCredentialFactory"
    )

is only setting the system property on the task master, not on the worker instances for some reason. Thus, the master uses your custom credential factory but the workers do not.

Here is a more reliable way to install a custom credential factory: Programmatically configure your connector using Java code. See Registering a named connector.

You would update your sample something like this:

import com.google.cloud.sql.ConnectorConfig;
import com.google.cloud.sql.ConnectorRegistry;

// ...

class CloudSqlProxyTask(config: SerializableConfiguration) extends SparkTask {

  override def run(): Unit = {

    implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats

    Supplier<GoogleCredentials> credentialSupplier = () -> { 
         // your custom credential code goes here 
         return new GoogleCredentials(); 
     }

     ConnectorConfig cc =
        new ConnectorConfig.Builder()
            .withGoogleCredentialsSupplier(credentialSupplier)
            .build();

      ConnectorRegistry.register("spark-pg", cc) 

    val dbName = "postgres"
    val instanceConnectionName = "pg8000"
    val jdbcUrl = s"jdbc:postgresql:///$dbName"

    val connProps = new Properties()

    // Configure the JDBC properties to use your  connector configuration
    connProps.setProperty("cloudSqlNamedConnector", "spark-pg")

    connProps.setProperty("driver", "org.postgresql.Driver")
    connProps.setProperty("user", "XXXXXX")
    connProps.setProperty("sslmode", "disable")
    connProps.setProperty("socketFactory", "com.google.cloud.sql.postgres.SocketFactory")
    connProps.setProperty("cloudSqlInstance", "XXXXX")
    connProps.setProperty("enableIamAuth", "true")

    val query = "(SELECT * FROM information_schema.tables) tables"

    val df = session.read.jdbc(jdbcUrl, query, connProps)

    df.show(false)
  }
}

}
sl-nicolasmoteley commented 1 month ago

Hi @hessjcg, thanks for your answer. I've just added the named connector with your code sample but it's not found...

Caused by: java.lang.IllegalArgumentException: Named connection spark-pg does not exist.
    at com.shaded.com.google.cloud.sql.core.InternalConnectorRegistry.getNamedConnector(InternalConnectorRegistry.java:367) ~[spark-boot-app.jar:?]
    at com.shaded.com.google.cloud.sql.core.InternalConnectorRegistry.connect(InternalConnectorRegistry.java:169) ~[spark-boot-app.jar:?]
    at com.shaded.com.google.cloud.sql.postgres.SocketFactory.createSocket(SocketFactory.java:81) ~[spark-boot-app.jar:?]
    at org.postgresql.core.PGStream.createSocket(PGStream.java:223) ~[spark-boot-app.jar:?]
    at org.postgresql.core.PGStream.<init>(PGStream.java:95) ~[spark-boot-app.jar:?]
    at org.postgresql.core.v3.ConnectionFactoryImpl.tryConnect(ConnectionFactoryImpl.java:98) ~[spark-boot-app.jar:?]
    at org.postgresql.core.v3.ConnectionFactoryImpl.openConnectionImpl(ConnectionFactoryImpl.java:213) ~[spark-boot-app.jar:?]
    at org.postgresql.core.ConnectionFactory.openConnection(ConnectionFactory.java:51) ~[spark-boot-app.jar:?]
    at org.postgresql.jdbc.PgConnection.<init>(PgConnection.java:223) ~[spark-boot-app.jar:?]
    at org.postgresql.Driver.makeConnection(Driver.java:465) ~[spark-boot-app.jar:?]
    at org.postgresql.Driver.connect(Driver.java:264) ~[spark-boot-app.jar:?]
    at org.apache.spark.sql.execution.datasources.jdbc.connection.BasicConnectionProvider.getConnection(BasicConnectionProvider.scala:49) ~[spark-sql_2.12-3.3.2-amzn-0.jar:3.3.2-amzn-0]
    at org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProviderBase.create(ConnectionProvider.scala:102) ~[spark-sql_2.12-3.3.2-amzn-0.jar:3.3.2-amzn-0]
    at org.apache.spark.sql.jdbc.JdbcDialect.$anonfun$createConnectionFactory$1(JdbcDialects.scala:122) ~[spark-sql_2.12-3.3.2-amzn-0.jar:3.3.2-amzn-0]
    at org.apache.spark.sql.jdbc.JdbcDialect.$anonfun$createConnectionFactory$1$adapted(JdbcDialects.scala:118) ~[spark-sql_2.12-3.3.2-amzn-0.jar:3.3.2-amzn-0]
    at org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD.compute(JDBCRDD.scala:273) ~[spark-sql_2.12-3.3.2-amzn-0.jar:3.3.2-amzn-0]
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365) ~[spark-core_2.12-3.3.2-amzn-0.jar:3.3.2-amzn-0]
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:329) ~[spark-core_2.12-3.3.2-amzn-0.jar:3.3.2-amzn-0]
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) ~[spark-core_2.12-3.3.2-amzn-0.jar:3.3.2-amzn-0]
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365) ~[spark-core_2.12-3.3.2-amzn-0.jar:3.3.2-amzn-0]
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:329) ~[spark-core_2.12-3.3.2-amzn-0.jar:3.3.2-amzn-0]
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) ~[spark-core_2.12-3.3.2-amzn-0.jar:3.3.2-amzn-0]
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365) ~[spark-core_2.12-3.3.2-amzn-0.jar:3.3.2-amzn-0]
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:329) ~[spark-core_2.12-3.3.2-amzn-0.jar:3.3.2-amzn-0]
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90) ~[spark-core_2.12-3.3.2-amzn-0.jar:3.3.2-amzn-0]
    at org.apache.spark.scheduler.Task.run(Task.scala:138) ~[spark-core_2.12-3.3.2-amzn-0.jar:3.3.2-amzn-0]
    at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548) ~[spark-core_2.12-3.3.2-amzn-0.jar:3.3.2-amzn-0]
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1516) ~[spark-core_2.12-3.3.2-amzn-0.jar:3.3.2-amzn-0]
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551) ~[spark-core_2.12-3.3.2-amzn-0.jar:3.3.2-amzn-0]
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) ~[?:1.8.0_412]
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) ~[?:1.8.0_412]
    at java.lang.Thread.run(Thread.java:750) ~[?:1.8.0_412]