DeepLearningExamples/TensorFlow2/Recommendation/WideAndDeep/data/outbrain/spark/preproc1.py
2021-03-04 14:25:59 +01:00

105 lines
4 KiB
Python

#!/usr/bin/env python
# coding: utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pyspark.context import SparkContext, SparkConf
from pyspark.sql.functions import col
from pyspark.sql.session import SparkSession
from pyspark.sql.types import IntegerType, StringType, StructType, StructField
OUTPUT_BUCKET_FOLDER = "/tmp/spark/preprocessed/"
DATA_BUCKET_FOLDER = "/outbrain/orig/"
SPARK_TEMP_FOLDER = "/tmp/spark/spark-temp/"
conf = SparkConf().setMaster('local[*]').set('spark.executor.memory', '40g').set('spark.driver.memory', '200g').set(
"spark.local.dir", SPARK_TEMP_FOLDER)
sc = SparkContext(conf=conf)
spark = SparkSession(sc)
print('Loading data...')
events_schema = StructType(
[StructField("display_id", IntegerType(), True),
StructField("uuid_event", StringType(), True),
StructField("document_id_event", IntegerType(), True),
StructField("timestamp_event", IntegerType(), True),
StructField("platform_event", IntegerType(), True),
StructField("geo_location_event", StringType(), True)]
)
events_df = spark.read.schema(events_schema) \
.options(header='true', inferschema='false', nullValue='\\N') \
.csv(DATA_BUCKET_FOLDER + "events.csv") \
.withColumn('day_event', (col('timestamp_event') / 1000 / 60 / 60 / 24).cast("int")) \
.alias('events')
events_df.count()
print('Drop rows with empty "geo_location"...')
events_df = events_df.dropna(subset="geo_location_event")
events_df.count()
print('Drop rows with empty "platform"...')
events_df = events_df.dropna(subset="platform_event")
events_df.count()
promoted_content_schema = StructType(
[StructField("ad_id", IntegerType(), True),
StructField("document_id_promo", IntegerType(), True),
StructField("campaign_id", IntegerType(), True),
StructField("advertiser_id", IntegerType(), True)]
)
promoted_content_df = spark.read.schema(promoted_content_schema) \
.options(header='true', inferschema='false', nullValue='\\N') \
.csv(DATA_BUCKET_FOLDER + "promoted_content.csv") \
.alias('promoted_content')
clicks_train_schema = StructType(
[StructField("display_id", IntegerType(), True),
StructField("ad_id", IntegerType(), True),
StructField("clicked", IntegerType(), True)]
)
clicks_train_df = spark.read.schema(clicks_train_schema) \
.options(header='true', inferschema='false', nullValue='\\N') \
.csv(DATA_BUCKET_FOLDER + "clicks_train.csv") \
.alias('clicks_train')
clicks_train_joined_df = clicks_train_df \
.join(promoted_content_df, on='ad_id', how='left') \
.join(events_df, on='display_id', how='left')
clicks_train_joined_df.createOrReplaceTempView('clicks_train_joined')
validation_display_ids_df = clicks_train_joined_df.select('display_id', 'day_event') \
.distinct() \
.sampleBy("day_event", fractions={0: 0.2, 1: 0.2, 2: 0.2, 3: 0.2, 4: 0.2,
5: 0.2, 6: 0.2, 7: 0.2, 8: 0.2, 9: 0.2, 10: 0.2, 11: 1.0, 12: 1.0}, seed=0)
validation_display_ids_df.createOrReplaceTempView("validation_display_ids")
validation_set_df = spark.sql('''SELECT display_id, ad_id, uuid_event, day_event,
timestamp_event, document_id_promo, platform_event, geo_location_event
FROM clicks_train_joined t
WHERE EXISTS (SELECT display_id FROM validation_display_ids
WHERE display_id = t.display_id)''')
validation_set_gcs_output = "validation_set.parquet"
validation_set_df.write.parquet(OUTPUT_BUCKET_FOLDER + validation_set_gcs_output, mode='overwrite')
print(validation_set_df.take(5))
spark.stop()