diff --git a/src/main/scala/com/lucidworks/spark/SolrRelation.scala b/src/main/scala/com/lucidworks/spark/SolrRelation.scala index 198e93e1..1ef90f78 100644 --- a/src/main/scala/com/lucidworks/spark/SolrRelation.scala +++ b/src/main/scala/com/lucidworks/spark/SolrRelation.scala @@ -64,6 +64,8 @@ class SolrRelation( checkRequiredParams() + SolrSupport.doBasicAuthByOptsIfUsed(parameters) + lazy val solrVersion : String = SolrSupport.getSolrVersion(conf.getZkHost.get) lazy val initialQuery: SolrQuery = SolrRelation.buildQuery(conf) // we don't need the baseSchema for streaming expressions, so we wrap it in an optional diff --git a/src/main/scala/com/lucidworks/spark/util/SolrSupport.scala b/src/main/scala/com/lucidworks/spark/util/SolrSupport.scala index f7b33943..3d168bb4 100644 --- a/src/main/scala/com/lucidworks/spark/util/SolrSupport.scala +++ b/src/main/scala/com/lucidworks/spark/util/SolrSupport.scala @@ -19,9 +19,10 @@ import org.apache.solr.client.solrj.impl._ import org.apache.solr.client.solrj.request.UpdateRequest import org.apache.solr.client.solrj.response.QueryResponse import org.apache.solr.client.solrj.{SolrClient, SolrQuery, SolrServerException} +import org.apache.solr.client.solrj.impl.HttpClientUtil import org.apache.solr.common.cloud._ import org.apache.solr.common.{SolrDocument, SolrException, SolrInputDocument} -import org.apache.solr.common.params.ModifiableSolrParams +import org.apache.solr.common.params.{MapSolrParams, ModifiableSolrParams} import org.apache.solr.common.util.SimpleOrderedMap import org.apache.spark.rdd.RDD import org.apache.spark.streaming.dstream.DStream @@ -78,6 +79,7 @@ object SolrSupport extends LazyLogging { val AUTH_CONFIGURER_CLASS = "auth.configurer.class" val SOLR_VERSION_PATTERN = Pattern.compile("^(\\d+)\\.(\\d+)(\\.(\\d+))?.*") + var basicAuthBySparkOpts = false; //false: via system properties or no auth true: via spark options or no auth def getSolrVersion(zkHost: String): String = { val sysQuery = new SolrQuery @@ -733,6 +735,45 @@ object SolrSupport extends LazyLogging { splits.toList } + /** + * To support basic auth by Spark options, + * it set username and password in PreemptiveBasicAuthClientBuilderFactory by spark options and get cached client immediately + * It does not support using both Spark options and system properties for basic auth, + * username and password in spark options will cover credentials in system properties. + * + * @param params Spark options + */ + def doBasicAuthByOptsIfUsed(params: Map[String, String]): Unit ={ + val usernameInOpts = params.get(HttpClientUtil.PROP_BASIC_AUTH_USER) + val passwordInOpts = params.get(HttpClientUtil.PROP_BASIC_AUTH_PASS) + + if (!basicAuthBySparkOpts){ + if (usernameInOpts == None) { + return + } else { + val credentials = System.getProperty(PreemptiveBasicAuthClientBuilderFactory.SYS_PROP_BASIC_AUTH_CREDENTIALS) + val configFile = System.getProperty(PreemptiveBasicAuthClientBuilderFactory.SYS_PROP_HTTP_CLIENT_CONFIG) + if (null != credentials || null != configFile){ + logger.warn(s""" There is "${PreemptiveBasicAuthClientBuilderFactory.SYS_PROP_BASIC_AUTH_CREDENTIALS}" or "${PreemptiveBasicAuthClientBuilderFactory.SYS_PROP_HTTP_CLIENT_CONFIG}" """ + + s""" in system properties and "${HttpClientUtil.PROP_BASIC_AUTH_USER}" in spark options. Credentials in spark options will cover credentials in system properties""") + } + basicAuthBySparkOpts = true + } + } + + if (usernameInOpts == None || passwordInOpts == None){ + System.clearProperty(PreemptiveBasicAuthClientBuilderFactory.SYS_PROP_BASIC_AUTH_CREDENTIALS) + PreemptiveBasicAuthClientBuilderFactory.setDefaultSolrParams(null) + } else { + logger.info(s"basic auth info zkHost:${params.getOrElse("zkhost", "")} username:${usernameInOpts.get}") + System.setProperty(PreemptiveBasicAuthClientBuilderFactory.SYS_PROP_BASIC_AUTH_CREDENTIALS, usernameInOpts.get + ":" + passwordInOpts.get) + val authParams = Map(HttpClientUtil.PROP_BASIC_AUTH_USER -> usernameInOpts.get, + HttpClientUtil.PROP_BASIC_AUTH_PASS -> passwordInOpts.get) + PreemptiveBasicAuthClientBuilderFactory.setDefaultSolrParams(new MapSolrParams(authParams)) + } + getCachedCloudClient(params.getOrElse("zkhost", "")) + } + case class WorkerShardSplit(query: SolrQuery, replica: SolrReplica) case class ExportHandlerSplit(query: SolrQuery, replica: SolrReplica, numWorkers: Int, workerId: Int) } diff --git a/src/test/scala/com/lucidworks/spark/TestBasicAuthBySparkOpts.scala b/src/test/scala/com/lucidworks/spark/TestBasicAuthBySparkOpts.scala new file mode 100644 index 00000000..2d691462 --- /dev/null +++ b/src/test/scala/com/lucidworks/spark/TestBasicAuthBySparkOpts.scala @@ -0,0 +1,66 @@ +package com.lucidworks.spark + +import java.util.UUID + +import com.lucidworks.spark.util.{SolrCloudUtil, SolrSupport} +import org.apache.spark.sql.SaveMode.Overwrite +import org.apache.spark.sql._ +import org.apache.spark.sql.types._ +import org.apache.zookeeper.{WatchedEvent, Watcher, ZooKeeper} + + +class TestBasicAuthBySparkOpts extends TestSuiteBuilder { + val securityJson = "{\n\"authentication\":{ \n \"blockUnknown\": true, \n \"class\":\"solr.BasicAuthPlugin\",\n \"credentials\":{\"solr\":\"IV0EHq1OnNrj6gvRCwvFwTrZ1+z1oBbnQdiVC3otuq0= Ndd7LKvVBAaZIF0QAVi1ekCfAJXr1GGfLtRUXhgrF8c=\"} \n},\n\"authorization\":{\n \"class\":\"solr.RuleBasedAuthorizationPlugin\",\n \"permissions\":[{\"name\":\"security-edit\",\n \"role\":\"admin\"}], \n \"user-role\":{\"solr\":\"admin\"} \n}}" + + test("auth by spark options"){ + val collectionName = "testBasicAuth-" + UUID.randomUUID().toString + SolrCloudUtil.buildCollection(zkHost, collectionName, null, 1, cloudClient, sc) + + // Enable basic authentication + val zk = new ZooKeeper(zkHost, 500000, new Watcher() { + override def process(event: WatchedEvent): Unit = {} + }) + val bytes: Array[Byte] = zk.getData("/security.json", false, null) + zk.setData("/security.json", securityJson.getBytes, -(1)) + + try { + val csvDF = buildTestData() + val solrOpts = Map("zkhost" -> zkHost, + "httpBasicAuthUser"->"solr", + "httpBasicAuthPassword"->"SolrRocks", + "collection" -> collectionName) + csvDF.write.format("solr").options(solrOpts).mode(Overwrite).save() + + // Explicit commit to make sure all docs are visible + val solrCloudClient = SolrSupport.getCachedCloudClient(zkHost) + solrCloudClient.commit(collectionName, true, true) + + val solrDF = sparkSession.read.format("solr").options(solrOpts).load() + assert(solrDF.count == 3) + + }finally { + zk.setData("/security.json", bytes, -(1)) + zk.close() + SolrCloudUtil.deleteCollection(collectionName, cluster) + } + } + + def buildTestData() : DataFrame = { + val testDataSchema : StructType = StructType( + StructField("id", IntegerType, true) :: + StructField("one_txt", StringType, false) :: + StructField("two_txt", StringType, false) :: + StructField("three_s", StringType, false) :: Nil) + + val rows = Seq( + Row(1, "A", "B", "C"), + Row(2, "C", "D", "E"), + Row(3, "F", "G", "H") + ) + + val csvDF : DataFrame = sparkSession.createDataFrame(sparkSession.sparkContext.makeRDD(rows, 1), testDataSchema) + assert(csvDF.count == 3) + return csvDF + } + +}