SparkHiveUtil

object SparkHiveUtil {


  def output[T](spark: SparkSession, outputTables: Array[String], partitions: Array[String], res: Dataset[T], dupliCols: Seq[String], paraNum: Int = 16, isPersist: Boolean = true) = {

    spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
    spark.sql("set hive.exec.dynamic.partition.mode=nonstrict")

    spark.sql("set hive.exec.max.dynamic.partitions=9000")

    spark.sql("set hive.enforce.bucketing=true")

    spark.sqlContext.setConf("spark.sql.hive.convertMetastoreParquet", "false")

    outputTables.map(_.trim).foreach(table => {

      println("输出到hive的表【" + table + "】")
      res.write
        .mode(SaveMode.Append)
        .format("Hive")
        .partitionBy(partitions: _*)
        .saveAsTable(table)

    })

    if (isPersist) res.unpersist()

    println("当前时间:" + DateTime.now().toString)
  }


}

SparkHiveUtil.output[STD_GPS_DRIVE](spark, outTables.split(","), Array("source_company", "city", "year", "month", "day"), output, dupliCols = Seq("loc_time", "vehicle_id"))


PhoenixJdbcUtils

object PhoenixJdbcUtils extends Serializable {

  private type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit
  private val logger = LoggerFactory.getLogger("")

  def saveTable(df: DataFrame, table: String, batchSize: Int) {
    println("输出phoenix")
    //    df.show(false)

    try {

      val columns = df.columns.mkString(",")

      val rddSchema = df.schema

      val recordFormat: scala.collection.mutable.StringBuilder = new scala.collection.mutable.StringBuilder()

      df.dtypes.foreach(x => {

        x._2 match {

          case "StringType" => recordFormat.append("?,")

          case _ => recordFormat.append("?,")

        }

      })

      val placeholders = recordFormat.stripSuffix(",")

      df.foreachPartition(iterator => {

        savePartition(getConn, table, iterator, columns, placeholders, rddSchema, batchSize)

      })

    }

    catch {

      case e: Exception => logger.error(e.toString)

    }

  }

  //  private def getJdbcType(dt: DataType): JdbcType = {
  //
  //    getCommonJDBCType(dt).getOrElse(
  //
  //      throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))
  //
  //  }

  def getConn(): Connection = {
    DriverManager.getConnection(DBUrls.phoenix_url, DBUrls.phoenix_prop)
  }

  def savePartition(

                     conn: Connection,

                     table: String,

                     iterator: Iterator[Row],

                     columns: String,

                     placeholders: String,

                     rddSchema: StructType,

                     batchSize: Int): Iterator[Byte] = {

    try {

      var committed = false

      val stmt = insertStatement(conn, table, columns, placeholders)

      val setters = rddSchema.fields.map(_.dataType).map(makeSetter(_)).toArray

      val numFields = rddSchema.fields.length

      try {

        var rowCount = 0

        var writeBatch = 0

        while (iterator.hasNext) {

          val row = iterator.next()

          var i = 0

          while (i < numFields) {

            if (row.isNullAt(i)) {

              stmt.setNull(i + 1, i)

            } else {

              setters(i).apply(stmt, row, i)

            }

            i = i + 1

          }

          stmt.addBatch()

          rowCount += 1

          if (rowCount % batchSize == 0) {

            stmt.executeBatch()

            writeBatch = writeBatch + 1

            rowCount = 0

          }

        }

        if (rowCount > 0) {

          stmt.executeBatch()

        }

      } finally {

        stmt.close()

      }

      conn.commit()

      committed = true

      if (iterator.isEmpty) {

        conn.close()

      }

      Iterator.empty

    } catch {

      case e: SQLException =>

        val cause = e.getNextException

        if (cause != null && e.getCause != cause) {

          if (e.getCause == null) {

            e.initCause(cause)

          } else {

            e.addSuppressed(cause)

          }

        }

        throw e

    } finally {

      conn.close()

    }

  }

  def insertStatement(conn: Connection, table: String, columns: String, placeholders: String)

  : PreparedStatement = {

    //这里定义写入phoenix的SQL方言

    val sql = s"UPSERT INTO $table ($columns) VALUES ($placeholders)"

    conn.prepareStatement(sql)

  }

  private def makeSetter(dataType: DataType): JDBCValueSetter = dataType match {

    case IntegerType =>

      (stmt: PreparedStatement, row: Row, pos: Int) =>

        stmt.setInt(pos + 1, row.getInt(pos))

    case LongType =>

      (stmt: PreparedStatement, row: Row, pos: Int) =>

        stmt.setLong(pos + 1, row.getLong(pos))

    case DoubleType =>

      (stmt: PreparedStatement, row: Row, pos: Int) =>

        stmt.setDouble(pos + 1, row.getDouble(pos))

    case FloatType =>

      (stmt: PreparedStatement, row: Row, pos: Int) =>

        stmt.setFloat(pos + 1, row.getFloat(pos))

    case ShortType =>

      (stmt: PreparedStatement, row: Row, pos: Int) =>

        stmt.setInt(pos + 1, row.getShort(pos))

    case ByteType =>

      (stmt: PreparedStatement, row: Row, pos: Int) =>

        stmt.setInt(pos + 1, row.getByte(pos))

    case BooleanType =>

      (stmt: PreparedStatement, row: Row, pos: Int) =>

        stmt.setBoolean(pos + 1, row.getBoolean(pos))

    case StringType =>

      (stmt: PreparedStatement, row: Row, pos: Int) =>

        stmt.setString(pos + 1, row.getString(pos))

    case TimestampType =>

      (stmt: PreparedStatement, row: Row, pos: Int) =>

        stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))

    case DateType =>

      (stmt: PreparedStatement, row: Row, pos: Int) =>

        stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))

    case t: DecimalType =>

      (stmt: PreparedStatement, row: Row, pos: Int) =>

        stmt.setBigDecimal(pos + 1, row.getDecimal(pos))

    case t: ArrayType =>

      (stmt: PreparedStatement, row: Row, pos: Int) =>

        stmt.setArray(pos + 1, row.getAs[java.sql.Array](pos))

    case _ =>

      (_: PreparedStatement, _: Row, pos: Int) =>

        throw new IllegalArgumentException(

          s"Can't translate non-null value for field $pos")

  }
}

SparkHbaseUtil

class SparkHbaseUtil(zookeeper_address: String, zookeeper_port: String, tableName: String) extends Serializable {


  def write[T](spark: SparkSession, ds: Dataset[T]) = {

    //    val out_config = spark.sparkContext.broadcast(
    //      Map[String, Object](
    //        "bootstrap.servers" -> brokers,
    //        ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG -> "30000",
    //        "key.serializer" -> "org.apache.kafka.common.serialization.StringSerializer",
    //        "value.serializer" -> "org.apache.kafka.common.serialization.StringSerializer"))
    //
    //    ds.foreachPartition { it =>
    //      val sink = KafkaSink[String, T](out_config.value)
    //      it.foreach(v => sink.send(topic, v))
    //      sink.producer.close()
    //    }


    val sc = spark.sparkContext

    val map = Map(
      "taxi" -> 2,
      "bus" -> 32,
      "jiaolian" -> 8,
      "keyun" -> 1,
      "tour" -> 16,
      "huoyun" -> 4,
      "private" -> 516,
      "danger" -> 64,
      "heavy" -> 256
    )


    sc.textFile("/data/heavy/*")
      .foreachPartition(iterator => {

        val conf = HBaseConfiguration.create
        conf.set("hbase.zookeeper.quorum", zookeeper_address)
        conf.set("hbase.zookeeper.property.clientPort", zookeeper_port)

        val conn = ConnectionFactory.createConnection(conf)
        val admin = conn.getAdmin

        val putlist = new util.ArrayList[Put]();

        iterator.foreach(x => {
          val arr = x.split(",")

          //          val date = arr(0).trim

          val date = "20190927"

          val time = ""
          val vechid = arr(3)
          val lng = arr(4).trim
          val lat = arr(5).trim
          val cartType = 256.toString

          val v_head = vechid.charAt(0)

          val rowkey_vechid = if (v_head >= 0x4E00 && v_head <= 0x29FA5) "YUE_" + vechid.substring(1) else ("YUE_" + vechid)

          val rowkey: String = rowkey_vechid + "_" + date + "_" + time

          val put: Put = new Put(rowkey.getBytes)
          val f = Bytes.toBytes("cf")

          put.addColumn(f, Bytes.toBytes("vechid"), Bytes.toBytes(vechid))
          put.addColumn(f, Bytes.toBytes("cartype"), Bytes.toBytes(cartType))
          put.addColumn(f, Bytes.toBytes("date"), Bytes.toBytes(date))
          put.addColumn(f, Bytes.toBytes("lat"), Bytes.toBytes(lat))
          put.addColumn(f, Bytes.toBytes("lng"), Bytes.toBytes(lng))
          put.addColumn(f, Bytes.toBytes("time"), Bytes.toBytes(time))

          putlist.add(put)
        })

        try {
          val tableNameObj = TableName.valueOf(tableName)
          if (admin.tableExists(tableNameObj)) {
            val table = conn.getTable(tableNameObj)
            table.put(putlist)
            table.close()
            admin.close()
          }
        } catch {
          case e: Exception =>
            e.printStackTrace()
        }
      })


  }

}


SparkKafkaUtil

class SparkKafkaUtil(brokers: String, topic: String) extends Serializable {

  def input(ssc: StreamingContext, groupId: String): DStream[ConsumerRecord[String, String]] = {

    val inputStream = {
      val kafkaParams = Map[String, Object](
        ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG -> brokers,
        ConsumerConfig.GROUP_ID_CONFIG -> groupId,
        //        ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG -> "true",
        "auto.commit.interval.ms" -> "10000",
        "enable.auto.commit" -> "false",
        ConsumerConfig.AUTO_COMMIT_INTERVAL_MS_CONFIG -> "1000",
        ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG -> "30000",
        ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG -> "org.apache.kafka.common.serialization.StringDeserializer",
        ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG -> "org.apache.kafka.common.serialization.StringDeserializer",
        //                ConsumerConfig.AUTO_OFFSET_RESET_CONFIG -> "earliest"
        ConsumerConfig.AUTO_OFFSET_RESET_CONFIG -> "latest",
        ConsumerConfig.FETCH_MAX_BYTES_CONFIG -> "52428800",
        ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG -> "30000",
        ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG -> "100000"
        //      "group.max.session.timeout.ms" -> "300000"
        //        "linger.ms" -> "50",
        //        "acks" -> "all",
        //        "retries" -> 30.toString,
        //        "reconnect.backoff.ms" -> 20000.toString,
        //        "retry.backoff.ms" -> 20000.toString,
        //        "unclean.leader.election.enable" -> false.toString,
        //        "enable.auto.commit" -> false.toString,
        //        "max.in.flight.requests.per.connection" -> 1.toString
      )

      KafkaUtils.createDirectStream[String, String](ssc,
        //PreferFixed
        LocationStrategies.PreferConsistent,
        ConsumerStrategies.Subscribe[String, String](topic.split(",").toSet, kafkaParams))
    }

    //异步队列提交offset
    inputStream.foreachRDD { rdd =>
      val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges

      // some time later, after outputs have completed
      inputStream.asInstanceOf[CanCommitOffsets].commitAsync(offsetRanges)
    }

    inputStream
  }


  def output[T](spark: SparkSession, ds: Dataset[T]) = {

    println("输出到kafka")
    ds.show(false)

    // 初始化KafkaSink,并广播
    val kafkaProducer: Broadcast[KafkaSink[String, String]] = {
      val kafkaProducerConfig = {
        val p = new Properties()
        p.setProperty("bootstrap.servers", brokers)
        p.setProperty(ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG, "30000")
        p.setProperty("key.serializer", classOf[StringSerializer].getName)
        p.setProperty("value.serializer", classOf[StringSerializer].getName)
        p
      }
      //      println("kafka producer init done!")
      spark.sparkContext.broadcast(KafkaSink[String, String](kafkaProducerConfig))
    }

    ds.foreachPartition { it =>
      val sink = kafkaProducer.value
      it.foreach(record => {
        // 使用广播变量发送到Kafka
        sink.send(topic, record.toString)
      })
    }

  }

  def outputRDD[T](spark: SparkSession, rdd: RDD[T]) = {
    println("输出到kafka")

    val out_config = spark.sparkContext.broadcast(
      Map[String, Object](
        "bootstrap.servers" -> brokers,
        ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG -> "30000",
        "key.serializer" -> "org.apache.kafka.common.serialization.StringSerializer",
        "value.serializer" -> "org.apache.kafka.common.serialization.StringSerializer"))


    //    val kafkaProducer: Broadcast[KafkaSink[String, String]] = {
    //      val kafkaProducerConfig = {
    //        val p = new Properties()
    //        p.setProperty("bootstrap.servers", "192.168.2.116:9092")
    //        p.setProperty("key.serializer", classOf[StringSerializer].getName)
    //        p.setProperty("value.serializer", classOf[StringSerializer].getName)
    //        p
    //      }
    //      sc.broadcast(KafkaSink[String, String](kafkaProducerConfig))
    //    }

    rdd.foreachPartition { it =>
      val sink = KafkaSink[String, T](out_config.value)
      it.foreach(v => sink.send(topic, v))
      sink.producer.close()
    }


    //    val kafkaProducer: Broadcast[KafkaSink[String, String]] = {
    //      val kafkaProducerConfig = {
    //        val p = new Properties()
    //        p.setProperty("bootstrap.servers", "192.168.2.116:9092")
    //        p.setProperty("key.serializer", classOf[StringSerializer].getName)
    //        p.setProperty("value.serializer", classOf[StringSerializer].getName)
    //        p
    //      }
    //      sc.broadcast(KafkaSink[String, String](kafkaProducerConfig))
    //    }
    //
    //    rdd.foreach(record=>{
    //      kafkaProducer.value.send("lili", record)
    //    })


    //    ds.write
    //      .format("kafka")
    //      .option("kafka.bootstrap.servers", brokers)
    //      .option("topic", topic)
    //      .option("checkpointLocation", "data/checkpoint/" + System.currentTimeMillis())

  }
}

数据标准化

object StdUtils {


  def isCnLngLat(lng: Double, lat: Double): Boolean = {

    //纬度3.86~53.55,经度73.66~135.05

    lng >= 73.66 && lng <= 135.05 && lat >= 3.86 && lat <= 53.55

  }

  def main(args: Array[String]): Unit = {
    println(fillTime("2018-10-01 00:02:35.000"))
  }

  def fillTime(date: String, pattern: String = "yyyy-MM-dd HH:mm:ss"): String = {

    if (StdUtils.isTime(date, Array(pattern))) {
      return date
    }

    var time1 = date

    if (date.length > 19) {
      time1 = date.substring(0, 19)
      return time1
    }

    val time2 = "1971-01-01 00:00:00"

    time1 + time2.substring(time1.length)

  }

  def isMonth(str: String): Boolean = {

    val parsePatterns = Array("yyyy/MM", "yyyy-MM", "yyyyMM")

    return isTime(str, parsePatterns)
  }


  //  val parsePatterns = Array("yyyy-MM-dd", "yyyy-MM-dd HH:mm:ss", "yyyy-MM-dd HH:mm", "yyyy/MM/dd", "yyyy/MM/dd HH:mm:ss", "yyyy/MM/dd HH:mm", "yyyyMMdd")

  def isTime(str: String, parsePatterns: Array[String]): Boolean = {
    try {
      val date = DateUtils.parseDate(str, null, parsePatterns: _*)
      if (date.getTime > DateTime.now().getMillis) {
        return false;
      }
    } catch {
      case e: Exception => return false
    }

    return true;
  }

  def isDate(str: String): Boolean = {

    val parsePatterns = Array("yyyy/MM/dd", "yyyy-MM-dd", "yyyyMMdd")

    return isTime(str, parsePatterns)

  }

  //  def main(args: Array[String]): Unit = {
  //    val res = isTime("20571501040057", Array("yyyyMMddHHmmss"))
  //    println(res)
  //  }

  def isHour(str: String): Boolean = {

    val parsePatterns = Array("yyyy/MM/dd/HH", "yyyy-MM-dd HH", "yyyyMMdd HH")

    return isTime(str, parsePatterns)
  }

  def replaceBlank(str: String) = {
    str.replaceAll(" ", "")
  }

  /**
    * 临时替代数据年份
    *
    * @param date
    * @param year
    * @return
    */
  def replaceYear(date: String, year: String = "2017"): String = {
    if (!date.startsWith(year)) {
      year + date.substring(4)
    } else {
      date
    }
  }

  /**
    * 同济需求
    *
    * @param date
    * @param yearMonth
    * @return
    */
  def replaceYearMonth(date: String, yearMonth: String = "201710"): String = {
    if (date.startsWith("20190312")) {
      return "20171023"
    }
    if (!date.startsWith(yearMonth)) {
      yearMonth + date.substring(6)
    } else {
      date
    }
  }


  //  def main(args: Array[String]): Unit = {
  //   val arr = "2939131247,19980077,-15093343,201,2017-01-01 15:34:23,1,0,0,1,0,0,0,0,0,0,0,0,131856,131856,131856,131856,16.987055,16.987055,16.987055,16.987055,866,866,866,866,130896,130896,130896,130896,4.746764,4.746764,4.746764,4.746764,0,0,0,0,0,0,1".split(",")
  //
  //    println(arr)
  //    println(arr.length)
  //  }

  /**
    * 字符串前补0
    *
    * @param fillsize
    * @param data
    * @return
    */
  def fillZero(fillsize: Int, data: String): String = {


    val prefix = new StringBuilder

    Range(0, fillsize - data.trim.length).foreach(x => {
      prefix.append("0")
    })

    prefix.append(data.trim).toString()
  }


  /**
    * 判断全不为空
    *
    * @param arr
    * @return
    */
  def allNotEmpty(arr: Array[String]): Boolean = {
    var res = true
    for (e <- arr if res) {
      res = StringUtils.isNotEmpty(e)
    }
    res
  }


}

UDFUtils

class UDFUtils {

  def registerUdf(sqlContext: HiveContext) {

    sqlContext.udf.register("strLen", (str: String) => str.length())

    sqlContext.udf.register("concat", (str1: String, str2: String, str3: String) => str1 + str2 + str3)

    sqlContext.udf.register("concat4", (str1: String, str2: String, str3: String, str4: String) => str1 + str2 + str3 + str4)

    sqlContext.udf.register("regexp_extract", (str: String, pattern: String) => {
      val matcher = Pattern.compile(pattern, 1).matcher(str)
      var res = ""
      while (matcher.find()) {
        res = matcher.group()
      }
      res
    })

    sqlContext.udf.register("getHost", (url: String) => {
      var strURL = "";
      try {
        strURL = url.toString();
        if (strURL.contains("://") && (strURL.indexOf("://") < 6) && strURL.length() > (strURL.indexOf("://") + 4)) {
          strURL = strURL.substring(strURL.indexOf("://") + 3);
        }

        if (strURL.contains("/")) {
          strURL = strURL.substring(0, strURL.indexOf("/"));
        }

        if (strURL.contains(":")) {
          strURL = strURL.substring(0, strURL.indexOf(":"));
        }

      } catch {
        case e: Exception => println("registerUdf Exception")
      }
      strURL;
    })

  }
}

object UDFUtils {
  def apply() = new UDFUtils();
}



Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐