pyspark开发总结笔记

2023-01-01 08:15:36

本文记录spark开发过程中遇到的小知识点,使用pyspark开发,由于使用大多数场景为DataFrame,介绍也多为DataFrame。本文比较长,在学习过程中摘了一些博客和资料,如果有描述的不对的地方请指出。

Spark是分布式内存计算,能够依据各类操作创建一个计算DAG图,数据通过DAG处理后生成结果。

对spark的数据操作分为两类,一类是转换(transformation)操作,比如Filter、map、flatMap、reduce等,但是这些操作是懒转换,只在action的时候才真正的对数据做处理;另一类是action为操作,比如collect、show、count、first等,它们能够触发数据的计算得到结果。

数据集分为RDD、DataFrame、DataSet,其中DataFrame可以看做带格式的RDD,(因为格式确定,所以处理计算效率高于RDD),对DataFrame的操作可以视为对一张数据表操作,由于数据集的不可变特性,不能够修改原有DataFrame,而只能创建新的DataFrame,如增加列生成新的DF,删减列生成新的DF等。由于对DataFrame类似于表,spark提供了SQL的方式进行计算和操作,很多计算可以直接通过SQL的方式解决掉。

注意,虽然PySpark中也叫DataFrame,但是其与Pandas的DataFrame的很多操作不一样,虽然有很多类似之处。

1. 常用的引入包

from pyspark.sql import Row
from pyspark.sql.functions import col, isnan, isnull
from pyspark.sql import SparkSession  # SparkConf、SparkContext 和 SQLContext 都已经被封装在 SparkSession
from pyspark.sql.types import *

2. 创建spark

# 创建spark 如果是pypsark的话,直接用内置的spark变量
spark = SparkSession.builder.appName('test pyspark').getOrCreate()

3. 创建DataFrame

# 通过读取数据集来创建DataFrame
# 1. 参考本文档读取操作

# 2. 通过RDD
client_rdd = spark.sparkContext.parallelize([
    ('20180701', '1111', 0.1),
    ('20180801', "1111", 0.2),
    ('20180901', "1111", 0.3),
])

client_schema = StructType([
    StructField("date", StringType(), True),
    StructField("client_id", StringType(), True),
    StructField("cash", DoubleType(), True)
])
client_df = spark.createDataFrame(client_rdd, client_schema)

4. Pandas与PySpark的DataFrame

# pandas DataFrame 与 pyspark的DataFrame相互转化
pandas_df = spark_df.toPandas()  # spark df转pandas df
spark_df = spark.createDataFrame(pandas_df.values.tolist(), list(pandas_df.columns))  # pandas df转spark df
spark_df = spark.createDataFrame(pandas_df)  # pandas df转spark df

5. DataFrame基础操作

5.1查看DataFrame

# 查看DataFrame
df.show(5, False)
df.first()
df.head(5)
df.take(5)
df.collect()  # 将df整体以list的形式返回,不要在大数据集的情况用这种方法
df.collectAsMap()  # ?
df.count()  # 统计df的行数
df.schema  # df的结构
df.printSchema()  # 树的形式打印df的结构
df.columns  # 查看列名

5.2缓存DataFrame

# 缓存df,对于多次调用一些小的数据集,如果不缓存,则在计算的时候会多次加载,缓存能提高效率
df.cache()
df.persist()

5.3列操作

def essure_asset_type(stock_code):
    '''随机赋予证券代码的类别属性'''
    return random.choice(["equity", "fixincome", "cash"])
 
asset_type = udf(essure_asset_type, StringType())  # 创建用户自定义函数

df = df.withColumn('asset_type', asset_type(df.symbol))  # 增加新的列,处理的函数必须是处理列的
df = df.withColumn("differ", df.col1 - df.col2)  # 增加新的列
df = df.withColumnRenamed('age', 'age2')  # 更改列名
df = df.drop(df.col1)  # 删除列
df = df.drop("col1")

5.4其他

df.rdd  # df 转 rdd
df.rdd.getNumPartitions()  # 查看分区数

6. DataFrame SQL操作

## DataFrame的类SQL操作 DataFrame通过select、where、groupBy、sum等实现了类SQL的操作
client_df.groupBy(client_df.date, client_df.client_id).sum('cash')

# 两个DF的join操作,不得不说join的存在是在太方便了
join_df = client_df.join(bench_df, client_df.date == bench_df.date, 'left')

# DataFrame的列选择操作
df = df.select("orders", "traders")  # 通过选择列生成新的df
df = df.select(df.orders, df.traders)

# DataFrame的过滤操作
df = df.filter("trader=1111")
df = df.where("trader=1111")
df = df.filter(df.trader == '1111')
df = df.where(df.trader == '1111')
df = df.filter(col('trader').like('%1111%'))
df = df.filter(isnull("trader"))

# DataFrame的SQL操作,支持常用的sql操作,如select、where、group by、count、like等
df.registerTempTable("df")  # 注册了一个名叫 df 的表
df = spark.sql("select orders from df where trader='1111'")
df = spark.sql("select count(orders) from df where trader='1111'")

7. 文件数据读取

# 读取json
df = spark.read.json(json_file_path)

# 读取csv
df = spark.read.csv(csv_file_path, header=True, inferSchema=True)

# 写入csv
df.write.csv(path=csv_file_path, header=True, sep=",", mode='overwrite')

# 读取MySQL
df = spark.read.format('jdbc').options(url='jdbc:oracle:thin:@ip:port:database',
                                       dbtable='table_name or select sql as table', user='user_name',
                                       password='password').load()

# 读取Oracle, 需要在提交时指定JDBC
df = spark.read.format('jdbc').options(url='jdbc:oracle:thin:@ip:port:database',
                                       dbtable='table_name or select sql as table', user='user_name',
                                       password='password').load()

# 读取Hive,注:该方法我没试过
from pyspark.sql import SparkSession
spark = SparkSession.builder.enableHiveSupport().master("172.31.100.170:7077").appName("my_first_app_name").getOrCreate()
df=spark.sql("select * from hive_tb_name")

# 读取HDFS文件,如parquet
df = spark.read.parquet(parquet_file_path)
# 写入parquet文件
df.write.parquet(path=parquet_file_path, mode='overwrite')

# 读取Impala
from impala.dbapi import connect
conn = connect(host='ip', port=port, user='user_name', password='password')
cur = conn.cursor(user='user_name')
cur.execute("select sql;")
rdd = spark.sparkContext.parallelize(cur.fetchall())



# 写入MySQL,需要在提交时指定JDBC
df.write.mode("append").format("jdbc").options(url='jdbc:mysql://ip:port/database',
                                               user='user_name', password='password',
                                               dbtable='table', batchsize="1000").save()

8. 例子

8.1word count例子

from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('test pyspark').getOrCreate()
lines_df = spark.read.text("/user/spark/test/test.txt")  # z这个是一个DataFrame
lines = lines_df.rdd.map(lambda x: x[0])  # df.rdd  可以将数据由DataFrame类型转化为RDD类型
counts = lines.flatMap(lambda x: x.split(' ')).map(lambda x: (x, 1)).reduceByKey(lambda a, b: a + b)
print(counts.collect())
counts.saveAsTextFile("/user/spark/test/count")

8.2pagerank

def compute_contribs(urls, rank):
    '''
    计算每个节点新的pr分值
    '''
    num_urls = len(urls)
    for url in urls:
        yield (url, rank * 1.0 / num_urls)  # (节点,节点权重*占比)


# 初始化,"1 2"表示节点1有一条路径(连接)到2
lines = sc.parallelize(["1 2", "1 3", "1 4", "2 4", "2 1", "3 1", "4 3", "4 2"])
# links表示每个节点连接的节点列表[('2', ['4', '1']), ('1', ['2', '3', '4']), ('4', ['3', '2']), ('3', ['1'])]
links = lines.map(lambda line: line.split()).groupByKey().mapValues(lambda x: list(x)).cache()
# 每个节点初始化pr值为1,[('2', 1), ('1', 1), ('4', 1), ('3', 1)]
ranks = links.keys().map(lambda x: (x, 1))

for i in range(10):
    # 将各个节点初始的分数分发给相邻的每个节点
    contribs = links.join(ranks).flatMap(lambda r: compute_contribs(r[1][0], r[1][1]))
    # 将每个节点的分数值汇总
    ranks = contribs.reduceByKey(lambda x, y: x + y).mapValues(lambda x: x)

for link, rank in ranks.collect():
    print("%s has rank %s." % (link, rank))

9. 程序提交

# 提交pyspark程序,可以指定运行配置,注:无论是pyspark shell模式,还是提交程序到yarn上运行,如果需要用到相关JDBC等jar包,需要指定
spark2-submit spark_learn.py
spark2-submit --master local[2] spark_learn.py  # 本地模式,2个节点并行
spark2-submit --master yarn spark_learn.py  # 提交到yarn上去执行

spark2-submit \
    --master yarn \
    --deploy-mode cluster \
    --jars thirdparty/jars/ojdbc6.jar,thirdparty/jars/mysql-connector-java-5.1.42-bin.jar,thirdparty/jars/hive-metastore-1.1.0-cdh5.13.0.jar,thirdparty/jars/hive-service-1.1.0-cdh5.13.0.jar,thirdparty/jars/ImpalaJDBC41.jar \
    --driver-class-path ojdbc6.jar:mysql-connector-java-5.1.42-bin.jar:hive-metastore-1.1.0-cdh5.13.0.jar:hive-service-1.1.0-cdh5.13.0.jar:ImpalaJDBC41.jar \
    --executor-memory 14G \
    --driver-memory 6G \
    --conf spark.app.name='test' \
    --conf spark.default.parallelism=50 \
    --conf spark.memory.fraction=0.85 \
    --conf spark.memory.storageFraction=0.5 \
    --conf spark.yarn.executor.memoryOverhead=2048 \
    --conf spark.yarn.driver.memoryOverhead=1024 \
    --conf spark.serializer=org.apache.spark.serializer.KryoSerializer \
    --conf spark.yarn.maxAppAttempts=1 \
    spark_learn.py \

pyspark2 \
    --jars thirdparty/jars/ojdbc6.jar,thirdparty/jars/mysql-connector-java-5.1.42-bin.jar \
    --driver-class-path thirdparty/jars/ojdbc6.jar:thirdparty/jars/mysql-connector-java-5.1.42-bin.jar

参考文档

  1. 官方文档:http://spark.apache.org/docs/latest/api/python/pyspark.sql.html
  2. pyspark系列–读写dataframe: https://blog.csdn.net/suzyu12345/article/details/79673473
  3. 《Spark快速大数据分析》
  • 作者:白熊花田
  • 原文链接:https://blog.csdn.net/whiterbear/article/details/84629256
    更新时间:2023-01-01 08:15:36