From 3f82b7f982a1d9a9c5392bd110a5b93dfa7e80eb Mon Sep 17 00:00:00 2001 From: kkudrynski Date: Thu, 25 Mar 2021 11:30:48 +0100 Subject: [PATCH] [DLRM/TF2] Initial Release --- Tensorflow2/Recommendation/DLRM/Dockerfile | 34 + .../Recommendation/DLRM/Dockerfile_spark | 44 + Tensorflow2/Recommendation/DLRM/README.md | 634 +++ Tensorflow2/Recommendation/DLRM/dataloader.py | 73 + .../Recommendation/DLRM/distributed_utils.py | 122 + Tensorflow2/Recommendation/DLRM/embedding.py | 111 + .../DLRM/img/columnwise_split.svg | 1 + .../DLRM/img/dlrm_histograms.svg | 5071 +++++++++++++++++ .../DLRM/img/hybrid_parallel.svg | 1 + .../DLRM/img/singlegpu_architecture.svg | 1 + .../Recommendation/DLRM/interaction.py | 67 + .../Recommendation/DLRM/lr_scheduler.py | 64 + Tensorflow2/Recommendation/DLRM/main.py | 315 + Tensorflow2/Recommendation/DLRM/model.py | 606 ++ .../DLRM/preproc/dgx2_config.sh | 38 + .../DLRM/preproc/gpu/get_gpu_resources.sh | 4 + .../DLRM/preproc/gpu/spark-defaults.conf | 30 + .../DLRM/preproc/parquet_to_binary.py | 90 + .../DLRM/preproc/prepare_dataset.sh | 79 + .../Recommendation/DLRM/preproc/run_spark.sh | 34 + .../DLRM/preproc/run_spark_cpu.sh | 162 + .../DLRM/preproc/run_spark_gpu.sh | 195 + .../DLRM/preproc/spark_data_utils.py | 507 ++ .../DLRM/preproc/split_dataset.py | 127 + .../DLRM/preproc/verify_criteo_downloaded.sh | 34 + .../Recommendation/DLRM/requirements.txt | 7 + .../Recommendation/DLRM/slurm_multinode.sh | 39 + .../DLRM/split_binary_dataset.py | 215 + .../tensorflow-dot-based-interact/LICENSE | 201 + .../tensorflow-dot-based-interact/MANIFEST.in | 1 + .../tensorflow-dot-based-interact/Makefile | 57 + .../tensorflow-dot-based-interact/README.md | 101 + .../build_pip_pkg.sh | 31 + .../tensorflow-dot-based-interact/setup.py | 78 + .../tensorflow_dot_based_interact/__init__.py | 18 + .../ampere/dot_based_interact_ampere.cu.cc | 528 ++ .../ampere/dot_based_interact_ampere.h | 53 + .../dot_based_interact_ampere_fp32.cu.inl | 280 + .../dot_based_interact_ampere_half.cu.inl | 570 ++ .../dot_based_interact_ampere_tf32.cu.inl | 346 ++ .../dot_based_interact_grad_kernels.cc | 155 + .../cc/kernels/dot_based_interact_kernels.cc | 147 + .../dot_based_interact_shared_utils.cu.h | 41 + .../volta/dot_based_interact_volta.cu.cc | 337 ++ .../volta/dot_based_interact_volta.cu.inl | 822 +++ .../kernels/volta/dot_based_interact_volta.h | 53 + .../cc/ops/dot_based_interact_ops.cc | 50 + .../python/__init__.py | 0 .../python/ops/__init__.py | 0 .../python/ops/dot_based_interact_ops.py | 31 + .../python/ops/dot_based_interact_ops_test.py | 115 + Tensorflow2/Recommendation/DLRM/utils.py | 127 + 52 files changed, 12847 insertions(+) create mode 100644 Tensorflow2/Recommendation/DLRM/Dockerfile create mode 100644 Tensorflow2/Recommendation/DLRM/Dockerfile_spark create mode 100644 Tensorflow2/Recommendation/DLRM/README.md create mode 100644 Tensorflow2/Recommendation/DLRM/dataloader.py create mode 100644 Tensorflow2/Recommendation/DLRM/distributed_utils.py create mode 100644 Tensorflow2/Recommendation/DLRM/embedding.py create mode 100644 Tensorflow2/Recommendation/DLRM/img/columnwise_split.svg create mode 100644 Tensorflow2/Recommendation/DLRM/img/dlrm_histograms.svg create mode 100644 Tensorflow2/Recommendation/DLRM/img/hybrid_parallel.svg create mode 100644 Tensorflow2/Recommendation/DLRM/img/singlegpu_architecture.svg create mode 100644 Tensorflow2/Recommendation/DLRM/interaction.py create mode 100644 Tensorflow2/Recommendation/DLRM/lr_scheduler.py create mode 100644 Tensorflow2/Recommendation/DLRM/main.py create mode 100644 Tensorflow2/Recommendation/DLRM/model.py create mode 100755 Tensorflow2/Recommendation/DLRM/preproc/dgx2_config.sh create mode 100644 Tensorflow2/Recommendation/DLRM/preproc/gpu/get_gpu_resources.sh create mode 100644 Tensorflow2/Recommendation/DLRM/preproc/gpu/spark-defaults.conf create mode 100644 Tensorflow2/Recommendation/DLRM/preproc/parquet_to_binary.py create mode 100755 Tensorflow2/Recommendation/DLRM/preproc/prepare_dataset.sh create mode 100755 Tensorflow2/Recommendation/DLRM/preproc/run_spark.sh create mode 100755 Tensorflow2/Recommendation/DLRM/preproc/run_spark_cpu.sh create mode 100755 Tensorflow2/Recommendation/DLRM/preproc/run_spark_gpu.sh create mode 100644 Tensorflow2/Recommendation/DLRM/preproc/spark_data_utils.py create mode 100644 Tensorflow2/Recommendation/DLRM/preproc/split_dataset.py create mode 100755 Tensorflow2/Recommendation/DLRM/preproc/verify_criteo_downloaded.sh create mode 100644 Tensorflow2/Recommendation/DLRM/requirements.txt create mode 100644 Tensorflow2/Recommendation/DLRM/slurm_multinode.sh create mode 100644 Tensorflow2/Recommendation/DLRM/split_binary_dataset.py create mode 100644 Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/LICENSE create mode 100644 Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/MANIFEST.in create mode 100644 Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/Makefile create mode 100644 Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/README.md create mode 100755 Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/build_pip_pkg.sh create mode 100644 Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/setup.py create mode 100644 Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/__init__.py create mode 100644 Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/ampere/dot_based_interact_ampere.cu.cc create mode 100644 Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/ampere/dot_based_interact_ampere.h create mode 100644 Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/ampere/dot_based_interact_ampere_fp32.cu.inl create mode 100644 Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/ampere/dot_based_interact_ampere_half.cu.inl create mode 100644 Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/ampere/dot_based_interact_ampere_tf32.cu.inl create mode 100644 Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/dot_based_interact_grad_kernels.cc create mode 100644 Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/dot_based_interact_kernels.cc create mode 100644 Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/dot_based_interact_shared_utils.cu.h create mode 100644 Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/volta/dot_based_interact_volta.cu.cc create mode 100644 Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/volta/dot_based_interact_volta.cu.inl create mode 100644 Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/volta/dot_based_interact_volta.h create mode 100644 Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/ops/dot_based_interact_ops.cc create mode 100644 Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/python/__init__.py create mode 100644 Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/python/ops/__init__.py create mode 100644 Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/python/ops/dot_based_interact_ops.py create mode 100644 Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/python/ops/dot_based_interact_ops_test.py create mode 100644 Tensorflow2/Recommendation/DLRM/utils.py diff --git a/Tensorflow2/Recommendation/DLRM/Dockerfile b/Tensorflow2/Recommendation/DLRM/Dockerfile new file mode 100644 index 00000000..17ec0921 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/Dockerfile @@ -0,0 +1,34 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# author: Tomasz Grel (tgrel@nvidia.com) + + +ARG FROM_IMAGE_NAME=nvcr.io/nvidia/tensorflow:21.02-tf2-py3 +FROM ${FROM_IMAGE_NAME} + +RUN pip install -e git+https://github.com/NVIDIA/dllogger#egg=dllogger + +ENV HOROVOD_CYCLE_TIME=0.1 + +WORKDIR /dlrm + +ADD . . + +RUN mkdir -p /usr/local/lib/python3.8/dist-packages/tensorflow/include/third_party/gpus/cuda/ \ + && ln -s /usr/local/cuda/include /usr/local/lib/python3.8/dist-packages/tensorflow/include/third_party/gpus/cuda/ \ + && cd tensorflow-dot-based-interact \ + && make \ + && make pkg \ + && pip install ./artifacts/tensorflow_dot_based_interact-*.whl diff --git a/Tensorflow2/Recommendation/DLRM/Dockerfile_spark b/Tensorflow2/Recommendation/DLRM/Dockerfile_spark new file mode 100644 index 00000000..d146bebf --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/Dockerfile_spark @@ -0,0 +1,44 @@ +# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ARG FROM_IMAGE_NAME=nvcr.io/nvidia/cuda:10.2-cudnn8-runtime-ubuntu18.04 +FROM ${FROM_IMAGE_NAME} + +RUN apt update && \ + apt install -y openjdk-8-jdk && \ + apt install -y curl && \ + curl https://archive.apache.org/dist/spark/spark-3.0.1/spark-3.0.1-bin-hadoop2.7.tgz -o /opt/spark.tgz && \ + tar zxf /opt/spark.tgz -C /opt/ && \ + mv /opt/spark-3.0.1-bin-hadoop3.2 /opt/spark && \ + rm /opt/spark.tgz && \ + curl https://repo1.maven.org/maven2/ai/rapids/cudf/0.14/cudf-0.14-cuda10-2.jar -o /opt/cudf.jar && \ + curl https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/0.1.0/rapids-4-spark_2.12-0.1.0.jar -o /opt/rapids-4-spark.jar && \ + apt install -y git + +ADD requirements.txt . +RUN apt install -y python3-pip && pip3 install -r requirements.txt + +WORKDIR /workspace/dlrm + +COPY . . + +RUN mv /opt/cudf.jar /opt/spark/jars && \ + mv /opt/rapids-4-spark.jar /opt/spark/jars/ && \ + mv /workspace/dlrm/preproc/gpu/get_gpu_resources.sh /opt/spark/conf/ && \ + mv /workspace/dlrm/preproc/gpu/spark-defaults.conf /opt/spark/conf/ && \ + rm -fr /workspace/dlrm/preproc/gpu + +RUN chmod +x /opt/spark/conf/get_gpu_resources.sh +RUN /bin/bash -c "echo export PYSPARK_PYTHON=/usr/bin/python3 >> /etc/bash.bashrc; update-alternatives --install /usr/bin/python python /usr/bin/python3 10" + diff --git a/Tensorflow2/Recommendation/DLRM/README.md b/Tensorflow2/Recommendation/DLRM/README.md new file mode 100644 index 00000000..6066f322 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/README.md @@ -0,0 +1,634 @@ +# DLRM For TensorFlow 2 + +This repository provides a script and recipe to train the Deep Learning Recommendation Model (DLRM) to achieve state-of-the-art accuracy is tested and maintained by NVIDIA. + +## Table Of Contents + + * [Model overview](#model-overview) + * [Model architecture](#model-architecture) + * [Default configuration](#default-configuration) + * [Feature support matrix](#feature-support-matrix) + * [Features](#features) + * [Mixed precision training](#mixed-precision-training) + * [Enabling mixed precision](#enabling-mixed-precision) + * [Enabling TF32](#enabling-tf32) + * [Hybrid-parallel multi-GPU with all-2-all communication](#hybrid-parallel-multi-gpu-with-all-2-all-communication) + * [Embedding table placement and load balancing (default mode)](#embedding-table-placement-and-load-balancing-default-mode) + * [Training very large embedding tables (experimental mode)](#training-very-large-embedding-tables-experimental-mode) + * [Multi-node training](#multi-node-training) + * [Preprocessing on GPU with Spark 3](#preprocessing-on-gpu-with-spark-3) + * [Setup](#setup) + * [Requirements](#requirements) + * [Quick Start Guide](#quick-start-guide) + * [Advanced](#advanced) + * [Scripts and sample code](#scripts-and-sample-code) + * [Parameters](#parameters) + * [Command-line options](#command-line-options) + * [Getting the data](#getting-the-data) + * [Dataset guidelines](#dataset-guidelines) + * [Multi-dataset](#multi-dataset) + * [Preprocess with Spark](#preprocess-with-spark) + * [Training process](#training-process) + * [Performance](#performance) + * [Benchmarking](#benchmarking) + * [Training performance benchmark](#training-performance-benchmark) + * [Inference performance benchmark](#inference-performance-benchmark) + * [Results](#results) + * [Training accuracy results](#training-accuracy-results) + * [Training accuracy: NVIDIA DGX A100 (8x A100 80GB)](#training-accuracy-nvidia-dgx-a100-8x-a100-80gb) + * [Training accuracy: NVIDIA DGX-1 (8x V100 32GB)](#training-accuracy-nvidia-dgx-1-8x-v100-32gb) + * [Training accuracy: NVIDIA DGX-2 (16x V100 32GB)](#training-accuracy-nvidia-dgx-2-16x-v100-32gb) + * [Training stability test](#training-stability-test) + * [Training performance results](#training-performance-results) + * [Training performance: NVIDIA DGX A100 (8x A100 80GB)](#training-performance-nvidia-dgx-a100-8x-a100-80gb) + * [Training performance: NVIDIA DGX-1 (8x V100 32GB)](#training-performance-nvidia-dgx-1-8x-v100-32gb) + * [Training performance: NVIDIA DGX-2 (16x V100 32GB)](#training-performance-nvidia-dgx-2-16x-v100-32gb) + * [Inference performance results](#inference-performance-results) + * [Inference performance: NVIDIA DGX A100 (8x A100 80GB)](#inference-performance-nvidia-dgx-a100-8x-a100-80gb) + * [Inference performance: NVIDIA DGX1V-32GB (8x V100 32GB)](#inference-performance-nvidia-dgx1v-32gb-8x-v100-32gb) + * [Inference performance: NVIDIA DGX2 (16x V100 16GB)](#inference-performance-nvidia-dgx2-16x-v100-16gb) + * [Release notes](#release-notes) + * [Changelog](#changelog) + * [Known issues](#known-issues) + * [Horovod issues](#horovod-issues) + * [Checkpointing](#checkpointing) + + +## Model overview + +The Deep Learning Recommendation Model (DLRM) is a recommendation model designed to make use of both categorical and numerical inputs. It was first described in [Deep Learning Recommendation Model for Personalization and Recommendation Systems](https://arxiv.org/abs/1906.00091). +This repository provides a reimplementation of the code-base provided originally [here](https://github.com/facebookresearch/dlrm). +The scripts enable you to train DLRM on the [Criteo Terabyte Dataset](https://labs.criteo.com/2013/12/download-terabyte-click-logs/). + +Using the scripts provided here, you can efficiently train models that are too large to fit into a single GPU. This is because we use a hybrid-parallel approach, which combines model parallelism with data parallelism for different parts of the neural network. This is explained in details in the [next section](#hybrid-parallel-multi-gpu-with-all-2-all-communication). + +This model uses a slightly different preprocessing procedure than the one found in the original implementation. You can find a detailed description of the preprocessing steps in the [Dataset guidelines](#dataset-guidelines) section. + +Using DLRM, you can train a high-quality general model for recommendations. + +This model is trained with mixed precision using Tensor Cores on Volta, Turing and NVIDIA Ampere GPU architectures. Therefore, researchers can get results 2x faster than training without Tensor Cores while experiencing the benefits of mixed precision training. This model is tested against each NGC monthly container release to ensure consistent accuracy and performance over time. + + + +### Model architecture + +DLRM accepts two types of features: categorical and numerical. For each categorical feature, an embedding table is used to provide dense representation to each unique value. The dense features enter the model and are transformed by a simple neural network referred to as "bottom MLP". + +This part of the network consists of a series +of linear layers with ReLU activations. The output of the bottom MLP and the embedding vectors are then fed into the "dot interaction" operation. The output of "dot interaction" is then concatenated with the features resulting from bottom MLP and fed into the "top MLP" which is a series of dense layers with activations. +The model outputs a single number which can be interpreted as a likelihood of a certain user clicking an ad. + + + +

+ +
+Figure 1. The architecture of DLRM. +

+ +### Default configuration + +The following features were implemented in this model: +- general + - static loss scaling for Tensor Cores (mixed precision) training + - hybrid-parallel multi-GPU training +- preprocessing + - dataset preprocessing using Spark 3 on GPUs + +### Feature support matrix + +The following features are supported by this model: + +| Feature | DLRM +|----------------------|-------------------------- +|Automatic mixed precision (AMP) | Yes +|XLA | Yes +|Hybrid-parallel multiGPU with Horovod all-to-all| Yes +|Preprocessing on GPU with Spark 3| Yes +|Multi-node training | Yes + +#### Features + +**Automatic Mixed Precision (AMP)** +Enables mixed precision training without any changes to the code-base by performing automatic graph rewrites and loss scaling controlled by an environmental variable. + +**XLA** + +The training script supports a `--xla` flag. It can be used to enable XLA JIT compilation. Currently, we use [XLA Lite](https://docs.nvidia.com/deeplearning/frameworks/tensorflow-user-guide/index.html#xla-lite). It delivers a steady 10-30% performance boost depending on your hardware platform, precision, and the number of GPUs. It is turned off by default. + +**Horovod** +Horovod is a distributed training framework for TensorFlow, Keras, PyTorch, and MXNet. The goal of Horovod is to make distributed deep learning fast and easy to use. For more information about how to get started with Horovod, see the Horovod [official repository](https://github.com/horovod/horovod). + +**Hybrid-parallel multiGPU with Horovod all-to-all** +Our model uses Horovod to implement efficient multi-GPU training with NCCL. For details, see example sources in this repository or see the TensorFlow tutorial. For the detailed description of our multi-GPU approach, visit this [section](hybrid-parallel-multi-gpu-with-all-2-all-communication). + +**Multi-node training** +This repository supports multinode training. For more information refer to the [multinode section](#multi-node-training) + + +### Mixed precision training + +Mixed precision is the combined use of different numerical precisions in a computational method. [Mixed precision](https://arxiv.org/abs/1710.03740) training offers significant computational speedup by performing operations in half-precision format while storing minimal information in single-precision to retain as much information as possible in critical parts of the network. Since the introduction of [Tensor Cores](https://developer.nvidia.com/tensor-cores) in Volta, and following with both the Turing and Ampere architectures, significant training speedups are experienced by switching to mixed precision -- up to 3.4x overall speedup on the most arithmetically intense model architectures. Using mixed precision training requires two steps: +1. Porting the model to use the FP16 data type where appropriate. +2. Adding loss scaling to preserve small gradient values. + +The ability to train deep learning networks with lower precision was introduced in the Pascal architecture and first supported in [CUDA 8](https://devblogs.nvidia.com/parallelforall/tag/fp16/) in the NVIDIA Deep Learning SDK. + +For information about: +- How to train using mixed precision, see the [Mixed Precision Training](https://arxiv.org/abs/1710.03740) paper and [Training With Mixed Precision](https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html) documentation. +- Techniques used for mixed precision training, see the [Mixed-Precision Training of Deep Neural Networks](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) blog. + +#### Enabling mixed precision + +Mixed precision training is turned off by default. To turn it on, issue the `--amp` flag to the `main.py` script. + + +#### Enabling TF32 + +TensorFloat-32 (TF32) is the new math mode in [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for handling the matrix math also called tensor operations. TF32 running on Tensor Cores in A100 GPUs can provide up to 10x speedups compared to single-precision floating-point math (FP32) on Volta GPUs. + +TF32 Tensor Cores can speed up networks using FP32, typically with no loss of accuracy. It is more robust than FP16 for models which require high dynamic range for weights or activations. + +For more information, refer to the [TensorFloat-32 in the A100 GPU Accelerates AI Training, HPC up to 20x](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) blog post. + +TF32 is supported in the NVIDIA Ampere GPU architecture and is enabled by default. + + +### Hybrid-parallel multi-GPU with all-2-all communication + +Many recommendation models contain very large embedding tables. As a result, the model is often too large to fit onto a single device. This could be easily solved by training in a model-parallel way, using either the CPU or other GPUs as "memory donors". However, this approach is suboptimal as the "memory donor" devices' compute is not utilized. In this repository, we use the model-parallel approach for the bottom part of the model (Embedding Tables + bottom MLP) while using a usual data parallel approach for the top part of the model (Dot Interaction + top MLP). This way, we can train models much larger than what would normally fit into a single GPU while at the same time making the training faster by using multiple GPUs. We call this approach hybrid-parallel training. + +The transition from model-parallel to data-parallel in the middle of the neural net needs a specific multi-GPU communication pattern called [all-2-all](https://en.wikipedia.org/wiki/All-to-all_\(parallel_pattern\)) which is available in our [TensorFlow 2 21.02-py3](https://ngc.nvidia.com/catalog/containers/nvidia:tensorflow/tags) NGC Docker container. In the [original DLRM whitepaper](https://arxiv.org/abs/1906.00091) this has been referred to as "butterfly shuffle". + + +

+ +
+Figure 2. The default multi-GPU mode. +

+ + +As the example shows, in this repository we train models of two sizes: "small" (~15 GB) and "large" (~82 GB). The "large" model cannot be trained in a single GPU paradigm as it will not fit into a single GPU memory. + +#### Embedding table placement and load balancing (default mode) + +By default, we use the following heuristic for dividing the work between the GPUs: +- The bottom MLP is placed on GPU-0 and no embedding tables are placed on this device. +- The tables are sorted from the largest to the smallest. +- Set `max_tables_per_gpu := ceil(number_of_embedding_tables / number_of_available_gpus)`. +- Repeat until all embedding tables have an assigned device: + - Out of all the available GPUs, find the one with the largest amount of unallocated memory. + - Place the largest unassigned embedding table on this GPU. Raise an exception if it does not fit. + - If the number of embedding tables on this GPU is now equal to `max_tables_per_gpu`, remove this GPU from the list of available GPUs, so that no more embedding tables will be placed on this GPU. This ensures the all-2-all communication is well-balanced between all devices. + +#### Training very large embedding tables (experimental mode) + +The default multi-GPU paradigm described above has a constraint – each individual table has to fit into a single device's memory. If that's not met, then an Out-of-Memory error will be raised. To enable experimentation with very large models, we provide a way of circumventing this constraint by passing the `--experimental_columnwise_split --data_parallel_bottom_mlp` command-line flags. As a result, each table will be split across the latent space dimension. Some dimensions of the latent space will be placed on one GPU and the rest of them are stored on other GPUs. This means that a table that originally encoded C unique categories into D dense dimensions will now become N separate tables of shape `[C, D / N]` each stored on a different GPU, where N is the number of GPUs used. Symbolically, the computations are exactly equivalent. + +The figure below illustrates this paradigm for a model with 2 embedding tables distributed across two GPUs. Note that this approach is currently slower than the default mode described above. + +

+ +
+Figure 3. The "columnwise split" multi-GPU mode. +

+ + +We tested this approach by training a DLRM model on the Criteo Terabyte dataset with the frequency limiting option turned off (set to zero). The weights of the resulting model take 421 GB. The largest table weighs 140 GB. Here are the commands you can use to reproduce this: + +``` +# build and run the preprocessing container as in the Quick Start Guide +# then when preprocessing set the frequency limit to 0: +./prepare_dataset.sh DGX2 0 + +# build and run the training container same as in the Quick Start Guide +# then append options necessary for training very large embedding tables: +horovodrun -np 8 -H localhost:8 --mpi-args=--oversubscribe numactl --interleave=all -- python -u main.py --dataset_path /data/dlrm/ --amp --tf_gpu_memory_limit_gb 73 --experimental_columnwise_split --data_parallel_bottom_mlp --xla +``` + +When using this method on a DGX A100 with 8 A100-80GB GPUs and a large-enough dataset, it is possible to train a single embedding table of up to 600 GB. You can also use multi-node training (described below) to train even larger recommender systems. + +#### Multi-node training + +Multi-node training is supported. Depending on the exact interconnect hardware and model configuration, you might experience only a modest speedup with multi-node. Multi-node training can also be used to train larger models. For example, to train a 1.68 TB variant of DLRM on multi-node, you can run: + +``` +cmd='numactl --interleave=all -- python -u main.py --dataset_path /data/dlrm/full_criteo_data --amp \ +--tf_gpu_memory_limit_gb 73 --experimental_columnwise_split --data_parallel_bottom_mlp \ +--embedding_dim 512 --bottom_mlp_dims 512,256,512' \ +srun_flags='--mpi=pmix' \ +cont=nvidia_dlrm_tf \ +mounts=/data/dlrm:/data/dlrm \ +sbatch -n 32 -N 4 -t 00:20:00 slurm_multinode.sh +``` + +### Preprocessing on GPU with Spark 3 + +Refer to the ["Preprocessing with Spark" section](#preprocess-with-spark) for a detailed description of the Spark 3 GPU functionality. + +## Setup + +The following section lists the requirements for training DLRM. + +### Requirements + +This repository contains Dockerfile which extends the TensorFlow 2 NGC container and encapsulates some dependencies. Aside from these dependencies, ensure you have the following components: +- [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker) +- [TensorFlow 2 21.02-py3](https://ngc.nvidia.com/catalog/containers/nvidia:tensorflow/tags) NGC container +- Supported GPUs: + - [NVIDIA Volta architecture](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/) + - [NVIDIA Turing architecture](https://www.nvidia.com/en-us/geforce/turing/) + - [NVIDIA Ampere architecture](https://www.nvidia.com/en-us/data-center/nvidia-ampere-gpu-architecture/) + + +For more information about how to get started with NGC containers, see the following sections from the NVIDIA GPU Cloud Documentation and the Deep Learning Documentation: +- [Getting Started Using NVIDIA GPU Cloud](https://docs.nvidia.com/ngc/ngc-getting-started-guide/index.html) +- [Accessing And Pulling From The NGC Container Registry](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#accessing_registry) +- [Running TensorFlow](https://docs.nvidia.com/deeplearning/frameworks/tensorflow-release-notes/running.html#running) + +For those unable to use the TensorFlow NGC container, to set up the required environment or create your own container, see the versioned [NVIDIA Container Support Matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html). + +## Quick Start Guide + +To train your model using mixed or TF32 precision with Tensor Cores or using FP32, perform the following steps using +the default parameters of DLRM on the Criteo Terabyte dataset. For the specifics concerning training and inference, +see the [Advanced](#advanced) section. + +1. Clone the repository. +``` +git clone https://github.com/NVIDIA/DeepLearningExamples +cd DeepLearningExamples/TensorFlow2/Recommendation/DLRM +``` + +2. Build a DLRM Docker container. +```bash +docker build -t nvidia_dlrm_tf . +docker build -t nvidia_dlrm_spark -f Dockerfile_spark . +``` + +3. Start an interactive session in the NGC container to run preprocessing. +The DLRM TensorFlow container can be launched with: +```bash +mkdir -p data +docker run --runtime=nvidia -it --rm --ipc=host -v ${PWD}/data:/data nvidia_dlrm_spark bash +``` + +4. Download and preprocess the dataset. + +You can download the data by following the instructions at: http://labs.criteo.com/2013/12/download-terabyte-click-logs/. + +When you have successfully downloaded the dataset, put it in the `/data/dlrm/criteo/` directory in the container (`$PWD/data/dlrm/criteo` in the host system). + +Here are a few examples of different preprocessing commands. For the details on how those scripts work and detailed description of all the parameters, consult the [preprocess with spark section](#preprocess-with-spark). + +```bash +cd preproc + +# to run on a DGX-2 with a frequency limit of 3 (will need 8xV100-32GB to fit the model in GPU memory) +./prepare_dataset.sh DGX2 3 + +# to run on a DGX-2 with a frequency limit of 15 (should fit on a single V100-32GB): +./prepare_dataset.sh DGX2 15 +# +# to run on CPU with a frequency limit of 15: +./prepare_dataset.sh CPU 15 +``` + +5. Start training. + +First, start the Docker container: +```bash +docker run --runtime=nvidia -it --rm --ipc=host -v ${PWD}/data:/data nvidia_dlrm_tf bash +``` + +- single-GPU A100-80GB: +```bash +horovodrun -np 1 -H localhost:1 --mpi-args=--oversubscribe numactl --interleave=all -- python -u main.py --dataset_path /data/dlrm/ --amp --tf_gpu_memory_limit_gb 73 --xla --save_checkpoint_path /data/dlrm/checkpoint/dlrm +``` + +- single-GPU V100-32GB: +```bash +horovodrun -np 1 -H localhost:1 --mpi-args=--oversubscribe numactl --interleave=all -- python -u main.py --dataset_path /data/dlrm/ --xla --save_checkpoint_path /data/dlrm/checkpoint/dlrm +``` + +- multi-GPU for DGX A100: +```bash +horovodrun -np 8 -H localhost:8 --mpi-args=--oversubscribe numactl --interleave=all -- python -u main.py --dataset_path /data/dlrm/ --amp --tf_gpu_memory_limit_gb 73 --xla --save_checkpoint_path /data/dlrm/checkpoint/dlrm +``` + +- multi-GPU for DGX2: +```bash +horovodrun -np 16 -H localhost:16 --mpi-args=--oversubscribe numactl --interleave=all -- python -u main.py --dataset_path /data/dlrm/ --amp --xla --save_checkpoint_path /data/dlrm/checkpoint/dlrm +``` + +- multi-GPU for DGX1V-32GB: +```bash +horovodrun -np 8 -H localhost:8 --mpi-args=--oversubscribe numactl --interleave=all -- python -u main.py --dataset_path /data/dlrm/ --amp --xla --save_checkpoint_path /data/dlrm/checkpoint/dlrm +``` + +6. Start evaluation. + +To evaluate a previously trained checkpoint, append `--restore_checkpoint_path --mode eval` to the command used for training. For example, to test a checkpoint trained on 8x A100 80GB, run: + +```bash +horovodrun -np 8 -H localhost:8 --mpi-args=--oversubscribe numactl --interleave=all -- python -u main.py --dataset_path /data/dlrm/ --amp --tf_gpu_memory_limit_gb 73 --xla --restore_checkpoint_path /data/dlrm/checkpoint/dlrm --mode eval +``` + +## Advanced + +The following sections provide greater details of the dataset, running training and inference, and the training results. + +### Scripts and sample code + +These are the important modules in this repository: +- `main.py` - The main entrypoint script for training, evaluating, and benchmarking. +- `model.py` - Defines the DLRM model and some auxiliary functions used to train it. +- `dataloader.py` - Handles defining the dataset objects based on command-line flags. +- `split_binary_dataset.py` - Defines the `RawBinaryDataset` class responsible for storing and loading the training data. +- `distributed_utils.py` - Contains the heuristic used for placing the embedding tables across devices. Additionally, defines some data classes describing this placement and some utilities for multi-GPU and multi-node training. +- `slurm_multinode.sh` - Example batch script for multi-node training on SLURM clusters. +- `lr_scheduler.py` - Defines a TensorFlow learning rate scheduler that supports both learning rate warmup and polynomial decay. +- `embedding.py` - Implementations of the embedding layers. +- `interaction.py` - Implementation of the dot-interaction layer using TensorFlow operations. +- `tensorflow-dot-based-interact` - A directory with a set of custom CUDA kernels. They provide fast implementations of the dot-interaction operation for various precisions and hardware platforms. +- `utils.py` - General utilities, such as a timer used for taking performance measurements. + + + +### Parameters + +The table below lists the most important command-line parameters of the `main.py` script. + +| Scope| parameter| Comment| Default Value | +| ----- | --- | ---- | ---- | +|datasets|dataset_path|Path to the JSON file with the sizes of embedding tables| +|function|mode| Choose "train" to train the model, "inference" to benchmark inference and "eval" to run validation| train| +|optimizations|amp| Enable automatic mixed precision| False +|optimizations|xla| Enable XLA| False| +|hyperparameters|batch_size| Batch size used for training|65536| +|hyperparameters|epochs| Number of epochs to train for|1| +|hyperparameters|optimizer| Optimization algorithm for training |SGD| +|hyperparameters|evals_per_epoch| Number of evaluations per epoch|1| +|hyperparameters|valid_batch_size| Batch size used for validation|65536| +|hyperparameters|max_steps| Stop the training/inference after this many optimization steps|-1| +|checkpointing|restore_checkpoint_path| Path from which to restore a checkpoint before training|None| +|checkpointing|save_checkpoint_path| Path to which to save a checkpoint file at the end of the training|None| +|debugging|run_eagerly| Disable all tf.function decorators for debugging|False| +|debugging|print_freq| Number of steps between debug prints|1000| + + +### Command-line options + +The `main.py` script supports a number of command-line flags. You can get the descriptions of those by running `python main.py --help`. + +### Getting the data + +This example uses the [Criteo Terabyte Dataset](https://labs.criteo.com/2013/12/download-terabyte-click-logs/). +The first 23 days are used as the training set. The last day is split in half. The first part is used as a validation set and the second set is used as a hold-out test set. + + +#### Dataset guidelines + +The preprocessing steps applied to the raw data include: +- Replacing the missing values with `0`. +- Replacing the categorical values that exist fewer than 15 times with a special value. +- Converting the hash values to consecutive integers. +- Adding 2 to all the numerical features so that all of them are greater or equal to 1. +- Taking a natural logarithm of all numerical features. + +#### Multi-dataset + +Our preprocessing scripts are designed for the Criteo Terabyte Dataset and should work with any other dataset with the same format. The data should be split into text files. Each line of those text files should contain a single training example. An example should consist of multiple fields separated by tabulators: +- The first field is the label – `1` for a positive example and `0` for negative. +- The next `N` tokens should contain the numerical features separated by tabs. +- The next `M` tokens should contain the hashed categorical features separated by tabs. + + +#### Preprocess with Spark + +The preprocessing scripts provided in this repository support running both on CPU and on DGX-2 using [Apache Spark 3.0](https://www.nvidia.com/en-us/deep-learning-ai/solutions/data-science/apache-spark-3/). +It should be possible to change the values in `preproc/dgx2_config.sh` +so that they'll work on other hardware platforms such as DGX-1. + +Note that the preprocessing will require about 4TB of disk storage. + +The syntax for the preprocessing script is as follows: +```bash +cd preproc +./prepare_dataset.sh +``` + +The first argument is the hardware platform to use (either DGX-2 or pure-CPU). The second argument means the frequency +threshold to apply to the categorical variables. For a frequency threshold `T`, the categorical values that occur less +often than `T` will be replaced with a special embedding. Thus, a larger value of `T` will require smaller embedding tables +and will substantially reduce the overall size of the model. + +For the Criteo Terabyte dataset we recommend a frequency threshold of `T=3` if you intend to run the hybrid-parallel mode +on multiple GPUs. If you want to make the model fit into a single NVIDIA Tesla V100-32GB, you can set `T=15`. + +The preprocessing scripts makes use of the following environment variables to configure the data directory paths: +- `download_dir` – this directory should contain the original Criteo Terabyte CSV files +- `spark_output_path` – directory to which the parquet data will be written +- `conversion_intermediate_dir` – directory used for storing intermediate data used to convert from parquet to train-ready format +- `final_output_dir` – directory to store the final results of the preprocessing which can then be used to train DLRM + +The script `spark_data_utils.py` is a PySpark application, which is used to preprocess the Criteo Terabyte Dataset. In the Docker image, we have installed Spark 3.0.1, which will start a standalone cluster of Spark. The scripts `run_spark_cpu.sh` and `run_spark_gpu.sh` start Spark, then runs several PySpark jobs with `spark_data_utils.py`, for example: +generates the dictionary +- transforms the train dataset +- transforms the test dataset +- transforms the validation dataset + + Change the variables in the `run-spark.sh` script according to your environment. + Configure the paths. +``` +export SPARK_LOCAL_DIRS=/data/spark-tmp +export INPUT_PATH=/data/criteo +export OUTPUT_PATH=/data/output +``` +Note that the Spark job requires about 3TB disk space used for data shuffle. + +Where: +`SPARK_LOCAL_DIRS` is the path where Spark uses to write shuffle data. +`INPUT_PATH` is the path of the Criteo Terabyte Dataset, including uncompressed files like day_0, day_1… +`OUTPUT_PATH` is where the script writes the output data. It will generate the following subdirectories of `models`, `train`, `test`, and `validation`. +- The `model` is the dictionary folder. +- The `train` is the train dataset transformed from day_0 to day_22. +- The `test` is the test dataset transformed from the prior half of day_23. +- The `validation` is the dataset transformed from the latter half of day_23. + +Configure the resources which Spark will use. +``` +export TOTAL_CORES=80 +export TOTAL_MEMORY=800 +``` + +Where: +`TOTAL_CORES` is the total CPU cores you want Spark to use. + +`TOTAL_MEMORY` is the total memory Spark will use. + +Configure frequency limit. +``` +USE_FREQUENCY_LIMIT=15 +``` +The frequency limit is used to filter out the categorical values which appear less than n times in the whole dataset, and make them be 0. Change this variable to 1 to enable it. The default frequency limit is 15 in the script. You also can change the number as you want by changing the line of `OPTS="--frequency_limit 8"`. + + +### Training process + +The main training script resides in `main.py`. The speed of training is measured by throughput i.e., the number +of samples processed per second. We use mixed precision training with static loss scaling for the bottom and top MLPs while embedding tables are stored in FP32 format. + +## Performance + +### Benchmarking + +The following section shows how to run benchmarks measuring the model performance in training and inference modes. + +#### Training performance benchmark + +To benchmark the training performance on a specific batch size, follow the instructions +in the [Quick Start Guide](#quick-start-guide). You can also add the `--max_steps 1000` +if you want to get a reliable throughput measurement without running the entire training. + +You can also use synthetic data by running with the `--dataset_type synthetic` option if you haven't downloaded the dataset yet. + +#### Inference performance benchmark + +To benchmark the inference performance on a specific batch size, run: + +``` +horovodrun -np 1 -H localhost:1 --mpi-args=--oversubscribe numactl --interleave=all -- python -u main.py --dataset_path /data/dlrm/ --amp --restore_checkpoint_path --mode inference +``` + +### Results + +The following sections provide details on how we achieved our performance and accuracy in training and inference. + +We used two model size variants to show memory scalability in multi-GPU setup: +- small - refers to model trained on Criteo dataset with frequency thresholding set to 15 resulting in smaller embedding tables - total model size: ~15 GB +- large - refers to model trained on Criteo dataset with frequency thresholding set to 3 resulting in larger embedding tables - total model size: ~82 GB + +#### Training accuracy results + + +##### Training accuracy: NVIDIA DGX A100 (8x A100 80GB) + +Our results were obtained by running training scripts as described in the Quick Start Guide in the DLRM Docker container. + +| GPUs | Model size | Batch size / GPU | Accuracy (AUC) - TF32 | Accuracy (AUC) - mixed precision | Time to train - TF32 [minutes] | Time to train - mixed precision [minutes] | Time to train speedup (TF32 to mixed precision) +|----:|----|----|----:|----:|---:|---:|---:| +| 1 | small | 64k | 0.8026 | 0.8026 | 34.78| 25.07| 1.39| +| 8 | large | 8k | 0.8026 | 0.8026 | 9.33| 7.30| 1.28| + + +##### Training accuracy: NVIDIA DGX-1 (8x V100 32GB) + +Our results were obtained by running training scripts as described in the Quick Start Guide in the DLRM Docker container. + +| GPUs | Model size | Batch size / GPU | Accuracy (AUC) - FP32 | Accuracy (AUC) - mixed precision | Time to train - FP32 [minutes] | Time to train - mixed precision [minutes] | Time to train speedup (FP32 to mixed precision) +|----:|----|----|----:|----:|---:|---:|---:| +| 1 | small | 64k | 0.8026 | 0.8026 | 127.53| 50.55| 2.52| +| 8 | large | 8k | 0.8026 | 0.8026 | 31.73| 14.92| 2.13| + + +##### Training accuracy: NVIDIA DGX-2 (16x V100 32GB) + +Our results were obtained by running training scripts as described in the Quick Start Guide in the DLRM Docker container. + +| GPUs | Model size | Batch size / GPU | Accuracy (AUC) - FP32 | Accuracy (AUC) - mixed precision | Time to train - FP32 [minutes] | Time to train - mixed precision [minutes] | Time to train speedup (FP32 to mixed precision) +|----:|----|----|----:|----:|---:|---:|---:| +| 1 | small | 64k | 0.8026 | 0.8026 | 112.78| 43.20| 2.61| +| 8 | large | 8k | 0.8026 | 0.8026 | 25.28| 11.65| 2.17| +| 16 | large | 4k | 0.8026 | 0.8026 |20.93 | 11.90| 1.76| + + +##### Training stability test + +The histograms below show the distribution of ROC AUC results achieved at the end of the training for each precision/hardware platform tested. There are no statistically significant differences between precision, number of GPUs or hardware platform. Using the larger dataset has a modest, positive impact on final AUC score. + + +

+ +
+Figure 4. Results of stability tests for DLRM. +

+ + +#### Training performance results + + +We used throughput in items processed per second as the performance metric. + + +##### Training performance: NVIDIA DGX A100 (8x A100 80GB) + +Our results were obtained by following the commands from the Quick Start Guide +in the DLRM Docker container on NVIDIA DGX A100 (8x A100 80GB) GPUs. Performance numbers (in items per second) were averaged over 1000 training steps. + +| GPUs | Model size | Batch size / GPU | Throughput - TF32 | Throughput - mixed precision | Throughput speedup (TF32 - mixed precision) +|----:|----|----|---:|---:|---:| +| 1 | small | 64k | 2.09M | 2.93M | 1.40| +| 8 | large | 8k | 8.93M | 12.05M | 1.35| + +To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide). + + +##### Training performance: NVIDIA DGX-1 (8x V100 32GB) + +| GPUs | Model size | Batch size / GPU | Throughput - FP32 | Throughput - mixed precision | Throughput speedup (FP32 - mixed precision) | +|----:|----|----|---:|---:|---:| +| 1 | small | 64k | 0.56M| 1.46M| 2.60| +| 8 | large | 8k | 2.44M| 5.95M | 2.44| + +To achieve the same results, follow the steps in the [Quick Start Guide](#quick-start-guide). + + +##### Training performance: NVIDIA DGX-2 (16x V100 32GB) + +| GPUs | Model size | Batch size / GPU | Throughput - FP32 | Throughput - mixed precision | Throughput speedup (FP32 - mixed precision) +|----:|----|---|---:|---:|---:| +| 1 | small | 64k | 0.64M| 1.70M | 2.68| +| 8 | large | 8k | 3.16M| 8.18M| 2.59| +| 16 | large | 4k | 4.37M| 9.52M| 2.18| + + +To achieve the same results, follow the steps in the [Quick Start Guide](#quick-start-guide). + +#### Inference performance results + +##### Inference performance: NVIDIA DGX A100 (8x A100 80GB) +| GPUs | Model size | Batch size / GPU | Throughput - TF32 | Throughput - mixed precision | Average latency - TF32 [ms] | Average latency - mixed precision [ms] | Throughput speedup (mixed precision to TF32) +|----:|----|---|---:|---:|---:|---:|---:| +| 1| small| 2048| 755k|828k |2.71|2.47|1.10 | + +##### Inference performance: NVIDIA DGX1V-32GB (8x V100 32GB) +| GPUs | Model size | Batch size / GPU | Throughput - FP32 | Throughput - mixed precision | Average latency - FP32 [ms] | Average latency - mixed precision [ms] | Throughput speedup (mixed precision to FP32) +|----:|----|---|---:|---:|---:|---:|---:| +| 1| small| 2048| 441k| 497k |4.65|4.12|1.13 | + +##### Inference performance: NVIDIA DGX2 (16x V100 16GB) +| GPUs | Model size | Batch size / GPU | Throughput - FP32 | Throughput - mixed precision | Average latency - FP32 [ms] | Average latency - mixed precision [ms] | Throughput speedup (mixed precision to FP32) +|----:|----|---|---:|---:|---:|---:|---:| +| 1| small| 2048| 558k| 774k |3.67|2.65|1.39| + + +## Release notes +We’re constantly refining and improving our performance on AI and HPC workloads even on the same hardware with frequent updates to our software stack. For our latest performance data please refer to these pages for [AI](https://developer.nvidia.com/deep-learning-performance-training-inference) and [HPC](https://developer.nvidia.com/hpc-application-performance) benchmarks. + +### Changelog + +March 2021 +- Initial release + +### Known issues + +#### Horovod issues +In certain cases, TensorFlow can structure the graph in such a way that the rank-0 GPU has a different order of Horovod all-2-all calls then the other ranks. This will cause a deadlock. It does not happen in the default settings, but there's a chance it will, especially if you make heavy modifications to the bottom part of the model. To circumvent this, you can run the bottom MLP in data-parallel mode. This causes the computational graphs of each GPU to be very similar, thus eliminating the chance of a deadlock. Note this mode will be up to 10% slower than the default mode. + +#### Checkpointing +TensorFlow runs into issues when trying to save model checkpoints for extremely large variables. We circumvent this by using a custom checkpoint format that splits the variables into pieces and stores each piece independently. However, this custom format cannot be used by the standard inference deployment frameworks such as ONNX. + +#### Inference performance +Current inference performance was evaluated in python using TensorFlow 2.4.0. This provides ease of use and flexibility but is suboptimal in terms of performance. If you're interested in state-of-the-art performance for recommender system inference, please review our results in [the MLPerf v0.7 benchmark](https://mlperf.org/inference-results/) where we used [TensorRT](https://developer.nvidia.com/tensorrt). You might also want to check [the source code of our MLPerf Inference submission](https://github.com/mlcommons/inference_results_v0.7/tree/master/closed/NVIDIA/code/dlrm/tensorrt). + diff --git a/Tensorflow2/Recommendation/DLRM/dataloader.py b/Tensorflow2/Recommendation/DLRM/dataloader.py new file mode 100644 index 00000000..f6c16fae --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/dataloader.py @@ -0,0 +1,73 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# author: Tomasz Grel (tgrel@nvidia.com) + + +from distributed_utils import get_device_mapping +import horovod.tensorflow as hvd +from split_binary_dataset import RawBinaryDataset, DatasetMetadata, DummyDataset + + +def create_input_pipelines(FLAGS): + if FLAGS.dataset_type == 'synthetic': + dataset_metadata = DummyDataset.get_metadata(FLAGS) + elif FLAGS.dataset_type == 'raw': + dataset_metadata = RawBinaryDataset.get_metadata(FLAGS.dataset_path, FLAGS.num_numerical_features) + + multi_gpu_metadata = get_device_mapping(embedding_sizes=dataset_metadata.categorical_cardinalities, + num_gpus=hvd.size(), + data_parallel_bottom_mlp=FLAGS.data_parallel_bottom_mlp, + experimental_columnwise_split=FLAGS.experimental_columnwise_split, + num_numerical_features=FLAGS.num_numerical_features) + + local_tables = multi_gpu_metadata.rank_to_categorical_ids[hvd.rank()] + local_table_sizes = [dataset_metadata.categorical_cardinalities[i] for i in local_tables] + + numerical_features = dataset_metadata.num_numerical_features if hvd.rank() in multi_gpu_metadata.bottom_mlp_ranks else 0 + + if FLAGS.dataset_type == 'synthetic': + train_dataset = DummyDataset(batch_size=FLAGS.batch_size, + num_numerical_features=numerical_features, + num_categorical_features=len(local_table_sizes), + num_batches=FLAGS.synthetic_dataset_train_batches) + + test_dataset = DummyDataset(batch_size=FLAGS.valid_batch_size, + num_numerical_features=numerical_features, + num_categorical_features=len(local_table_sizes), + num_batches=FLAGS.synthetic_dataset_valid_batches) + + elif FLAGS.dataset_type == 'raw': + train_dataset = RawBinaryDataset(data_path=FLAGS.dataset_path, + valid=False, + batch_size=FLAGS.batch_size, + numerical_features=numerical_features, + categorical_features=local_tables, + categorical_feature_sizes=dataset_metadata.categorical_cardinalities, + prefetch_depth=FLAGS.prefetch_batches, + drop_last_batch=True) + + test_dataset = RawBinaryDataset(data_path=FLAGS.dataset_path, + valid=True, + batch_size=FLAGS.valid_batch_size, + numerical_features=numerical_features, + categorical_features=local_tables, + categorical_feature_sizes=dataset_metadata.categorical_cardinalities, + prefetch_depth=FLAGS.prefetch_batches, + drop_last_batch=True) + + else: + raise ValueError(f'Unsupported dataset type: {FLAGS.dataset_type}') + + return train_dataset, test_dataset, dataset_metadata, multi_gpu_metadata diff --git a/Tensorflow2/Recommendation/DLRM/distributed_utils.py b/Tensorflow2/Recommendation/DLRM/distributed_utils.py new file mode 100644 index 00000000..5b1373be --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/distributed_utils.py @@ -0,0 +1,122 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# author: Tomasz Grel (tgrel@nvidia.com) + + +import tensorflow as tf +from collections import deque +import math +import horovod.tensorflow as hvd +from collections import namedtuple + + +class BroadcastingInitializer(tf.keras.initializers.Initializer): + def __init__(self, wrapped): + self.wrapped = wrapped + + def __call__(self, *args, **kwargs): + weights = self.wrapped(*args, **kwargs) + weights = hvd.broadcast(weights, root_rank=0, name='BroadcastingInitializer') + return weights + + def get_config(self): + return {} + + +def argsort(sequence, reverse: bool = False): + idx_pairs = [(x, i) for i, x in enumerate(sequence)] + sorted_pairs = sorted(idx_pairs, key=lambda pair: pair[0], reverse=reverse) + return [i for _, i in sorted_pairs] + + +def distribute_to_buckets(sizes, buckets_num): + def sum_sizes(indices): + return sum(sizes[i] for i in indices) + + max_bucket_size = math.ceil(len(sizes) / buckets_num) + idx_sorted = deque(argsort(sizes, reverse=True)) + buckets = [[] for _ in range(buckets_num)] + final_buckets = [] + + while idx_sorted: + bucket = buckets[0] + bucket.append(idx_sorted.popleft()) + + if len(bucket) == max_bucket_size: + final_buckets.append(buckets.pop(0)) + + buckets.sort(key=sum_sizes) + + final_buckets += buckets + + return final_buckets + + +MultiGpuMetadata = namedtuple('MultiGpuMetadata', + ['bottom_mlp_ranks','rank_to_categorical_ids','rank_to_feature_count']) + + +def get_device_mapping(embedding_sizes, num_gpus, data_parallel_bottom_mlp, + experimental_columnwise_split, num_numerical_features): + """Get device mappings for hybrid parallelism + + Bottom MLP running on device 0. Embeddings will be distributed across among all the devices. + + Optimal solution for partitioning set of N embedding tables into K devices to minimize maximal subset sum + is an NP-hard problem. Additionally, embedding tables distribution should be nearly uniform due to the performance + constraints. Therefore, suboptimal greedy approach with max bucket size is used. + + Args: + embedding_sizes (Sequence[int]): embedding tables sizes + num_gpus (int): Default 8. + + Returns: + device_mapping (dict): + """ + if num_numerical_features == 0: + bottom_mlp_ranks = [] + elif data_parallel_bottom_mlp: + bottom_mlp_ranks = list(range(num_gpus)) + else: + bottom_mlp_ranks = [0] + + if experimental_columnwise_split: + gpu_buckets = num_gpus * [list(range(len(embedding_sizes)))] + + vectors_per_gpu = [len(bucket) for bucket in gpu_buckets] + + if num_numerical_features > 0: + vectors_per_gpu[0] += 1 # count bottom mlp + + return MultiGpuMetadata(bottom_mlp_ranks=bottom_mlp_ranks, + rank_to_categorical_ids=gpu_buckets, + rank_to_feature_count=vectors_per_gpu) + + if num_gpus > 4 and not data_parallel_bottom_mlp and num_numerical_features > 0: + # for higher no. of GPUs, make sure the one with bottom mlp has no embeddings + gpu_buckets = distribute_to_buckets(embedding_sizes, num_gpus - 1) # leave one device out for the bottom MLP + gpu_buckets.insert(0, []) + else: + gpu_buckets = distribute_to_buckets(embedding_sizes, num_gpus) + + vectors_per_gpu = [len(bucket) for bucket in gpu_buckets] + + if not data_parallel_bottom_mlp: + for rank in bottom_mlp_ranks: + vectors_per_gpu[rank] += 1 # count bottom mlp + + return MultiGpuMetadata(bottom_mlp_ranks=bottom_mlp_ranks, + rank_to_categorical_ids=gpu_buckets, + rank_to_feature_count=vectors_per_gpu) \ No newline at end of file diff --git a/Tensorflow2/Recommendation/DLRM/embedding.py b/Tensorflow2/Recommendation/DLRM/embedding.py new file mode 100644 index 00000000..94d0d445 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/embedding.py @@ -0,0 +1,111 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# author: Tomasz Grel (tgrel@nvidia.com) + + +import tensorflow as tf +import math + + +class EmbeddingInitializer(tf.keras.initializers.Initializer): + def __call__(self, shape, dtype=tf.float32): + with tf.device('/CPU:0'): + maxval = tf.sqrt(tf.constant(1.) / tf.cast(shape[0], tf.float32)) + maxval = tf.cast(maxval, dtype=dtype) + minval = -maxval + + weights = tf.random.uniform(shape, minval=minval, maxval=maxval, dtype=dtype) + weights = tf.cast(weights, dtype=tf.float32) + return weights + + def get_config(self): + return {} + + +def _divisors(n): + large_divisors = [] + for i in range(1, int(math.sqrt(n) + 1)): + if n % i == 0: + yield i + if i*i != n: + large_divisors.append(n / i) + for divisor in reversed(large_divisors): + yield int(divisor) + + +def _get_n_chunks(input_dim, output_dim, max_chunk_size): + for n_chunks in _divisors(output_dim): + chunk_output_dim = output_dim / n_chunks + chunk_size = input_dim * chunk_output_dim + if chunk_size < max_chunk_size: + return n_chunks + raise ValueError(f'Unable to split embedding table: [{input_dim}, {output_dim}]') + + +class SplitEmbedding(tf.keras.layers.Layer): + def __init__(self, input_dim, output_dim, trainable=True, max_chunk_size=2**31): + super(SplitEmbedding, self).__init__(dtype=tf.float32) + self.input_dim = input_dim + self.output_dim = output_dim + self.embedding_tables = [] + self.trainable = trainable + + self.n_chunks = _get_n_chunks(input_dim, output_dim, max_chunk_size) + + self.chunk_output_dim = self.output_dim // self.n_chunks + + if self.n_chunks > output_dim: + raise ValueError('Unable to perform a column-wise split of an embedding table!') + + if self.n_chunks > 1: + print(f'Splitting the embedding table: [{input_dim} x {output_dim} into {self.n_chunks}' + f' [{input_dim} x {self.chunk_output_dim}] chunks') + + def build(self, input_shape): + for i in range(self.n_chunks): + w = self.add_weight(f"embedding_table_chunk_{i}", + shape=[self.input_dim, self.chunk_output_dim], + dtype=tf.float32, + initializer=EmbeddingInitializer(), + trainable=self.trainable, + ) + self.embedding_tables.append(w) + + def call(self, indices): + outputs = [] + for embedding_table in self.embedding_tables: + out = tf.gather(params=embedding_table, indices=indices) + outputs.append(out) + return tf.concat(outputs, axis=2) + + +class Embedding(tf.keras.layers.Layer): + def __init__(self, input_dim, output_dim, trainable=True): + super(Embedding, self).__init__(dtype=tf.float32) + self.input_dim = input_dim + self.output_dim = output_dim + self.embedding_table = None + self.trainable = trainable + + def build(self, input_shape): + self.embedding_table = self.add_weight("embedding_table", + shape=[self.input_dim, self.output_dim], + dtype=tf.float32, + initializer=EmbeddingInitializer(), + trainable=self.trainable, + ) + + def call(self, indices): + return tf.gather(params=self.embedding_table, indices=indices) diff --git a/Tensorflow2/Recommendation/DLRM/img/columnwise_split.svg b/Tensorflow2/Recommendation/DLRM/img/columnwise_split.svg new file mode 100644 index 00000000..1ca9b3cf --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/img/columnwise_split.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/Tensorflow2/Recommendation/DLRM/img/dlrm_histograms.svg b/Tensorflow2/Recommendation/DLRM/img/dlrm_histograms.svg new file mode 100644 index 00000000..3895eedf --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/img/dlrm_histograms.svg @@ -0,0 +1,5071 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/Tensorflow2/Recommendation/DLRM/img/hybrid_parallel.svg b/Tensorflow2/Recommendation/DLRM/img/hybrid_parallel.svg new file mode 100644 index 00000000..87ba1a7e --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/img/hybrid_parallel.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/Tensorflow2/Recommendation/DLRM/img/singlegpu_architecture.svg b/Tensorflow2/Recommendation/DLRM/img/singlegpu_architecture.svg new file mode 100644 index 00000000..fe36034e --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/img/singlegpu_architecture.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/Tensorflow2/Recommendation/DLRM/interaction.py b/Tensorflow2/Recommendation/DLRM/interaction.py new file mode 100644 index 00000000..4714c70c --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/interaction.py @@ -0,0 +1,67 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import tensorflow as tf + + +def dot_interact(concat_features, bottom_mlp_out=None, skip_gather=False): + # Interact features, select lower-triangular portion, and re-shape. + interactions = tf.matmul(concat_features, concat_features, transpose_b=True) + + ones = tf.ones_like(interactions, dtype=tf.float32) + upper_tri_mask = tf.linalg.band_part(ones, 0, -1) + + feature_dim = tf.shape(interactions)[-1] + + if skip_gather: + upper_tri_bool = tf.cast(upper_tri_mask, tf.bool) + activations = tf.where( + condition=upper_tri_bool, x=tf.zeros_like(interactions), y=interactions) + out_dim = feature_dim * feature_dim + else: + lower_tri_mask = ones - upper_tri_mask + activations = tf.boolean_mask(interactions, lower_tri_mask) + out_dim = feature_dim * (feature_dim - 1) // 2 + + activations = tf.reshape(activations, shape=[-1, out_dim]) + + if bottom_mlp_out is not None: + bottom_mlp_out = tf.squeeze(bottom_mlp_out) + activations = tf.concat([activations, bottom_mlp_out], axis=1) + + return activations + + +def dummy_dot_interact(concat_features, bottom_mlp_out=None): + batch_size = tf.shape(concat_features)[0] + num_features = tf.shape(concat_features)[1] + concat_features = tf.math.reduce_mean(concat_features, axis=[2], keepdims=True) + return dot_interact(concat_features, bottom_mlp_out) diff --git a/Tensorflow2/Recommendation/DLRM/lr_scheduler.py b/Tensorflow2/Recommendation/DLRM/lr_scheduler.py new file mode 100644 index 00000000..71d594e0 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/lr_scheduler.py @@ -0,0 +1,64 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# author: Tomasz Grel (tgrel@nvidia.com) + + +import tensorflow as tf + + +class LearningRateScheduler: + """ + LR Scheduler combining Polynomial Decay with Warmup at the beginning. + TF-based cond operations necessary for performance in graph mode. + """ + + def __init__(self, optimizers, base_lr, warmup_steps, decay_start_step, decay_steps): + self.optimizers = optimizers + self.warmup_steps = tf.constant(warmup_steps, dtype=tf.int32) + self.decay_start_step = tf.constant(decay_start_step, dtype=tf.int32) + self.decay_steps = tf.constant(decay_steps) + self.decay_end_step = decay_start_step + decay_steps + self.poly_power = 2 + self.base_lr = base_lr + with tf.device('/CPU:0'): + self.step = tf.Variable(0) + + @tf.function + def __call__(self): + with tf.device('/CPU:0'): + # used for the warmup stage + warmup_step = tf.cast(1 / self.warmup_steps, tf.float32) + lr_factor_warmup = 1 - tf.cast(self.warmup_steps - self.step, tf.float32) * warmup_step + lr_factor_warmup = tf.cast(lr_factor_warmup, tf.float32) + + # used for the constant stage + lr_factor_constant = tf.cast(1., tf.float32) + + # used for the decay stage + lr_factor_decay = (self.decay_end_step - self.step) / self.decay_steps + lr_factor_decay = tf.math.pow(lr_factor_decay, self.poly_power) + lr_factor_decay = tf.cast(lr_factor_decay, tf.float32) + + poly_schedule = tf.cond(self.step < self.decay_start_step, lambda: lr_factor_constant, + lambda: lr_factor_decay) + + lr_factor = tf.cond(self.step < self.warmup_steps, lambda: lr_factor_warmup, + lambda: poly_schedule) + + lr = self.base_lr * lr_factor + for optimizer in self.optimizers: + optimizer.lr.assign(lr) + + self.step.assign(self.step + 1) diff --git a/Tensorflow2/Recommendation/DLRM/main.py b/Tensorflow2/Recommendation/DLRM/main.py new file mode 100644 index 00000000..e65988ab --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/main.py @@ -0,0 +1,315 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# author: Tomasz Grel (tgrel@nvidia.com) + + +from absl import app, flags +import os +import sys + +# Define the flags first before importing TensorFlow. +# Otherwise, enabling XLA-Lite would be impossible with a command-line flag +def define_command_line_flags(): + flags.DEFINE_enum("mode", default="train", enum_values=['inference', 'eval', 'train'], + help='Choose "train" to train the model, "inference" to benchmark inference' + ' and "eval" to run validation') + flags.DEFINE_float("learning_rate", default=24, help="Learning rate") + flags.DEFINE_integer("batch_size", default=64 * 1024, help="Batch size used for training") + flags.DEFINE_integer("valid_batch_size", default=64 * 1024, help="Batch size used for validation") + flags.DEFINE_bool("run_eagerly", default=False, help="Disable all tf.function decorators for debugging") + + flags.DEFINE_bool("dummy_model", default=False, help="Use a dummy model for benchmarking and debugging") + flags.DEFINE_bool("dummy_embedding", default=False, help="") + + flags.DEFINE_list("top_mlp_dims", [1024, 1024, 512, 256, 1], "Linear layer sizes for the top MLP") + flags.DEFINE_list("bottom_mlp_dims", [512, 256, 128], "Linear layer sizes for the bottom MLP") + + flags.DEFINE_enum("optimizer", default="sgd", enum_values=['sgd', 'adam'], + help='The optimization algorithm to be used.') + + flags.DEFINE_string("save_checkpoint_path", default=None, + help="Path to which to save a checkpoint file at the end of the training") + flags.DEFINE_string("restore_checkpoint_path", default=None, + help="Path from which to restore a checkpoint before training") + + flags.DEFINE_enum("dataset_type", default="raw", enum_values=['raw', 'synthetic'], + help='The type of the dataset to use') + flags.DEFINE_integer("num_numerical_features", default=13, + help='Number of numerical features to be read from the dataset. ' + 'If set to 0, then no numerical features will be loaded ' + 'and the Bottom MLP will not be evaluated') + + flags.DEFINE_integer('synthetic_dataset_train_batches', default=64008, + help='Number of training batches in the synthetic dataset') + flags.DEFINE_integer('synthetic_dataset_valid_batches', default=1350, + help='Number of validation batches in the synthetic dataset') + flags.DEFINE_list('synthetic_dataset_cardinalities', default=26*[1000], + help='Number of categories for each embedding table of the synthetic dataset') + + flags.DEFINE_bool("amp", default=False, help="Enable automatic mixed precision") + flags.DEFINE_bool("xla", default=False, help="Enable XLA") + + flags.DEFINE_integer("loss_scale", default=1024, help="Static loss scale to use with mixed precision training") + + flags.DEFINE_integer("prefetch_batches", default=10, + help="The number of batches to prefetch for the dataloader") + + flags.DEFINE_integer("auc_thresholds", default=8000, + help="Number of thresholds for the AUC computation") + + flags.DEFINE_integer("epochs", default=1, help="Number of epochs to train for") + flags.DEFINE_integer("max_steps", default=-1, help="Stop the training/inference after this many optimiation steps") + + flags.DEFINE_string("embedding_type", default="split_embedding", + help="Embedding type to use, possible choices: embedding, split_embedding") + flags.DEFINE_bool("embedding_trainable", default=True, help="If True the embeddings will be trainable, otherwise frozen") + + flags.DEFINE_string("dot_interaction", default="custom_cuda", + help="Dot interaction implementation to use, possible choices: custom_cuda, tensorflow, dummy") + + flags.DEFINE_string("dataset_path", default=None, + help="Path to the JSON file with the sizes of embedding tables") + + flags.DEFINE_integer("embedding_dim", default=128, help='Number of columns in the embedding tables') + + flags.DEFINE_integer("evals_per_epoch", default=1, help='Number of evaluations per epoch') + flags.DEFINE_float("print_freq", default=1000, help='Number of steps between debug prints') + + flags.DEFINE_integer("warmup_steps", default=8000, + help='Number of steps over which to linearly increase the LR at the beginning') + flags.DEFINE_integer("decay_start_step", default=48000, help='Optimization step at which to start the poly LR decay') + flags.DEFINE_integer("decay_steps", default=24000, help='Number of steps over which to decay from base LR to 0') + + flags.DEFINE_integer("profiler_start_step", default=None, help='Step at which to start profiling') + flags.DEFINE_integer("profiled_rank", default=1, help='Rank to profile') + + flags.DEFINE_integer("inter_op_parallelism", default=None, help='Number of inter op threads') + flags.DEFINE_integer("intra_op_parallelism", default=None, help='Number of intra op threads') + + flags.DEFINE_integer("tf_gpu_memory_limit_gb", default=24, + help='Gigabytes of GPU memory reserved for TensorFlow. Only applied in multiGPU/multiNode to leave' + ' enough memory for NCCL to operate properly.') + + flags.DEFINE_bool("data_parallel_bottom_mlp", default=False, help="Run the bottom MLP in data-parallel mode") + flags.DEFINE_bool("experimental_columnwise_split", default=False, + help="Enable slicing individual embedding tables across multiple devices") + + flags.DEFINE_string("log_path", default='dlrm_tf_log.json', help="Path to JSON file for storing benchmark results") + + +define_command_line_flags() + +FLAGS = flags.FLAGS +app.define_help_flags() +app.parse_flags_with_usage(sys.argv) + +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' +if FLAGS.xla: + os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=fusible' + +import time +from lr_scheduler import LearningRateScheduler +import tensorflow as tf +import tensorflow_addons as tfa +import numpy as np +from utils import IterTimer, init_logging, dist_print +from dataloader import create_input_pipelines +from model import Dlrm, DummyDlrm, DlrmTrainer, evaluate, DataParallelSplitter +import horovod.tensorflow as hvd +from tensorflow.keras.mixed_precision import LossScaleOptimizer +import dllogger + +def init_tf(FLAGS): + """ + Set global options for TensorFlow + """ + + gpus = tf.config.experimental.list_physical_devices('GPU') + + if gpus: + tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU') + + if hvd.size() > 1: + memory_limit_mb = FLAGS.tf_gpu_memory_limit_gb * 1024 + print(f"Limiting TF memory to: {memory_limit_mb} MB") + + tf.config.set_logical_device_configuration(gpus[hvd.local_rank()], + [tf.config.LogicalDeviceConfiguration(memory_limit=memory_limit_mb)]) + tf.config.experimental.set_virtual_device_configuration( + gpus[hvd.local_rank()], + [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=memory_limit_mb)], + ) + + if FLAGS.amp: + policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16", loss_scale=FLAGS.loss_scale) + tf.keras.mixed_precision.experimental.set_policy(policy) + + tf.config.run_functions_eagerly(FLAGS.run_eagerly) + + if FLAGS.inter_op_parallelism: + tf.config.threading.set_inter_op_parallelism_threads(FLAGS.inter_op_parallelism) + + if FLAGS.intra_op_parallelism: + tf.config.threading.set_intra_op_parallelism_threads(FLAGS.intra_op_parallelism) + + if FLAGS.xla: + os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=fusible' + + +def compute_eval_points(train_batches, evals_per_epoch): + eval_points = np.linspace(0, train_batches - 1, evals_per_epoch + 1)[1:] + eval_points = np.round(eval_points).tolist() + return eval_points + + +def inference_benchmark(validation_pipeline, dlrm, timer, splitter, FLAGS): + if FLAGS.max_steps == -1: + FLAGS.max_steps = 1000 + + _, _, latencies = evaluate(validation_pipeline, dlrm, + timer, auc_thresholds=None, + data_parallel_splitter=splitter, + max_steps=FLAGS.max_steps) + + # don't benchmark the first few warmup steps + latencies = latencies[10:] + result_data = { + 'mean_inference_throughput': FLAGS.valid_batch_size / np.mean(latencies), + 'mean_inference_latency': np.mean(latencies) + } + + for percentile in [90, 95, 99]: + result_data[f'p{percentile}_inference_latency'] = np.percentile(latencies, percentile) + dllogger.log(data=result_data, step=tuple()) + + +def main(argv): + if FLAGS.experimental_columnwise_split and not FLAGS.data_parallel_bottom_mlp and FLAGS.num_numerical_features > 0: + raise ValueError('Currently you when using the --experimenal_columnwise_split option ' + 'you must either set --data_parallel_bottom_mlp or --num_numerical_features=0') + + if FLAGS.batch_size != FLAGS.valid_batch_size: + raise ValueError('For now, validation batch size must be the same as training batch size') + + hvd.init() + init_logging(log_path=FLAGS.log_path, FLAGS=FLAGS) + init_tf(FLAGS) + + train_pipeline, validation_pipeline, dataset_metadata, multi_gpu_metadata = create_input_pipelines(FLAGS) + + if FLAGS.dummy_model: + dlrm = DummyDlrm(FLAGS=FLAGS, dataset_metadata=dataset_metadata, + multi_gpu_metadata=multi_gpu_metadata) + else: + dlrm = Dlrm(FLAGS=FLAGS, dataset_metadata=dataset_metadata, + multi_gpu_metadata=multi_gpu_metadata) + + if FLAGS.optimizer == 'sgd': + embedding_optimizer = tf.keras.optimizers.SGD(lr=FLAGS.learning_rate, momentum=0) + if FLAGS.amp: + embedding_optimizer = LossScaleOptimizer(embedding_optimizer, + initial_scale=FLAGS.loss_scale, + dynamic=False) + mlp_optimizer = embedding_optimizer + optimizers = [mlp_optimizer] + + elif FLAGS.optimizer == 'adam': + embedding_optimizer = tfa.optimizers.LazyAdam(lr=FLAGS.learning_rate) + mlp_optimizer = tf.keras.optimizers.Adam(lr=FLAGS.learning_rate) + if FLAGS.amp: + embedding_optimizer = LossScaleOptimizer(embedding_optimizer, + initial_scale=FLAGS.loss_scale, + dynamic=False) + mlp_optimizer = LossScaleOptimizer(mlp_optimizer, + initial_scale=FLAGS.loss_scale, + dynamic=False) + optimizers = [mlp_optimizer, embedding_optimizer] + + scheduler = LearningRateScheduler(optimizers, + warmup_steps=FLAGS.warmup_steps, + base_lr=FLAGS.learning_rate, + decay_start_step=FLAGS.decay_start_step, + decay_steps=FLAGS.decay_steps) + + timer = IterTimer(train_batch_size=FLAGS.batch_size, test_batch_size=FLAGS.valid_batch_size, + optimizer=embedding_optimizer, print_freq=FLAGS.print_freq, enabled=hvd.rank() == 0) + + splitter = DataParallelSplitter(batch_size=FLAGS.batch_size) + + dlrm.maybe_restore_checkpoint(FLAGS.restore_checkpoint_path) + + if FLAGS.mode == 'inference': + inference_benchmark(validation_pipeline, dlrm, timer, splitter, FLAGS) + return + + elif FLAGS.mode == 'eval': + test_auc, test_loss, _ = evaluate(validation_pipeline, dlrm, + timer, auc_thresholds=FLAGS.auc_thresholds, + data_parallel_splitter=splitter) + dist_print(f'Evaluation completed, AUC: {test_auc:.6f}, test_loss: {test_loss:.6f}') + return + + eval_points = compute_eval_points(train_batches=len(train_pipeline), + evals_per_epoch=FLAGS.evals_per_epoch) + + trainer = DlrmTrainer(dlrm, embedding_optimizer=embedding_optimizer, + mlp_optimizer=mlp_optimizer, amp=FLAGS.amp, + lr_scheduler=scheduler) + + best_auc = 0 + train_begin = time.time() + for epoch in range(FLAGS.epochs): + for step, ((numerical_features, categorical_features), labels) in enumerate(train_pipeline): + if step == FLAGS.profiler_start_step and hvd.rank() == FLAGS.profiled_rank: + tf.profiler.experimental.start('logdir') + + if FLAGS.profiler_start_step and step == FLAGS.profiler_start_step + 100 and hvd.rank() == FLAGS.profiled_rank: + tf.profiler.experimental.stop() + + labels = splitter(labels) + if FLAGS.data_parallel_bottom_mlp: + numerical_features = splitter(numerical_features) + + loss = trainer.train_step(numerical_features, categorical_features, labels) + + timer.step_train(loss=loss) + + if FLAGS.max_steps != -1 and step > FLAGS.max_steps: + dist_print(f'Max steps of {FLAGS.max_steps} reached, exiting') + break + + if step in eval_points: + test_auc, test_loss, _ = evaluate(validation_pipeline, dlrm, + timer, FLAGS.auc_thresholds, + data_parallel_splitter=splitter) + dist_print(f'Evaluation completed, AUC: {test_auc:.6f}, test_loss: {test_loss:.6f}') + timer.test_idx = 0 + best_auc = max(best_auc, test_auc) + + elapsed = time.time() - train_begin + dlrm.maybe_save_checkpoint(FLAGS.save_checkpoint_path) + + if hvd.rank() == 0: + dist_print(f'Training run completed, elapsed: {elapsed:.0f} [s]') + results = { + 'throughput': FLAGS.batch_size / timer.mean_train_time(), + 'mean_step_time_ms': timer.mean_train_time() * 1000, + 'auc': best_auc + } + dllogger.log(data=results, step=tuple()) + + +if __name__ == '__main__': + app.run(main) diff --git a/Tensorflow2/Recommendation/DLRM/model.py b/Tensorflow2/Recommendation/DLRM/model.py new file mode 100644 index 00000000..f4c8b0d1 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/model.py @@ -0,0 +1,606 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# author: Tomasz Grel (tgrel@nvidia.com) + + +import tensorflow as tf +from embedding import Embedding, SplitEmbedding +import interaction +import tensorflow.keras.initializers as initializers +import math +import horovod.tensorflow as hvd +from distributed_utils import BroadcastingInitializer +import numpy as np +import time +from utils import dist_print + +try: + from tensorflow_dot_based_interact.python.ops import dot_based_interact_ops +except ImportError: + print('WARNING: Could not import the custom dot-interaction kernels') + + +def scale_grad(grad, factor): + if isinstance(grad, tf.IndexedSlices): + # sparse gradient + grad._values = grad._values * factor + return grad + else: + # dense gradient + return grad * factor + + +class DataParallelSplitter: + def __init__(self, batch_size): + local_batch_size = (batch_size // hvd.size()) + if local_batch_size % 1 != 0: + raise ValueError("Global batch size must be divisible by the number of workers!") + local_batch_size = int(local_batch_size) + + batch_sizes_per_gpu = [local_batch_size] * hvd.size() + indices = tuple(np.cumsum([0] + list(batch_sizes_per_gpu))) + self.begin_idx = indices[hvd.rank()] + self.end_idx = indices[hvd.rank() + 1] + + def __call__(self, x): + x = x[self.begin_idx:self.end_idx] + x = tf.cast(x, dtype=tf.float32) + return x + + +class DlrmTrainer: + def __init__(self, dlrm, embedding_optimizer, mlp_optimizer, amp, lr_scheduler): + self.dlrm = dlrm + self.embedding_optimizer = embedding_optimizer + self.mlp_optimizer = mlp_optimizer + self.amp = amp + self.lr_scheduler = lr_scheduler + self.bce = tf.keras.losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE, + from_logits=True) + + def _bottom_part_weight_update(self, unscaled_gradients): + bottom_gradients = self.dlrm.extract_bottom_gradients(unscaled_gradients) + + if hvd.size() > 1: + # need to correct for allreduced gradients being averaged and model-parallel ones not + bottom_gradients = [scale_grad(g, 1 / hvd.size()) for g in bottom_gradients] + + if self.mlp_optimizer is self.embedding_optimizer: + self.mlp_optimizer.apply_gradients(zip(bottom_gradients, self.dlrm.bottom_variables)) + else: + bottom_grads_and_vars = list(zip(bottom_gradients, self.dlrm.bottom_variables)) + + embedding_grads_and_vars = [(g,v) for g,v in bottom_grads_and_vars if 'embedding' in v.name] + bottom_mlp_grads_and_vars = [(g,v) for g,v in bottom_grads_and_vars if 'embedding' not in v.name] + + self.mlp_optimizer.apply_gradients(bottom_mlp_grads_and_vars) + self.embedding_optimizer.apply_gradients(embedding_grads_and_vars) + + def _top_part_weight_update(self, unscaled_gradients): + top_gradients = self.dlrm.extract_top_gradients(unscaled_gradients) + + if hvd.size() > 1: + top_gradients = [hvd.allreduce(g, name="top_gradient_{}".format(i), op=hvd.Average, + compression=hvd.compression.NoneCompressor) for i, g in + enumerate(top_gradients)] + + self.mlp_optimizer.apply_gradients(zip(top_gradients, self.dlrm.top_variables)) + + @tf.function + def train_step(self, numerical_features, categorical_features, labels): + self.lr_scheduler() + + with tf.GradientTape() as tape: + predictions = self.dlrm(inputs=(numerical_features, categorical_features), + training=True) + + unscaled_loss = self.bce(labels, predictions) + # tf keras doesn't reduce the loss when using a Custom Training Loop + unscaled_loss = tf.math.reduce_mean(unscaled_loss) + scaled_loss = self.mlp_optimizer.get_scaled_loss(unscaled_loss) if self.amp else unscaled_loss + + scaled_gradients = tape.gradient(scaled_loss, self.dlrm.trainable_variables) + + if self.amp: + unscaled_gradients = self.mlp_optimizer.get_unscaled_gradients(scaled_gradients) + else: + unscaled_gradients = scaled_gradients + + self._bottom_part_weight_update(unscaled_gradients) + self._top_part_weight_update(unscaled_gradients) + + if hvd.size() > 1: + # compute mean loss for all workers for reporting + mean_loss = hvd.allreduce(unscaled_loss, name="mean_loss", op=hvd.Average) + else: + mean_loss = unscaled_loss + + return mean_loss + + +def evaluate(validation_pipeline, dlrm, timer, auc_thresholds, + data_parallel_splitter, max_steps=None): + + if auc_thresholds is not None: + auc_metric = tf.keras.metrics.AUC(num_thresholds=auc_thresholds, + curve='ROC', summation_method='interpolation', + name='my_auc') + else: + auc_metric = None + + bce_op = tf.keras.losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE, + from_logits=False) + auc, test_loss = 0, 0 + latencies, all_test_losses = [], [] + distributed = hvd.size() != 1 + iterator = enumerate(validation_pipeline) + while True: + begin = time.time() + + try: + eval_step, ((numerical_features, categorical_features), labels) = next(iterator) + except StopIteration: + break + + if dlrm.data_parallel_bottom_mlp: + numerical_features = data_parallel_splitter(numerical_features) + + if max_steps is not None and eval_step >= max_steps: + break + + y_pred = dlrm((numerical_features, categorical_features), sigmoid=True) + end = time.time() + latency = end - begin + latencies.append(latency) + + if distributed: + y_pred = hvd.allgather(y_pred) + + timer.step_test() + if hvd.rank() == 0 and auc_metric is not None: + auc_metric.update_state(labels, y_pred) + test_loss = bce_op(labels, y_pred) + all_test_losses.append(test_loss) + + if hvd.rank() == 0 and auc_metric is not None: + auc = auc_metric.result().numpy().item() + test_loss = tf.reduce_mean(all_test_losses).numpy().item() + + return auc, test_loss, latencies + + +class Dlrm(tf.keras.Model): + def __init__(self, FLAGS, dataset_metadata, multi_gpu_metadata): + super(Dlrm, self).__init__() + local_table_ids = multi_gpu_metadata.rank_to_categorical_ids[hvd.rank()] + self.table_sizes = [dataset_metadata.categorical_cardinalities[i] for i in local_table_ids] + self.rank_to_feature_count = multi_gpu_metadata.rank_to_feature_count + self.distributed = hvd.size() > 1 + self.batch_size = FLAGS.batch_size + self.num_all_categorical_features = len(dataset_metadata.categorical_cardinalities) + + self.amp = FLAGS.amp + self.dataset_metadata = dataset_metadata + + self.embedding_dim = FLAGS.embedding_dim + + if FLAGS.dot_interaction == 'custom_cuda': + self.interact_op = dot_based_interact_ops.dot_based_interact + elif FLAGS.dot_interaction == 'tensorflow': + self.interact_op = interaction.dot_interact + elif FLAGS.dot_interaction == 'dummy': + self.interact_op = interaction.dummy_dot_interact + else: + raise ValueError(f'Unknown dot-interaction implementation {FLAGS.dot_interaction}') + + self.dummy_embedding = FLAGS.dummy_embedding + + self.experimental_columnwise_split = FLAGS.experimental_columnwise_split + self.data_parallel_bottom_mlp = FLAGS.data_parallel_bottom_mlp + + if self.experimental_columnwise_split: + self.local_embedding_dim = self.embedding_dim // hvd.size() + else: + self.local_embedding_dim = self.embedding_dim + + self.embedding_type = FLAGS.embedding_type + self.embedding_trainable = FLAGS.embedding_trainable + + self.bottom_mlp_dims = [int(d) for d in FLAGS.bottom_mlp_dims] + self.top_mlp_dims = [int(d) for d in FLAGS.top_mlp_dims] + + self.top_mlp_padding = None + self.bottom_mlp_padding = None + + self.variables_partitioned = False + self.running_bottom_mlp = (not self.distributed) or (hvd.rank() == 0) or self.data_parallel_bottom_mlp + + self.num_numerical_features = FLAGS.num_numerical_features + # override in case there's no numerical features in the dataset + if self.num_numerical_features == 0: + self.running_bottom_mlp = False + + if self.running_bottom_mlp: + self._create_bottom_mlp() + self._create_embeddings() + self._create_top_mlp() + + # write embedding checkpoints of 1M rows at a time + self.embedding_checkpoint_batch = 1024 * 1024 + + def _create_bottom_mlp_padding(self, multiple=8): + num_features = self.dataset_metadata.num_numerical_features + pad_to = tf.math.ceil(num_features / multiple) * multiple + pad_to = tf.cast(pad_to, dtype=tf.int32) + padding_features = pad_to - num_features + + batch_size = self.batch_size // hvd.size() if self.data_parallel_bottom_mlp else self.batch_size + + padding_shape = [batch_size, padding_features] + dtype=tf.float16 if self.amp else tf.float32 + self.bottom_mlp_padding = self.add_weight("bottom_mlp_padding", shape=padding_shape, dtype=dtype, + initializer=initializers.Zeros(), trainable=False) + + def _create_top_mlp_padding(self, multiple=8): + num_features = self.num_all_categorical_features + if self.num_numerical_features != 0: + num_features += 1 + num_features = num_features * (num_features - 1) + num_features = num_features // 2 + num_features = num_features + self.embedding_dim + + pad_to = tf.math.ceil(num_features / multiple) * multiple + pad_to = tf.cast(pad_to, dtype=tf.int32) + padding_features = pad_to - num_features + + padding_shape = [self.batch_size // hvd.size(), padding_features] + dtype=tf.float16 if self.amp else tf.float32 + self.top_mlp_padding = self.add_weight("top_mlp_padding", shape=padding_shape, dtype=dtype, + initializer=initializers.Zeros(), trainable=False) + + def _create_bottom_mlp(self): + self._create_bottom_mlp_padding() + self.bottom_mlp_layers = [] + for dim in self.bottom_mlp_dims: + kernel_initializer = initializers.GlorotNormal() + bias_initializer = initializers.RandomNormal(stddev=math.sqrt(1. / dim)) + + if self.data_parallel_bottom_mlp: + kernel_initializer = BroadcastingInitializer(kernel_initializer) + bias_initializer = BroadcastingInitializer(bias_initializer) + + l = tf.keras.layers.Dense(dim, activation='relu', + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer) + self.bottom_mlp_layers.append(l) + + def _create_top_mlp(self): + self._create_top_mlp_padding() + self.top_mlp = [] + for i, dim in enumerate(self.top_mlp_dims): + if i == len(self.top_mlp_dims) - 1: + # final layer + activation = 'linear' + else: + activation = 'relu' + + kernel_initializer = BroadcastingInitializer(initializers.GlorotNormal()) + bias_initializer = BroadcastingInitializer(initializers.RandomNormal(stddev=math.sqrt(1. / dim))) + + l = tf.keras.layers.Dense(dim, activation=activation, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer) + self.top_mlp.append(l) + + def _create_embeddings(self): + self.embedding_layers = [] + if self.embedding_type == 'embedding': + for i, table_size in enumerate(self.table_sizes): + l = Embedding(input_dim=table_size, + output_dim=self.local_embedding_dim, + trainable=self.embedding_trainable) + self.embedding_layers.append(l) + + elif self.embedding_type == 'split_embedding': + for i, table_size in enumerate(self.table_sizes): + l = SplitEmbedding(input_dim=table_size, + output_dim=self.local_embedding_dim, + trainable=self.embedding_trainable) + self.embedding_layers.append(l) + else: + raise ValueError(f'Unknown embedding type {self.embedding_type}') + + def _partition_variables(self): + self.bottom_variables = [v for v in self.trainable_variables if 'bottom_model' in v.name] + self.bottom_variable_indices = [i for i,v in enumerate(self.trainable_variables) if 'bottom_model' in v.name] + + self.top_variables = [v for v in self.trainable_variables if 'bottom_model' not in v.name] + self.top_variable_indices = [i for i, v in enumerate(self.trainable_variables) if 'bottom_model' not in v.name] + self.variables_partitioned = True + + def extract_bottom_gradients(self, all_gradients): + if not self.variables_partitioned: + self._partition_variables() + return [all_gradients[i] for i in self.bottom_variable_indices] + + def extract_top_gradients(self, all_gradients): + if not self.variables_partitioned: + self._partition_variables() + return [all_gradients[i] for i in self.top_variable_indices] + + def force_initialization(self): + if self.running_bottom_mlp: + if self.data_parallel_bottom_mlp: + numerical_features = tf.zeros(shape=[self.batch_size / hvd.size(), + self.dataset_metadata.num_numerical_features]) + else: + numerical_features = tf.zeros(shape=[self.batch_size, + self.dataset_metadata.num_numerical_features]) + else: + numerical_features = None + + categorical_features = [tf.zeros(shape=[self.batch_size, 1], dtype=tf.int32)] * len(self.table_sizes) + self((numerical_features, categorical_features)) + + @tf.function + def call(self, inputs, sigmoid=False): + numerical_features, cat_features = inputs + embedding_outputs = self._call_embeddings(cat_features) + + if self.running_bottom_mlp: + bottom_mlp_out = self._call_bottom_mlp(numerical_features) + else: + bottom_mlp_out = None + + if self.distributed: + if self.experimental_columnwise_split: + interaction_input = self._call_alltoall_experimental_columnwise( + embedding_outputs, + bottom_mlp_out) + else: + interaction_input = self._call_alltoall(embedding_outputs, bottom_mlp_out) + else: + if bottom_mlp_out is not None: + bottom_part_output = tf.concat([bottom_mlp_out] + embedding_outputs, + axis=1) + else: + bottom_part_output = tf.concat(embedding_outputs, axis=1) + + num_categorical_features = len(self.dataset_metadata.categorical_cardinalities) + interaction_input = tf.reshape(bottom_part_output, + [-1, num_categorical_features + 1, + self.embedding_dim]) + + if not self.data_parallel_bottom_mlp: + bottom_mlp_out = interaction_input[:, 0, :] + + x = self.interact_op(interaction_input, tf.squeeze(bottom_mlp_out)) + x = self._call_top_mlp(x) + + if sigmoid: + x = tf.math.sigmoid(x) + return x + + def _call_bottom_mlp(self, numerical_features): + if self.amp: + numerical_features = tf.cast(numerical_features, dtype=tf.float16) + x = tf.concat([numerical_features, self.bottom_mlp_padding], axis=1) + + name_scope = "bottom_mlp" if self.data_parallel_bottom_mlp else "bottom_model" + with tf.name_scope(name_scope): + for l in self.bottom_mlp_layers: + x = l(x) + x = tf.expand_dims(x, axis=1) + bottom_mlp_out = x + return bottom_mlp_out + + def _call_dummy_embeddings(self, cat_features): + batch_size = tf.shape(cat_features)[0] + num_embeddings = tf.shape(cat_features)[1] + dtype = tf.float16 if self.amp else tf.float32 + return [tf.zeros(shape=[batch_size, num_embeddings, self.embedding_dim], dtype=dtype)] + + def _call_embeddings(self, cat_features): + if self.dummy_embedding: + return self._call_dummy_embeddings(cat_features) + + with tf.name_scope("bottom_model"): + embedding_outputs = [] + if self.table_sizes: + for i, l in enumerate(self.embedding_layers): + indices = tf.cast(cat_features[i], tf.int32) + out = l(indices) + embedding_outputs.append(out) + if self.amp: + embedding_outputs = [tf.cast(e, dtype=tf.float16) for e in embedding_outputs] + return embedding_outputs + + def _call_alltoall(self, embedding_outputs, bottom_mlp_out=None): + num_tables = len(self.table_sizes) + if bottom_mlp_out is not None and not self.data_parallel_bottom_mlp: + bottom_part_output = tf.concat([bottom_mlp_out] + embedding_outputs, + axis=1) + num_tables += 1 + else: + bottom_part_output = tf.concat(embedding_outputs, axis=1) + + global_batch = tf.shape(bottom_part_output)[0] + world_size = hvd.size() + local_batch = global_batch // world_size + embedding_dim = self.embedding_dim + + alltoall_input = tf.reshape(bottom_part_output, + shape=[global_batch * num_tables, + embedding_dim]) + + splits = [tf.shape(alltoall_input)[0] // world_size] * world_size + + alltoall_output = hvd.alltoall(tensor=alltoall_input, splits=splits, ignore_name_scope=True) + + vectors_per_worker = [x * local_batch for x in self.rank_to_feature_count] + alltoall_output = tf.split(alltoall_output, + num_or_size_splits=vectors_per_worker, + axis=0) + interaction_input = [tf.reshape(x, shape=[local_batch, -1, embedding_dim]) for x in alltoall_output] + + if self.data_parallel_bottom_mlp: + interaction_input = [bottom_mlp_out] + interaction_input + + interaction_input = tf.concat(interaction_input, axis=1) # shape=[local_batch, num_vectors, vector_dim] + return interaction_input + + def _call_alltoall_experimental_columnwise(self, embedding_outputs, bottom_mlp_out): + bottom_part_output = tf.concat(embedding_outputs, axis=1) + + global_batch = tf.shape(bottom_part_output)[0] + world_size = hvd.size() + local_batch = global_batch // world_size + num_tables = len(self.table_sizes) + + alltoall_input = tf.transpose(bottom_part_output, perm=[0, 2, 1]) + alltoall_input = tf.reshape(alltoall_input, shape=[global_batch * self.local_embedding_dim, + num_tables]) + + splits = [tf.shape(alltoall_input)[0] // world_size] * world_size + + alltoall_output = hvd.alltoall(tensor=alltoall_input, splits=splits, ignore_name_scope=True) + + alltoall_output = tf.split(alltoall_output, + num_or_size_splits=hvd.size(), + axis=0) + interaction_input = [tf.reshape(x, shape=[local_batch, + self.local_embedding_dim, num_tables]) for x in alltoall_output] + + interaction_input = tf.concat(interaction_input, axis=1) # shape=[local_batch, vector_dim, num_tables] + interaction_input = tf.transpose(interaction_input, + perm=[0, 2, 1]) # shape=[local_batch, num_tables, vector_dim] + + if self.running_bottom_mlp: + interaction_input = tf.concat([bottom_mlp_out, + interaction_input], + axis=1) # shape=[local_batch, num_tables + 1, vector_dim] + return interaction_input + + def _call_top_mlp(self, x): + if self.interact_op != 'custom_cuda': + x = tf.concat([x, self.top_mlp_padding], axis=1) + + with tf.name_scope("top_model"): + for i, l in enumerate(self.top_mlp): + x = l(x) + x = tf.cast(x, dtype=tf.float32) + return x + + @staticmethod + def _get_variable_path(checkpoint_path, v, i=0): + checkpoint_path = checkpoint_path + f'_rank_{hvd.rank()}' + name = v.name.replace('/', '_').replace(':', '_') + return checkpoint_path + '_' + name + f'_{i}' + '.npy' + + def maybe_save_checkpoint(self, checkpoint_path): + if checkpoint_path is None: + return + + dist_print('Saving a checkpoint...') + for v in self.trainable_variables: + filename = self._get_variable_path(checkpoint_path, v) + if 'embedding' not in v.name: + np.save(arr=v.numpy(), file=filename) + continue + print(f'saving embedding {v.name}') + chunks = math.ceil(v.shape[0] / self.embedding_checkpoint_batch) + for i in range(chunks): + filename = self._get_variable_path(checkpoint_path, v, i) + end = min((i + 1) * self.embedding_checkpoint_batch, v.shape[0]) + + indices = tf.range(start=i * self.embedding_checkpoint_batch, + limit=end, + dtype=tf.int32) + + arr = tf.gather(params=v, indices=indices, axis=0) + arr = arr.numpy() + np.save(arr=arr, file=filename) + + dist_print('Saved a checkpoint to ', checkpoint_path) + + def maybe_restore_checkpoint(self, checkpoint_path): + if checkpoint_path is None: + return + + dist_print('Restoring a checkpoint...') + self.force_initialization() + + for v in self.trainable_variables: + filename = self._get_variable_path(checkpoint_path, v) + if 'embedding' not in v.name: + numpy_var = np.load(file=filename) + v.assign(numpy_var) + continue + + chunks = math.ceil(v.shape[0] / self.embedding_checkpoint_batch) + for i in range(chunks): + filename = self._get_variable_path(checkpoint_path, v, i) + start = i * self.embedding_checkpoint_batch + numpy_arr = np.load(file=filename) + indices = tf.range(start=start, + limit=start + numpy_arr.shape[0], + dtype=tf.int32) + update = tf.IndexedSlices(values=numpy_arr, indices=indices, dense_shape=v.shape) + v.scatter_update(sparse_delta=update) + + dist_print('Restored a checkpoint from', checkpoint_path) + + +# dummy model for profiling and debugging +class DummyDlrm(tf.keras.Model): + def __init__(self, FLAGS, dataset_metadata): + super(DummyDlrm, self).__init__() + self.dense = tf.keras.layers.Dense(1, activation='sigmoid', + kernel_initializer='glorot_normal', + bias_initializer=initializers.RandomNormal(stddev=math.sqrt(1. / 1)) + ) + self.dataset_metadata = dataset_metadata + self.top_variables = [v for v in self.trainable_variables if 'bottom_model' not in v.name] + self.variables_partitioned = False + self.batch_size = FLAGS.batch_size + self.data_parallel_bottom_mlp = FLAGS.data_parallel_bottom_mlp + + def call(self, inputs, sigmoid=False): + x = tf.zeros(shape=[self.batch_size // hvd.size(), + self.dataset_metadata.num_numerical_features], + dtype=tf.float32) + x = self.dense(x) + x = tf.cast(x, dtype=tf.float32) + if sigmoid: + x = tf.math.sigmoid(x) + return x + + def _partition_variables(self): + self.bottom_variables = [v for v in self.trainable_variables if 'bottom_model' in v.name] + self.bottom_variable_indices = [i for i,v in enumerate(self.trainable_variables) if 'bottom_model' in v.name] + + self.top_variables = [v for v in self.trainable_variables if 'bottom_model' not in v.name] + self.top_variable_indices = [i for i, v in enumerate(self.trainable_variables) if 'bottom_model' not in v.name] + self.variables_partitioned = True + + def extract_bottom_gradients(self, all_gradients): + if not self.variables_partitioned: + self._partition_variables() + return [all_gradients[i] for i in self.bottom_variable_indices] + + def extract_top_gradients(self, all_gradients): + if not self.variables_partitioned: + self._partition_variables() + return [all_gradients[i] for i in self.top_variable_indices] diff --git a/Tensorflow2/Recommendation/DLRM/preproc/dgx2_config.sh b/Tensorflow2/Recommendation/DLRM/preproc/dgx2_config.sh new file mode 100755 index 00000000..ba064cf3 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/preproc/dgx2_config.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# the environment variables to run spark job +# should modify below environment variables + +# below numbers should be adjusted according to the resource of your running environment +# set the total number of CPU cores, spark can use +export TOTAL_CORES=80 + +# set the number of executors +export NUM_EXECUTORS=16 + +# the cores for each executor, it'll be calculated +export NUM_EXECUTOR_CORES=$((${TOTAL_CORES}/${NUM_EXECUTORS})) + +# unit: GB, set the max memory you want to use +export TOTAL_MEMORY=800 + +# unit: GB, set the memory for driver +export DRIVER_MEMORY=32 + +# the memory per executor +export EXECUTOR_MEMORY=$(((${TOTAL_MEMORY}-${DRIVER_MEMORY})/${NUM_EXECUTORS}-16)) diff --git a/Tensorflow2/Recommendation/DLRM/preproc/gpu/get_gpu_resources.sh b/Tensorflow2/Recommendation/DLRM/preproc/gpu/get_gpu_resources.sh new file mode 100644 index 00000000..b6411dbf --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/preproc/gpu/get_gpu_resources.sh @@ -0,0 +1,4 @@ +#! /bin/bash + +ADDRS=`nvidia-smi --query-gpu=index --format=csv,noheader | sed -e ':a' -e 'N' -e'$!ba' -e 's/\n/","/g'` +echo {\"name\": \"gpu\", \"addresses\":[\"$ADDRS\"]} diff --git a/Tensorflow2/Recommendation/DLRM/preproc/gpu/spark-defaults.conf b/Tensorflow2/Recommendation/DLRM/preproc/gpu/spark-defaults.conf new file mode 100644 index 00000000..0e204eb7 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/preproc/gpu/spark-defaults.conf @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Default system properties included when running spark-submit. +# This is useful for setting default environmental settings. + +# Example: +# spark.master spark://master:7077 +# spark.eventLog.enabled true +# spark.eventLog.dir hdfs://namenode:8021/directory +# spark.serializer org.apache.spark.serializer.KryoSerializer +# spark.driver.memory 5g +# spark.executor.extraJavaOptions -XX:+PrintGCDetails -Dkey=value -Dnumbers="one two three" + +spark.worker.resource.gpu.amount 16 +spark.worker.resource.gpu.discoveryScript /opt/spark/conf/get_gpu_resources.sh diff --git a/Tensorflow2/Recommendation/DLRM/preproc/parquet_to_binary.py b/Tensorflow2/Recommendation/DLRM/preproc/parquet_to_binary.py new file mode 100644 index 00000000..65f734fc --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/preproc/parquet_to_binary.py @@ -0,0 +1,90 @@ +# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import pandas as pd +import os +from joblib import Parallel, delayed +import glob +import argparse +import tqdm +import subprocess + +def process_file(f, dst): + + all_columns_sorted = [f'_c{i}' for i in range(0, 40)] + + data = pd.read_parquet(f) + data = data[all_columns_sorted] + + dense_columns = [f'_c{i}' for i in range(1, 14)] + data[dense_columns] = data[dense_columns].astype(np.float32) + + data = data.to_records(index=False) + data = data.tobytes() + + dst_file = dst + '/' + f.split('/')[-1] + '.bin' + with open(dst_file, 'wb') as dst_fd: + dst_fd.write(data) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--src_dir', type=str) + parser.add_argument('--intermediate_dir', type=str) + parser.add_argument('--dst_dir', type=str) + parser.add_argument('--parallel_jobs', default=40, type=int) + args = parser.parse_args() + + print('Processing train files...') + train_src_files = glob.glob(args.src_dir + '/train/*.parquet') + train_intermediate_dir = args.intermediate_dir + '/train' + os.makedirs(train_intermediate_dir, exist_ok=True) + + Parallel(n_jobs=args.parallel_jobs)(delayed(process_file)(f, train_intermediate_dir) for f in tqdm.tqdm(train_src_files)) + + print('Train files conversion done') + + print('Processing test files...') + test_src_files = glob.glob(args.src_dir + '/test/*.parquet') + test_intermediate_dir = args.intermediate_dir + '/test' + os.makedirs(test_intermediate_dir, exist_ok=True) + + Parallel(n_jobs=args.parallel_jobs)(delayed(process_file)(f, test_intermediate_dir) for f in tqdm.tqdm(test_src_files)) + print('Test files conversion done') + + print('Processing validation files...') + valid_src_files = glob.glob(args.src_dir + '/validation/*.parquet') + valid_intermediate_dir = args.intermediate_dir + '/valid' + os.makedirs(valid_intermediate_dir, exist_ok=True) + + Parallel(n_jobs=args.parallel_jobs)(delayed(process_file)(f, valid_intermediate_dir) for f in tqdm.tqdm(valid_src_files)) + print('Validation files conversion done') + + os.makedirs(args.dst_dir, exist_ok=True) + + print('Concatenating train files') + os.system(f'cat {train_intermediate_dir}/*.bin > {args.dst_dir}/train_data.bin') + + print('Concatenating test files') + os.system(f'cat {test_intermediate_dir}/*.bin > {args.dst_dir}/test_data.bin') + + print('Concatenating validation files') + os.system(f'cat {valid_intermediate_dir}/*.bin > {args.dst_dir}/val_data.bin') + print('Done') + + +if __name__ == '__main__': + main() diff --git a/Tensorflow2/Recommendation/DLRM/preproc/prepare_dataset.sh b/Tensorflow2/Recommendation/DLRM/preproc/prepare_dataset.sh new file mode 100755 index 00000000..ad4c0dd4 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/preproc/prepare_dataset.sh @@ -0,0 +1,79 @@ +#! /bin/bash + +# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Examples: +# to run on a DGX2 with a frequency limit of 3 (will need 8xV100-32GB to fit the model in GPU memory) +# ./prepare_dataset.sh DGX2 3 +# +# to run on a DGX2 with a frequency limit of 15 (should fit on a single V100-32GB): +# ./prepare_dataset.sh DGX2 15 +# +# to run on CPU with a frequency limit of 15: +# ./prepare_dataset.sh CPU 15 + + + +set -e +set -x + +ls -ltrash + +download_dir=${download_dir:-'/data/dlrm/criteo'} +./verify_criteo_downloaded.sh ${download_dir} + +spark_output_path=${spark_output_path:-'/data/dlrm/spark/output'} + + +if [ -f ${spark_output_path}/train/_SUCCESS ] \ + && [ -f ${spark_output_path}/validation/_SUCCESS ] \ + && [ -f ${spark_output_path}/test/_SUCCESS ]; then + + echo "Spark preprocessing already carried out" +else + echo "Performing spark preprocessing" + ./run_spark.sh $1 ${download_dir} ${spark_output_path} $2 +fi + +conversion_intermediate_dir=${conversion_intermediate_dir:-'/data/dlrm/intermediate_binary'} +final_output_dir=${final_output_dir:-'/data/dlrm/binary_dataset'} + + +if [ -d ${final_output_dir}/train ] \ + && [ -d ${final_output_dir}/val ] \ + && [ -d ${final_output_dir}/test ] \ + && [ -f ${final_output_dir}/model_sizes.json ]; then + + echo "Final conversion already done" +else + echo "Performing final conversion to a custom data format" + python parquet_to_binary.py --parallel_jobs 40 --src_dir ${spark_output_path} \ + --intermediate_dir ${conversion_intermediate_dir} \ + --dst_dir ${final_output_dir} + + cp "${spark_output_path}/model_size.json" "${final_output_dir}/model_size.json" + + python split_dataset.py --dataset "${final_output_dir}" --output "${final_output_dir}/split" + rm ${final_output_dir}/train_data.bin + rm ${final_output_dir}/val_data.bin + rm ${final_output_dir}/test_data.bin + + mv ${final_output_dir}/split/* ${final_output_dir} + rm -rf ${final_output_dir}/split +fi + +echo "Done preprocessing the Criteo Kaggle Dataset" +echo "You can now start the training with: " +echo "python -m dlrm.scripts.main --mode train --dataset ${final_output_dir}" diff --git a/Tensorflow2/Recommendation/DLRM/preproc/run_spark.sh b/Tensorflow2/Recommendation/DLRM/preproc/run_spark.sh new file mode 100755 index 00000000..256bb8aa --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/preproc/run_spark.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +######################################################################### +# File Name: run_spark.sh + + +echo "Input mode option: $1" +if [ "$1" = "CPU" ] +then + echo "Run with CPU."; + shift + ./run_spark_cpu.sh ${@} +elif [ "$1" = "DGX2" ] +then + echo "Run with GPU."; + shift + ./run_spark_gpu.sh ${@} DGX2 +else + echo "Please choose mode (CPU/DGX2)."; +fi diff --git a/Tensorflow2/Recommendation/DLRM/preproc/run_spark_cpu.sh b/Tensorflow2/Recommendation/DLRM/preproc/run_spark_cpu.sh new file mode 100755 index 00000000..c7427ff9 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/preproc/run_spark_cpu.sh @@ -0,0 +1,162 @@ +# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +######################################################################### +# File Name: run_spark_cpu.sh + +#!/bin/bash + +set -e + +# the environment variables to run spark job +# should modify below environment variables + +# the data path including 1TB criteo data, day_0, day_1, ... +export INPUT_PATH=${1:-'/data/dlrm/criteo'} + +# the output path, use for generating the dictionary and the final dataset +# the output folder should have more than 300GB +export OUTPUT_PATH=${2:-'/data/dlrm/spark/output'} + +export FREQUENCY_LIMIT=${3:-'15'} + +# spark local dir should have about 3TB +# the temporary path used for spark shuffle write +export SPARK_LOCAL_DIRS='/data/dlrm/spark/tmp' + +# below numbers should be adjusted according to the resource of your running environment +# set the total number of CPU cores, spark can use +export TOTAL_CORES=80 + +# set the number of executors +export NUM_EXECUTORS=8 + +# the cores for each executor, it'll be calculated +export NUM_EXECUTOR_CORES=$((${TOTAL_CORES}/${NUM_EXECUTORS})) + +# unit: GB, set the max memory you want to use +export TOTAL_MEMORY=800 + +# unit: GB, set the memory for driver +export DRIVER_MEMORY=32 + +# the memory per executor +export EXECUTOR_MEMORY=$(((${TOTAL_MEMORY}-${DRIVER_MEMORY})/${NUM_EXECUTORS})) + +OPTS="--frequency_limit $FREQUENCY_LIMIT" + +export SPARK_HOME=/opt/spark +export JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64 +export PATH=$SPARK_HOME/bin:$SPARK_HOME/sbin:$PATH + +# we use spark standalone to run the job +export MASTER=spark://$HOSTNAME:7077 + +echo "Starting spark standalone" +start-master.sh +start-slave.sh $MASTER + +echo "Generating the dictionary..." +spark-submit --master $MASTER \ + --driver-memory "${DRIVER_MEMORY}G" \ + --executor-cores $NUM_EXECUTOR_CORES \ + --executor-memory "${EXECUTOR_MEMORY}G" \ + --conf spark.cores.max=$TOTAL_CORES \ + --conf spark.task.cpus=1 \ + --conf spark.sql.files.maxPartitionBytes=1073741824 \ + --conf spark.sql.shuffle.partitions=600 \ + --conf spark.driver.maxResultSize=2G \ + --conf spark.locality.wait=0s \ + --conf spark.network.timeout=1800s \ + spark_data_utils.py --mode generate_models \ + $OPTS \ + --input_folder $INPUT_PATH \ + --days 0-23 \ + --model_folder $OUTPUT_PATH/models \ + --write_mode overwrite --low_mem 2>&1 | tee submit_dict_log.txt + +echo "Transforming the train data from day_0 to day_22..." +spark-submit --master $MASTER \ + --driver-memory "${DRIVER_MEMORY}G" \ + --executor-cores $NUM_EXECUTOR_CORES \ + --executor-memory "${EXECUTOR_MEMORY}G" \ + --conf spark.cores.max=$TOTAL_CORES \ + --conf spark.task.cpus=1 \ + --conf spark.sql.files.maxPartitionBytes=1073741824 \ + --conf spark.sql.shuffle.partitions=600 \ + --conf spark.driver.maxResultSize=2G \ + --conf spark.locality.wait=0s \ + --conf spark.network.timeout=1800s \ + spark_data_utils.py --mode transform \ + --input_folder $INPUT_PATH \ + --days 0-22 \ + --output_folder $OUTPUT_PATH/train \ + --model_size_file $OUTPUT_PATH/model_size.json \ + --model_folder $OUTPUT_PATH/models \ + --write_mode overwrite --low_mem 2>&1 | tee submit_train_log.txt + +echo "Splitting the last day into 2 parts of test and validation..." +last_day=$INPUT_PATH/day_23 +temp_test=$OUTPUT_PATH/temp/test +temp_validation=$OUTPUT_PATH/temp/validation +mkdir -p $temp_test $temp_validation + +lines=`wc -l $last_day | awk '{print $1}'` +former=$((lines / 2)) +latter=$((lines - former)) + +head -n $former $last_day > $temp_test/day_23 +tail -n $latter $last_day > $temp_validation/day_23 + +echo "Transforming the test data in day_23..." +spark-submit --master $MASTER \ + --driver-memory "${DRIVER_MEMORY}G" \ + --executor-cores $NUM_EXECUTOR_CORES \ + --executor-memory "${EXECUTOR_MEMORY}G" \ + --conf spark.cores.max=$TOTAL_CORES \ + --conf spark.task.cpus=1 \ + --conf spark.sql.files.maxPartitionBytes=1073741824 \ + --conf spark.sql.shuffle.partitions=30 \ + --conf spark.driver.maxResultSize=2G \ + --conf spark.locality.wait=0s \ + --conf spark.network.timeout=1800s \ + spark_data_utils.py --mode transform \ + --input_folder $temp_test \ + --days 23-23 \ + --output_folder $OUTPUT_PATH/test \ + --output_ordering input \ + --model_folder $OUTPUT_PATH/models \ + --write_mode overwrite --low_mem 2>&1 | tee submit_test_log.txt + +echo "Transforming the validation data in day_23..." +spark-submit --master $MASTER \ + --driver-memory "${DRIVER_MEMORY}G" \ + --executor-cores $NUM_EXECUTOR_CORES \ + --executor-memory "${EXECUTOR_MEMORY}G" \ + --conf spark.cores.max=$TOTAL_CORES \ + --conf spark.task.cpus=1 \ + --conf spark.sql.files.maxPartitionBytes=1073741824 \ + --conf spark.sql.shuffle.partitions=30 \ + --conf spark.driver.maxResultSize=2G \ + --conf spark.locality.wait=0s \ + --conf spark.network.timeout=1800s \ + spark_data_utils.py --mode transform \ + --input_folder $temp_validation \ + --days 23-23 \ + --output_folder $OUTPUT_PATH/validation \ + --output_ordering input \ + --model_folder $OUTPUT_PATH/models \ + --write_mode overwrite --low_mem 2>&1 | tee submit_validation_log.txt + +rm -r $temp_test $temp_validation diff --git a/Tensorflow2/Recommendation/DLRM/preproc/run_spark_gpu.sh b/Tensorflow2/Recommendation/DLRM/preproc/run_spark_gpu.sh new file mode 100755 index 00000000..a9fd734a --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/preproc/run_spark_gpu.sh @@ -0,0 +1,195 @@ +#!/bin/bash + +# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +######################################################################### +# File Name: run_spark_gpu.sh + +set -e + +# the data path including 1TB criteo data, day_0, day_1, ... +export INPUT_PATH=${1:-'/data/dlrm/criteo'} + +# the output path, use for generating the dictionary and the final dataset +# the output folder should have more than 300GB +export OUTPUT_PATH=${2:-'/data/dlrm/spark/output'} + +export FREQUENCY_LIMIT=${3:-'15'} + +export HARDWARE_PLATFORM=${4:-'DGX2'} + +# spark local dir should have about 3TB +# the temporary path used for spark shuffle write +export SPARK_LOCAL_DIRS='/data/dlrm/spark/tmp' + +if [[ $HARDWARE_PLATFORM == DGX2 ]]; then + source dgx2_config.sh +else + echo "Unknown hardware platform ${HARDWARE_PLATFORM}" + exit 1 +fi + +OPTS="--frequency_limit $FREQUENCY_LIMIT" + +export SPARK_HOME=/opt/spark +export JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64 +export PATH=$SPARK_HOME/bin:$SPARK_HOME/sbin:$PATH + +# we use spark standalone to run the job +export MASTER=spark://$HOSTNAME:7077 + +echo "Starting spark standalone" +start-master.sh +start-slave.sh $MASTER + +echo "Generating the dictionary..." +spark-submit --master $MASTER \ + --driver-memory "${DRIVER_MEMORY}G" \ + --executor-cores $NUM_EXECUTOR_CORES \ + --executor-memory "${EXECUTOR_MEMORY}G" \ + --conf spark.cores.max=$TOTAL_CORES \ + --conf spark.task.cpus=1 \ + --conf spark.sql.files.maxPartitionBytes=1073741824 \ + --conf spark.sql.shuffle.partitions=600 \ + --conf spark.driver.maxResultSize=2G \ + --conf spark.locality.wait=0s \ + --conf spark.network.timeout=1800s \ + --conf spark.task.resource.gpu.amount=0.01 \ + --conf spark.executor.resource.gpu.amount=1 \ + --conf spark.plugins=com.nvidia.spark.SQLPlugin \ + --conf spark.rapids.sql.concurrentGpuTasks=2 \ + --conf spark.rapids.sql.reader.batchSizeRows=4000000 \ + --conf spark.rapids.memory.pinnedPool.size=16g \ + --conf spark.rapids.sql.explain=ALL \ + --conf spark.sql.autoBroadcastJoinThreshold=1GB \ + --conf spark.rapids.sql.incompatibleOps.enabled=true \ + --conf spark.driver.maxResultSize=2G \ + --conf spark.executor.extraJavaOptions="-Dcom.nvidia.cudf.prefer-pinned=true\ -Djava.io.tmpdir=$SPARK_LOCAL_DIRS" \ + spark_data_utils.py --mode generate_models \ + $OPTS \ + --input_folder $INPUT_PATH \ + --days 0-23 \ + --model_folder $OUTPUT_PATH/models \ + --write_mode overwrite --low_mem 2>&1 | tee submit_dict_log.txt + +echo "Transforming the train data from day_0 to day_22..." +spark-submit --master $MASTER \ + --driver-memory "${DRIVER_MEMORY}G" \ + --executor-cores $NUM_EXECUTOR_CORES \ + --executor-memory "${EXECUTOR_MEMORY}G" \ + --conf spark.cores.max=$TOTAL_CORES \ + --conf spark.task.cpus=3 \ + --conf spark.sql.files.maxPartitionBytes=1073741824 \ + --conf spark.sql.shuffle.partitions=600 \ + --conf spark.driver.maxResultSize=2G \ + --conf spark.locality.wait=0s \ + --conf spark.network.timeout=1800s \ + --conf spark.task.resource.gpu.amount=0.01 \ + --conf spark.executor.resource.gpu.amount=1 \ + --conf spark.plugins=com.nvidia.spark.SQLPlugin \ + --conf spark.rapids.sql.concurrentGpuTasks=2 \ + --conf spark.rapids.sql.reader.batchSizeRows=4000000 \ + --conf spark.rapids.memory.pinnedPool.size=16g \ + --conf spark.rapids.sql.explain=ALL \ + --conf spark.sql.autoBroadcastJoinThreshold=1GB \ + --conf spark.rapids.sql.incompatibleOps.enabled=true \ + --conf spark.driver.maxResultSize=2G \ + --conf spark.executor.extraJavaOptions="-Dcom.nvidia.cudf.prefer-pinned=true\ -Djava.io.tmpdir=$SPARK_LOCAL_DIRS" \ + spark_data_utils.py --mode transform \ + --input_folder $INPUT_PATH \ + --days 0-22 \ + --output_folder $OUTPUT_PATH/train \ + --model_size_file $OUTPUT_PATH/model_size.json \ + --model_folder $OUTPUT_PATH/models \ + --write_mode overwrite --low_mem 2>&1 | tee submit_train_log.txt + +echo "Splitting the last day into 2 parts of test and validation..." +last_day=$INPUT_PATH/day_23 +temp_test=$OUTPUT_PATH/temp/test +temp_validation=$OUTPUT_PATH/temp/validation +mkdir -p $temp_test $temp_validation + +lines=`wc -l $last_day | awk '{print $1}'` +former=$((lines / 2)) +latter=$((lines - former)) + +head -n $former $last_day > $temp_test/day_23 +tail -n $latter $last_day > $temp_validation/day_23 + +echo "Transforming the test data in day_23..." +spark-submit --master $MASTER \ + --driver-memory "${DRIVER_MEMORY}G" \ + --executor-cores $NUM_EXECUTOR_CORES \ + --executor-memory "${EXECUTOR_MEMORY}G" \ + --conf spark.cores.max=$TOTAL_CORES \ + --conf spark.task.cpus=1 \ + --conf spark.sql.files.maxPartitionBytes=1073741824 \ + --conf spark.sql.shuffle.partitions=30 \ + --conf spark.driver.maxResultSize=2G \ + --conf spark.locality.wait=0s \ + --conf spark.network.timeout=1800s \ + --conf spark.task.resource.gpu.amount=0.01 \ + --conf spark.executor.resource.gpu.amount=1 \ + --conf spark.plugins=com.nvidia.spark.SQLPlugin \ + --conf spark.rapids.sql.concurrentGpuTasks=2 \ + --conf spark.rapids.sql.reader.batchSizeRows=4000000 \ + --conf spark.rapids.memory.pinnedPool.size=16g \ + --conf spark.rapids.sql.explain=ALL \ + --conf spark.sql.autoBroadcastJoinThreshold=1GB \ + --conf spark.rapids.sql.incompatibleOps.enabled=true \ + --conf spark.driver.maxResultSize=2G \ + --conf spark.executor.extraJavaOptions="-Dcom.nvidia.cudf.prefer-pinned=true\ -Djava.io.tmpdir=$SPARK_LOCAL_DIRS" \ + spark_data_utils.py --mode transform \ + --input_folder $temp_test \ + --days 23-23 \ + --output_folder $OUTPUT_PATH/test \ + --output_ordering input \ + --model_folder $OUTPUT_PATH/models \ + --write_mode overwrite --low_mem 2>&1 | tee submit_test_log.txt + +echo "Transforming the validation data in day_23..." +spark-submit --master $MASTER \ + --driver-memory "${DRIVER_MEMORY}G" \ + --executor-cores $NUM_EXECUTOR_CORES \ + --executor-memory "${EXECUTOR_MEMORY}G" \ + --conf spark.cores.max=$TOTAL_CORES \ + --conf spark.task.cpus=1 \ + --conf spark.sql.files.maxPartitionBytes=1073741824 \ + --conf spark.sql.shuffle.partitions=30 \ + --conf spark.driver.maxResultSize=2G \ + --conf spark.locality.wait=0s \ + --conf spark.network.timeout=1800s \ + --conf spark.task.resource.gpu.amount=0.01 \ + --conf spark.executor.resource.gpu.amount=1 \ + --conf spark.plugins=com.nvidia.spark.SQLPlugin \ + --conf spark.rapids.sql.concurrentGpuTasks=2 \ + --conf spark.rapids.sql.reader.batchSizeRows=4000000 \ + --conf spark.rapids.memory.pinnedPool.size=16g \ + --conf spark.rapids.sql.explain=ALL \ + --conf spark.sql.autoBroadcastJoinThreshold=1GB \ + --conf spark.rapids.sql.incompatibleOps.enabled=true \ + --conf spark.driver.maxResultSize=2G \ + --conf spark.executor.extraJavaOptions="-Dcom.nvidia.cudf.prefer-pinned=true\ -Djava.io.tmpdir=$SPARK_LOCAL_DIRS" \ + spark_data_utils.py --mode transform \ + --input_folder $temp_validation \ + --days 23-23 \ + --output_folder $OUTPUT_PATH/validation \ + --output_ordering input \ + --model_folder $OUTPUT_PATH/models \ + --write_mode overwrite --low_mem 2>&1 | tee submit_validation_log.txt + +rm -r $temp_test $temp_validation +stop-master.sh +stop-slave.sh diff --git a/Tensorflow2/Recommendation/DLRM/preproc/spark_data_utils.py b/Tensorflow2/Recommendation/DLRM/preproc/spark_data_utils.py new file mode 100644 index 00000000..7f1be2c9 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/preproc/spark_data_utils.py @@ -0,0 +1,507 @@ +# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import os +import sys + +from argparse import ArgumentParser +from collections import OrderedDict +from contextlib import contextmanager +from operator import itemgetter +from time import time + +from pyspark import broadcast +from pyspark.sql import Row, SparkSession, Window +from pyspark.sql.functions import * +from pyspark.sql.types import * + + +LABEL_COL = 0 +INT_COLS = list(range(1, 14)) +CAT_COLS = list(range(14, 40)) + + +def get_column_counts_with_frequency_limit(df, frequency_limit = None): + cols = ['_c%d' % i for i in CAT_COLS] + df = (df + .select(posexplode(array(*cols))) + .withColumnRenamed('pos', 'column_id') + .withColumnRenamed('col', 'data') + .filter('data is not null') + .groupBy('column_id', 'data') + .count()) + + if frequency_limit: + frequency_limit = frequency_limit.split(",") + exclude = [] + default_limit = None + for fl in frequency_limit: + frequency_pair = fl.split(":") + if len(frequency_pair) == 1: + default_limit = int(frequency_pair[0]) + elif len(frequency_pair) == 2: + df = df.filter((col('column_id') != int(frequency_pair[0]) - CAT_COLS[0]) | (col('count') >= int(frequency_pair[1]))) + exclude.append(int(frequency_pair[0])) + if default_limit: + remain = [x - CAT_COLS[0] for x in CAT_COLS if x not in exclude] + df = df.filter((~col('column_id').isin(remain)) | (col('count') >= default_limit)) + # for comparing isin and separate filter + # for i in remain: + # df = df.filter((col('column_id') != i - CAT_COLS[0]) | (col('count') >= default_limit)) + return df + + +def assign_id_with_window(df): + windowed = Window.partitionBy('column_id').orderBy(desc('count')) + return (df + .withColumn('id', row_number().over(windowed)) + .withColumnRenamed('count', 'model_count')) + + +def assign_low_mem_partial_ids(df): + # To avoid some scaling issues with a simple window operation, we use a more complex method + # to compute the same thing, but in a more distributed spark specific way + df = df.orderBy(asc('column_id'), desc('count')) + # The monotonically_increasing_id is the partition id in the top 31 bits and the rest + # is an increasing count of the rows within that partition. So we split it into two parts, + # the partion id part_id and the count mono_id + df = df.withColumn('part_id', spark_partition_id()) + return df.withColumn('mono_id', monotonically_increasing_id() - shiftLeft(col('part_id'), 33)) + + +def assign_low_mem_final_ids(df): + # Now we can find the minimum and maximum mono_ids within a given column/partition pair + sub_model = df.groupBy('column_id', 'part_id').agg(max('mono_id').alias('top'), min('mono_id').alias('bottom')) + sub_model = sub_model.withColumn('diff', col('top') - col('bottom') + 1) + sub_model = sub_model.drop('top') + # This window function is over aggregated column/partition pair table. It will do a running sum of the rows + # within that column + windowed = Window.partitionBy('column_id').orderBy('part_id').rowsBetween(Window.unboundedPreceding, -1) + sub_model = sub_model.withColumn('running_sum', sum('diff').over(windowed)).na.fill(0, ["running_sum"]) + + joined = df.withColumnRenamed('column_id', 'i_column_id') + joined = joined.withColumnRenamed('part_id', 'i_part_id') + joined = joined.withColumnRenamed('count', 'model_count') + + # Then we can join the original input with the pair it is a part of + joined = joined.join(sub_model, (col('i_column_id') == col('column_id')) & (col('part_id') == col('i_part_id'))) + + # So with all that we can subtract bottom from mono_id makeing it start at 0 for each partition + # and then add in the running_sum so the id is contiguous and unique for the entire column. + 1 to make it match the 1 based indexing + # for row_number + ret = joined.select(col('column_id'), + col('data'), + (col('mono_id') - col('bottom') + col('running_sum') + 1).cast(IntegerType()).alias('id'), + col('model_count')) + return ret + + +def get_column_models(combined_model): + for i in CAT_COLS: + model = (combined_model + .filter('column_id == %d' % (i - CAT_COLS[0])) + .drop('column_id')) + yield i, model + + +def col_of_rand_long(): + return (rand() * (1 << 52)).cast(LongType()) + +def skewed_join(df, model, col_name, cutoff): + # Most versions of spark don't have a good way + # to deal with a skewed join out of the box. + # Some do and if you want to replace this with + # one of those that would be great. + + # Because we have statistics about the skewedness + # that we can used we divide the model up into two parts + # one part is the highly skewed part and we do a + # broadcast join for that part, but keep the result in + # a separate column + b_model = broadcast(model.filter(col('model_count') >= cutoff) + .withColumnRenamed('data', col_name) + .drop('model_count')) + + df = (df + .join(b_model, col_name, how='left') + .withColumnRenamed('id', 'id_tmp')) + + # We also need to spread the skewed data that matched + # evenly. We will use a source of randomness for this + # but use a -1 for anything that still needs to be matched + if 'ordinal' in df.columns: + rand_column = col('ordinal') + else: + rand_column = col_of_rand_long() + + df = df.withColumn('join_rand', + # null values are not in the model, they are filtered out + # but can be a source of skewedness so include them in + # the even distribution + when(col('id_tmp').isNotNull() | col(col_name).isNull(), rand_column) + .otherwise(lit(-1))) + + # Null out the string data that already matched to save memory + df = df.withColumn(col_name, + when(col('id_tmp').isNotNull(), None) + .otherwise(col(col_name))) + + # Now do the second join, which will be a non broadcast join. + # Sadly spark is too smart for its own good and will optimize out + # joining on a column it knows will always be a constant value. + # So we have to make a convoluted version of assigning a -1 to the + # randomness column for the model itself to work around that. + nb_model = (model + .withColumn('join_rand', when(col('model_count') < cutoff, lit(-1)).otherwise(lit(-2))) + .filter(col('model_count') < cutoff) + .withColumnRenamed('data', col_name) + .drop('model_count')) + + df = (df + .join(nb_model, ['join_rand', col_name], how='left') + .drop(col_name, 'join_rand') + # Pick either join result as an answer + .withColumn(col_name, coalesce(col('id'), col('id_tmp'))) + .drop('id', 'id_tmp')) + + return df + + +def apply_models(df, models, broadcast_model = False, skew_broadcast_pct = 1.0): + # sort the models so broadcast joins come first. This is + # so we reduce the amount of shuffle data sooner than later + # If we parsed the string hex values to ints early on this would + # not make a difference. + models = sorted(models, key=itemgetter(3), reverse=True) + for i, model, original_rows, would_broadcast in models: + col_name = '_c%d' % i + if not (would_broadcast or broadcast_model): + # The data is highly skewed so we need to offset that + cutoff = int(original_rows * skew_broadcast_pct/100.0) + df = skewed_join(df, model, col_name, cutoff) + else: + # broadcast joins can handle skewed data so no need to + # do anything special + model = (model.drop('model_count') + .withColumnRenamed('data', col_name)) + model = broadcast(model) if broadcast_model else model + df = (df + .join(model, col_name, how='left') + .drop(col_name) + .withColumnRenamed('id', col_name)) + return df.fillna(0, ['_c%d' % i for i in CAT_COLS]) + + +def transform_log(df, transform_log = False): + cols = ['_c%d' % i for i in INT_COLS] + if transform_log: + for col_name in cols: + df = df.withColumn(col_name, log(df[col_name] + 3)) + return df.fillna(0, cols) + + +def would_broadcast(spark, str_path): + sc = spark.sparkContext + config = sc._jsc.hadoopConfiguration() + path = sc._jvm.org.apache.hadoop.fs.Path(str_path) + fs = sc._jvm.org.apache.hadoop.fs.FileSystem.get(config) + stat = fs.listFiles(path, True) + sum = 0 + while stat.hasNext(): + sum = sum + stat.next().getLen() + sql_conf = sc._jvm.org.apache.spark.sql.internal.SQLConf() + cutoff = sql_conf.autoBroadcastJoinThreshold() * sql_conf.fileCompressionFactor() + return sum <= cutoff + +def delete_data_source(spark, path): + sc = spark.sparkContext + config = sc._jsc.hadoopConfiguration() + path = sc._jvm.org.apache.hadoop.fs.Path(path) + sc._jvm.org.apache.hadoop.fs.FileSystem.get(config).delete(path, True) + + +def load_raw(spark, folder, day_range): + label_fields = [StructField('_c%d' % LABEL_COL, IntegerType())] + int_fields = [StructField('_c%d' % i, IntegerType()) for i in INT_COLS] + str_fields = [StructField('_c%d' % i, StringType()) for i in CAT_COLS] + + schema = StructType(label_fields + int_fields + str_fields) + paths = [os.path.join(folder, 'day_%d' % i) for i in day_range] + return (spark + .read + .schema(schema) + .option('sep', '\t') + .csv(paths)) + +def rand_ordinal(df): + # create a random long from the double precision float. + # The fraction part of a double is 52 bits, so we try to capture as much + # of that as possible + return df.withColumn('ordinal', col_of_rand_long()) + +def day_from_ordinal(df, num_days): + return df.withColumn('day', (col('ordinal') % num_days).cast(IntegerType())) + +def day_from_input_file(df): + return df.withColumn('day', substring_index(input_file_name(), '_', -1).cast(IntegerType())) + +def psudo_sort_by_day_plus(spark, df, num_days): + # Sort is very expensive because it needs to calculate the partitions + # which in our case may involve rereading all of the data. In some cases + # we can avoid this by repartitioning the data and sorting within a single partition + shuffle_parts = int(spark.conf.get('spark.sql.shuffle.partitions')) + extra_parts = int(shuffle_parts/num_days) + if extra_parts <= 0: + df = df.repartition('day') + else: + #We want to spread out the computation to about the same amount as shuffle_parts + divided = (col('ordinal') / num_days).cast(LongType()) + extra_ident = divided % extra_parts + df = df.repartition(col('day'), extra_ident) + return df.sortWithinPartitions('day', 'ordinal') + + +def load_combined_model(spark, model_folder): + path = os.path.join(model_folder, 'combined.parquet') + return spark.read.parquet(path) + + +def save_combined_model(df, model_folder, mode=None): + path = os.path.join(model_folder, 'combined.parquet') + df.write.parquet(path, mode=mode) + + +def delete_combined_model(spark, model_folder): + path = os.path.join(model_folder, 'combined.parquet') + delete_data_source(spark, path) + + +def load_low_mem_partial_ids(spark, model_folder): + path = os.path.join(model_folder, 'partial_ids.parquet') + return spark.read.parquet(path) + + +def save_low_mem_partial_ids(df, model_folder, mode=None): + path = os.path.join(model_folder, 'partial_ids.parquet') + df.write.parquet(path, mode=mode) + + +def delete_low_mem_partial_ids(spark, model_folder): + path = os.path.join(model_folder, 'partial_ids.parquet') + delete_data_source(spark, path) + + +def load_column_models(spark, model_folder, count_required): + for i in CAT_COLS: + path = os.path.join(model_folder, '%d.parquet' % i) + df = spark.read.parquet(path) + if count_required: + values = df.agg(sum('model_count').alias('sum'), count('*').alias('size')).collect() + else: + values = df.agg(sum('model_count').alias('sum')).collect() + yield i, df, values[0], would_broadcast(spark, path) + +def save_column_models(column_models, model_folder, mode=None): + for i, model in column_models: + path = os.path.join(model_folder, '%d.parquet' % i) + model.write.parquet(path, mode=mode) + + +def save_model_size(model_size, path, write_mode): + if os.path.exists(path) and write_mode == 'errorifexists': + print('Error: model size file %s exists' % path) + sys.exit(1) + + os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) + with open(path, 'w') as fp: + json.dump(model_size, fp, indent=4) + + +_benchmark = {} + + +@contextmanager +def _timed(step): + start = time() + yield + end = time() + _benchmark[step] = end - start + + +def _parse_args(): + parser = ArgumentParser() + + parser.add_argument( + '--mode', + required=True, + choices=['generate_models', 'transform']) + + parser.add_argument('--days', required=True) + parser.add_argument('--input_folder', required=True) + parser.add_argument('--output_folder') + parser.add_argument('--model_size_file') + parser.add_argument('--model_folder', required=True) + parser.add_argument( + '--write_mode', + choices=['overwrite', 'errorifexists'], + default='errorifexists') + + parser.add_argument('--frequency_limit') + parser.add_argument('--no_numeric_log_col', action='store_true') + #Support for running in a lower memory environment + parser.add_argument('--low_mem', action='store_true') + parser.add_argument( + '--output_ordering', + choices=['total_random', 'day_random', 'any', 'input'], + default='total_random') + + parser.add_argument( + '--output_partitioning', + choices=['day', 'none'], + default='none') + + parser.add_argument('--dict_build_shuffle_parallel_per_day', type=int, default=2) + parser.add_argument('--apply_shuffle_parallel_per_day', type=int, default=25) + parser.add_argument('--skew_broadcast_pct', type=float, default=1.0) + + parser.add_argument('--debug_mode', action='store_true') + + args = parser.parse_args() + + start, end = args.days.split('-') + args.day_range = list(range(int(start), int(end) + 1)) + args.days = len(args.day_range) + + return args + + +def _main(): + args = _parse_args() + spark = SparkSession.builder.getOrCreate() + + df = load_raw(spark, args.input_folder, args.day_range) + + if args.mode == 'generate_models': + spark.conf.set('spark.sql.shuffle.partitions', args.days * args.dict_build_shuffle_parallel_per_day) + with _timed('generate models'): + col_counts = get_column_counts_with_frequency_limit(df, args.frequency_limit) + if args.low_mem: + # in low memory mode we have to save an intermediate result + # because if we try to do it in one query spark ends up assigning the + # partial ids in two different locations that are not guaranteed to line up + # this prevents that from happening by assigning the partial ids + # and then writeing them out. + save_low_mem_partial_ids( + assign_low_mem_partial_ids(col_counts), + args.model_folder, + args.write_mode) + save_combined_model( + assign_low_mem_final_ids(load_low_mem_partial_ids(spark, args.model_folder)), + args.model_folder, + args.write_mode) + if not args.debug_mode: + delete_low_mem_partial_ids(spark, args.model_folder) + + else: + save_combined_model( + assign_id_with_window(col_counts), + args.model_folder, + args.write_mode) + save_column_models( + get_column_models(load_combined_model(spark, args.model_folder)), + args.model_folder, + args.write_mode) + if not args.debug_mode: + delete_combined_model(spark, args.model_folder) + + if args.mode == 'transform': + spark.conf.set('spark.sql.shuffle.partitions', args.days * args.apply_shuffle_parallel_per_day) + with _timed('transform'): + if args.output_ordering == 'total_random': + df = rand_ordinal(df) + if args.output_partitioning == 'day': + df = day_from_ordinal(df, args.days) + elif args.output_ordering == 'day_random': + df = rand_ordinal(df) + df = day_from_input_file(df) + elif args.output_ordering == 'input': + df = df.withColumn('ordinal', monotonically_increasing_id()) + if args.output_partitioning == 'day': + df = day_from_input_file(df) + else: # any ordering + if args.output_partitioning == 'day': + df = day_from_input_file(df) + + models = list(load_column_models(spark, args.model_folder, bool(args.model_size_file))) + if args.model_size_file: + save_model_size( + OrderedDict(('_c%d' % i, agg.size) for i, _, agg, _ in models), + args.model_size_file, + args.write_mode) + models = [(i, df, agg.sum, flag) for i, df, agg, flag in models] + + df = apply_models( + df, + models, + not args.low_mem, + args.skew_broadcast_pct) + df = transform_log(df, not args.no_numeric_log_col) + + + if args.output_partitioning == 'day': + partitionBy = 'day' + else: + partitionBy = None + + if args.output_ordering == 'total_random': + if args.output_partitioning == 'day': + df = psudo_sort_by_day_plus(spark, df, args.days) + else: # none + # Don't do a full sort it is expensive. Order is random so + # just make it random + df = df.repartition('ordinal').sortWithinPartitions('ordinal') + + df = df.drop('ordinal') + elif args.output_ordering == 'day_random': + df = psudo_sort_by_day_plus(spark, df, args.days) + df = df.drop('ordinal') + if args.output_partitioning != 'day': + df = df.drop('day') + elif args.output_ordering == 'input': + if args.low_mem: + # This is the slowest option. We totally messed up the order so we have to put + # it back in the correct order + df = df.orderBy('ordinal') + else: + # Applying the dictionary happened within a single task so we are already really + # close to the correct order, just need to sort within the partition + df = df.sortWithinPartitions('ordinal') + df = df.drop('ordinal') + if args.output_partitioning != 'day': + df = df.drop('day') + # else: any ordering so do nothing the ordering does not matter + + df.write.parquet( + args.output_folder, + mode=args.write_mode, + partitionBy=partitionBy) + + print('=' * 100) + print(_benchmark) + + +if __name__ == '__main__': + _main() diff --git a/Tensorflow2/Recommendation/DLRM/preproc/split_dataset.py b/Tensorflow2/Recommendation/DLRM/preproc/split_dataset.py new file mode 100644 index 00000000..5e82a7c8 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/preproc/split_dataset.py @@ -0,0 +1,127 @@ +# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os +import math +from shutil import copyfile + +from tqdm import tqdm +import numpy as np +from typing import Sequence + + +def get_categorical_feature_type(size: int): + types = (np.int8, np.int16, np.int32) + + for numpy_type in types: + if size < np.iinfo(numpy_type).max: + return numpy_type + + raise RuntimeError(f"Categorical feature of size {size} is too big for defined types") + + +def split_binary_file( + binary_file_path: str, + output_dir: str, + categorical_feature_sizes: Sequence[int], + num_numerical_features: int, + batch_size: int, + source_data_type: str = 'int32', +): + record_width = 1 + num_numerical_features + len(categorical_feature_sizes) # label + numerical + categorical + bytes_per_feature = np.__dict__[source_data_type]().nbytes + bytes_per_entry = record_width * bytes_per_feature + + total_size = os.path.getsize(binary_file_path) + batches_num = int(math.ceil((total_size // bytes_per_entry) / batch_size)) + + cat_feature_types = [get_categorical_feature_type(cat_size) for cat_size in categorical_feature_sizes] + + file_streams = [] + try: + input_data_f = open(binary_file_path, "rb") + file_streams.append(input_data_f) + + numerical_f = open(os.path.join(output_dir, "numerical.bin"), "wb+") + file_streams.append(numerical_f) + + label_f = open(os.path.join(output_dir, 'label.bin'), 'wb+') + file_streams.append(label_f) + + categorical_fs = [] + for i in range(len(categorical_feature_sizes)): + fs = open(os.path.join(output_dir, f'cat_{i}.bin'), 'wb+') + categorical_fs.append(fs) + file_streams.append(fs) + + for _ in tqdm(range(batches_num)): + raw_data = np.frombuffer(input_data_f.read(bytes_per_entry * batch_size), dtype=np.int32) + batch_data = raw_data.reshape(-1, record_width) + + numerical_features = batch_data[:, 1:1 + num_numerical_features].view(dtype=np.float32) + numerical_f.write(numerical_features.astype(np.float16).tobytes()) + + label = batch_data[:, 0] + label_f.write(label.astype(np.bool).tobytes()) + + cat_offset = num_numerical_features + 1 + for cat_idx, cat_feature_type in enumerate(cat_feature_types): + cat_data = batch_data[:, (cat_idx + cat_offset):(cat_idx + cat_offset + 1)].astype(cat_feature_type) + categorical_fs[cat_idx].write(cat_data.tobytes()) + finally: + for stream in file_streams: + stream.close() + + +def split_dataset(dataset_dir: str, output_dir: str, batch_size: int, numerical_features: int): + categorical_sizes_file = os.path.join(dataset_dir, "model_size.json") + with open(categorical_sizes_file) as f: + categorical_sizes = [int(v) for v in json.load(f).values()] + + train_file = os.path.join(dataset_dir, "train_data.bin") + test_file = os.path.join(dataset_dir, "test_data.bin") + val_file = os.path.join(dataset_dir, "val_data.bin") + + target_train = os.path.join(output_dir, "train") + target_test = os.path.join(output_dir, "test") + target_val = os.path.join(output_dir, "val") + + os.makedirs(output_dir, exist_ok=True) + os.makedirs(target_train, exist_ok=True) + os.makedirs(target_test, exist_ok=True) + os.makedirs(target_val, exist_ok=True) + + copyfile(categorical_sizes_file, os.path.join(output_dir, "model_size.json")) + split_binary_file(test_file, target_test, categorical_sizes, numerical_features, batch_size) + split_binary_file(train_file, target_train, categorical_sizes, numerical_features, batch_size) + split_binary_file(val_file, target_val, categorical_sizes, numerical_features, batch_size) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--dataset', type=str, required=True) + parser.add_argument('--output', type=str, required=True) + parser.add_argument('--batch_size', type=int, default=32768) + parser.add_argument('--numerical_features', type=int, default=13) + args = parser.parse_args() + + split_dataset( + dataset_dir=args.dataset, + output_dir=args.output, + batch_size=args.batch_size, + numerical_features=args.numerical_features + ) + diff --git a/Tensorflow2/Recommendation/DLRM/preproc/verify_criteo_downloaded.sh b/Tensorflow2/Recommendation/DLRM/preproc/verify_criteo_downloaded.sh new file mode 100755 index 00000000..88dc8233 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/preproc/verify_criteo_downloaded.sh @@ -0,0 +1,34 @@ +# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#! /bin/bash + +set -e +set -x + +download_dir=${1:-'/data/dlrm/criteo'} + +cd ${download_dir} +for i in $(seq 0 23); do + filename=day_${i} + if [ -f $filename ]; then + echo "$filename exists, OK" + else + echo "$filename does not exist. Please follow the instructions at: http://labs.criteo.com/2013/12/download-terabyte-click-logs/ to download it" + exit 1 + fi +done +cd - + +echo "Criteo data verified" diff --git a/Tensorflow2/Recommendation/DLRM/requirements.txt b/Tensorflow2/Recommendation/DLRM/requirements.txt new file mode 100644 index 00000000..2ab57ef2 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/requirements.txt @@ -0,0 +1,7 @@ +-e git://github.com/NVIDIA/dllogger#egg=dllogger +absl-py>=0.7.0 +numpy +pyarrow +pandas +joblib +tqdm diff --git a/Tensorflow2/Recommendation/DLRM/slurm_multinode.sh b/Tensorflow2/Recommendation/DLRM/slurm_multinode.sh new file mode 100644 index 00000000..fcb3d225 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/slurm_multinode.sh @@ -0,0 +1,39 @@ +#!/bin/bash +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# author: Tomasz Grel (tgrel@nvidia.com) + + +# This is a generic SLURM batch script. It runs $cmd +# command in $cont docker image while mounting $mounts directories. +# You can use the $srun_flags variable to pass additional +# arguments to srun. +# +# It is designed to work with enroot/pyxis, but could be modified +# to run on bare-metal machines as well. +# +# Example usage to train a 1.68TB DLRM variant using 32xA100-80GB GPUs on 4 nodes: +# +# cmd='numactl --interleave=all -- python -u main.py --dataset_path /data/dlrm/full_criteo_data --amp \ +# --tf_gpu_memory_limit_gb 73 --experimental_columnwise_split --data_parallel_bottom_mlp \ +# --embedding_dim 512 --bottom_mlp_dims 512,256,512' \ +# srun_flags='--mpi=pmix' \ +# cont=dlrm_tf_adam \ +# mounts=/data/dlrm:/data/dlrm \ +# sbatch -n 32 -N 4 -t 00:20:00 slurm_multinode.sh +# + +srun --mpi=none ${srun_flags} --ntasks-per-node=1 \ + --container-image="${cont}" --container-mounts=${mounts} /bin/bash -c "$cmd" \ No newline at end of file diff --git a/Tensorflow2/Recommendation/DLRM/split_binary_dataset.py b/Tensorflow2/Recommendation/DLRM/split_binary_dataset.py new file mode 100644 index 00000000..31b5db86 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/split_binary_dataset.py @@ -0,0 +1,215 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import tensorflow as tf + +import concurrent +import math +import os +import queue +import json +from collections import namedtuple + +import numpy as np +from typing import Optional, Sequence, Tuple, Any, Dict + + +DatasetMetadata = namedtuple('DatasetMetadata', ['num_numerical_features', + 'categorical_cardinalities']) + +class DummyDataset: + def __init__(self, batch_size, num_numerical_features, num_categorical_features, num_batches): + self.numerical_features = tf.zeros(shape=[batch_size, num_numerical_features]) + self.categorical_features = [tf.zeros(shape=[batch_size, 1], dtype=tf.int32)] * num_categorical_features + self.labels = tf.ones(shape=[batch_size, 1]) + self.num_batches = num_batches + + def __getitem__(self, idx): + if idx >= self.num_batches: + raise StopIteration + + return (self.numerical_features, self.categorical_features), self.labels + + def __len__(self): + return self.num_batches + + @staticmethod + def get_metadata(FLAGS): + cardinalities = [int(d) for d in FLAGS.synthetic_dataset_cardinalities] + metadata = DatasetMetadata(num_numerical_features=FLAGS.num_numerical_features, + categorical_cardinalities=cardinalities) + return metadata + + +def get_categorical_feature_type(size: int): + types = (np.int8, np.int16, np.int32) + + for numpy_type in types: + if size < np.iinfo(numpy_type).max: + return numpy_type + + raise RuntimeError(f"Categorical feature of size {size} is too big for defined types") + + +class RawBinaryDataset: + """Split version of Criteo dataset + + Args: + data_path (str): Full path to split binary file of dataset. It must contain numerical.bin, label.bin and + cat_0 ~ cat_25.bin + batch_size (int): + numerical_features(boolean): Number of numerical features to load, default=0 (don't load any) + categorical_features (list or None): categorical features used by the rank (IDs of the features) + categorical_feature_sizes (list of integers): max value of each of the categorical features + prefetch_depth (int): How many samples to prefetch. Default 10. + """ + + _model_size_filename = 'model_size.json' + + def __init__( + self, + data_path: str, + batch_size: int = 1, + numerical_features: int = 0, + categorical_features: Optional[Sequence[int]] = None, + categorical_feature_sizes: Optional[Sequence[int]] = None, + prefetch_depth: int = 10, + drop_last_batch: bool = False, + valid : bool = False, + ): + suffix = 'test' if valid else 'train' + data_path = os.path.join(data_path, suffix) + self._label_bytes_per_batch = np.dtype(np.bool).itemsize * batch_size + self._numerical_bytes_per_batch = numerical_features * np.dtype(np.float16).itemsize * batch_size + self._numerical_features = numerical_features + + self._categorical_feature_types = [ + get_categorical_feature_type(size) for size in categorical_feature_sizes + ] if categorical_feature_sizes else [] + self._categorical_bytes_per_batch = [ + np.dtype(cat_type).itemsize * batch_size for cat_type in self._categorical_feature_types + ] + self._categorical_features = categorical_features + self._batch_size = batch_size + self._label_file = os.open(os.path.join(data_path, 'label.bin'), os.O_RDONLY) + self._num_entries = int(math.ceil(os.fstat(self._label_file).st_size + / self._label_bytes_per_batch)) if not drop_last_batch \ + else int(math.floor(os.fstat(self._label_file).st_size / self._label_bytes_per_batch)) + + if numerical_features > 0: + self._numerical_features_file = os.open(os.path.join(data_path, "numerical.bin"), os.O_RDONLY) + number_of_numerical_batches = math.ceil(os.fstat(self._numerical_features_file).st_size + / self._numerical_bytes_per_batch) if not drop_last_batch \ + else math.floor(os.fstat(self._numerical_features_file).st_size + / self._numerical_bytes_per_batch) + if number_of_numerical_batches != self._num_entries: + raise ValueError(f"Size mismatch in data files. Expected: {self._num_entries}, got: {number_of_numerical_batches}") + else: + self._numerical_features_file = None + + if categorical_features: + self._categorical_features_files = [] + for cat_id in categorical_features: + cat_file = os.open(os.path.join(data_path, f"cat_{cat_id}.bin"), os.O_RDONLY) + cat_bytes = self._categorical_bytes_per_batch[cat_id] + number_of_categorical_batches = math.ceil(os.fstat(cat_file).st_size / cat_bytes) if not drop_last_batch \ + else math.floor(os.fstat(cat_file).st_size / cat_bytes) + if number_of_categorical_batches != self._num_entries: + raise ValueError(f"Size mismatch in data files. Expected: {self._num_entries}, got: {number_of_categorical_batches}") + self._categorical_features_files.append(cat_file) + else: + self._categorical_features_files = None + + self._prefetch_depth = min(prefetch_depth, self._num_entries) + self._prefetch_queue = queue.Queue() + self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + + @classmethod + def get_metadata(cls, path, num_numerical_features): + with open(os.path.join(path, cls._model_size_filename), 'r') as f: + global_table_sizes = json.load(f) + + global_table_sizes = list(global_table_sizes.values()) + global_table_sizes = [s + 1 for s in global_table_sizes] + + metadata = DatasetMetadata(num_numerical_features=num_numerical_features, + categorical_cardinalities=global_table_sizes) + return metadata + + def __len__(self): + return self._num_entries + + def __getitem__(self, idx: int): + if idx >= self._num_entries: + raise IndexError() + + if self._prefetch_depth <= 1: + return self._get_item(idx) + + if idx == 0: + for i in range(self._prefetch_depth): + self._prefetch_queue.put(self._executor.submit(self._get_item, (i))) + if idx < self._num_entries - self._prefetch_depth: + self._prefetch_queue.put(self._executor.submit(self._get_item, (idx + self._prefetch_depth))) + return self._prefetch_queue.get().result() + + def _get_item(self, idx: int) -> Tuple[tf.Tensor, Optional[tf.Tensor], Optional[tf.Tensor]]: + click = self._get_label(idx) + numerical_features = self._get_numerical_features(idx) + categorical_features = self._get_categorical_features(idx) + return (numerical_features, categorical_features), click + + def _get_label(self, idx: int) -> tf.Tensor: + raw_label_data = os.pread(self._label_file, self._label_bytes_per_batch, + idx * self._label_bytes_per_batch) + array = np.frombuffer(raw_label_data, dtype=np.bool) + array = tf.convert_to_tensor(array, dtype=tf.float32) + array = tf.expand_dims(array, 1) + return array + + def _get_numerical_features(self, idx: int) -> Optional[tf.Tensor]: + if self._numerical_features_file is None: + return -1 + + raw_numerical_data = os.pread(self._numerical_features_file, self._numerical_bytes_per_batch, + idx * self._numerical_bytes_per_batch) + array = np.frombuffer(raw_numerical_data, dtype=np.float16) + array = tf.convert_to_tensor(array) + return tf.reshape(array, shape=[self._batch_size, self._numerical_features]) + + def _get_categorical_features(self, idx: int) -> Optional[tf.Tensor]: + if self._categorical_features_files is None: + return -1 + + categorical_features = [] + for cat_id, cat_file in zip(self._categorical_features, self._categorical_features_files): + cat_bytes = self._categorical_bytes_per_batch[cat_id] + cat_type = self._categorical_feature_types[cat_id] + raw_cat_data = os.pread(cat_file, cat_bytes, idx * cat_bytes) + array = np.frombuffer(raw_cat_data, dtype=cat_type) + tensor = tf.convert_to_tensor(array) + tensor = tf.expand_dims(tensor, axis=1) + categorical_features.append(tensor) + return categorical_features + + def __del__(self): + data_files = [self._label_file, self._numerical_features_file] + if self._categorical_features_files is not None: + data_files += self._categorical_features_files + + for data_file in data_files: + if data_file is not None: + os.close(data_file) diff --git a/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/LICENSE b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/LICENSE new file mode 100644 index 00000000..7f9708a7 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2021 NVIDIA Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/MANIFEST.in b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/MANIFEST.in new file mode 100644 index 00000000..5a88f0c9 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/MANIFEST.in @@ -0,0 +1 @@ +recursive-include tensorflow_dot_based_interact *.so diff --git a/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/Makefile b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/Makefile new file mode 100644 index 00000000..23d3f9e2 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/Makefile @@ -0,0 +1,57 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +CXX := g++ +NVCC := nvcc +PYTHON_BIN_PATH = python + +TF_CFLAGS := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') +TF_LFLAGS := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') + +CFLAGS = ${TF_CFLAGS} -fPIC -O2 -std=c++11 +LDFLAGS = -shared ${TF_LFLAGS} + +.DEFAULT_GOAL := lib + +CC_SRC_DIR = tensorflow_dot_based_interact/cc +CC_PY_DIR = tensorflow_dot_based_interact/python +CC_SRCS = $(CC_SRC_DIR)/kernels/dot_based_interact_kernels.cc \ + $(CC_SRC_DIR)/kernels/dot_based_interact_grad_kernels.cc \ + $(CC_SRC_DIR)/ops/dot_based_interact_ops.cc +VOLTA_TARGET_OBJECT = $(CC_SRC_DIR)/_dot_based_interact_volta.cu.o +AMPERE_TARGET_OBJECT = $(CC_SRC_DIR)/_dot_based_interact_ampere.cu.o +TARGET_LIB = $(CC_PY_DIR)/ops/_dot_based_interact_ops.so + +volta: $(VOLTA_TARGET_OBJECT) +$(VOLTA_TARGET_OBJECT): $(CC_SRC_DIR)/kernels/volta/dot_based_interact_volta.cu.cc + $(NVCC) -std=c++11 -c -o $@ $^ $(TF_CFLAGS) -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC -DNDEBUG --expt-relaxed-constexpr -arch=sm_70 + +ampere: $(AMPERE_TARGET_OBJECT) +$(AMPERE_TARGET_OBJECT): $(CC_SRC_DIR)/kernels/ampere/dot_based_interact_ampere.cu.cc + $(NVCC) -std=c++11 -c -o $@ $^ $(TF_CFLAGS) -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC -DNDEBUG --expt-relaxed-constexpr -arch=sm_80 + +lib: $(TARGET_LIB) +$(TARGET_LIB): $(CC_SRCS) $(VOLTA_TARGET_OBJECT) $(AMPERE_TARGET_OBJECT) + $(CXX) $(CFLAGS) -o $@ $^ ${LDFLAGS} -D GOOGLE_CUDA=1 -I/usr/local/cuda/targets/x86_64-linux/include -L/usr/local/cuda/targets/x86_64-linux/lib -lcudart + +test: $(CC_PY_DIR)/ops/dot_based_interact_ops_test.py $(CC_PY_DIR)/ops/dot_based_interact_ops.py $(TARGET_LIB) + $(PYTHON_BIN_PATH) $(CC_PY_DIR)/ops/dot_based_interact_ops_test.py + +pkg: $(TARGET_LIB) + ./build_pip_pkg.sh + +clean: + rm -f $(VOLTA_TARGET_OBJECT) $(TARGET_LIB) + rm -rf artifacts diff --git a/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/README.md b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/README.md new file mode 100644 index 00000000..241cf03d --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/README.md @@ -0,0 +1,101 @@ +# TF 2.x Dot Based Interacti CUDA Op + +## Requirements + +This op needs to run from within a TF2 NGC containter >= 20.12. E.g.: +``` +docker pull gitlab-master.nvidia.com:5005/dl/dgx/tensorflow:21.02-tf2-py3-devel + +docker run -it [...] gitlab-master.nvidia.com:5005/dl/dgx/tensorflow:21.02-tf2-py3-devel +``` + +## Installation + +The package with built binaries will be internally hosted on [my PyPi package registry](https://gitlab-master.nvidia.com/wraveane/pypi). There are two ways to install this: + +- Either: Install directly via PIP using my gitlab token (or replace the URL with your own token): + ``` + pip3 install --extra-index-url https://__token__:TmAosCzLDiFzS7x3J1aN@gitlab-master.nvidia.com/api/v4/projects/38036/packages/pypi/simple tensorflow-dot-based-interact + ``` +- Or: Manually download the wheel package file from the [package's registry page](https://gitlab-master.nvidia.com/wraveane/pypi/-/packages/1376), and install it: + ``` + pip install ./tensorflow_dot_based_interact-*.whl + ``` + +## Build from Source + +Alternatively, it can be built from source as follows: + +- Fix the TF CUDA include directory: + ``` + mkdir -p /usr/local/lib/python3.8/dist-packages/tensorflow/include/third_party/gpus/cuda/ + ln -s /usr/local/cuda/include /usr/local/lib/python3.8/dist-packages/tensorflow/include/third_party/gpus/cuda/ + ``` +- Clone this repository and build it: + ``` + git clone https://gitlab-master.nvidia.com/wraveane/tensorflow-dot-based-interact + cd tensorflow-dot-based-interact + make + ``` +- Run the [unit tests](tensorflow_dot_based_interact/python/ops/dot_based_interact_ops_test.py) to ensure the op is working as intended: + ``` + make test + ``` +- Install the TF Op package in one of two ways: + - Either: Create a wheel and install it with pip: + ``` + make pkg + pip install ./artifacts/tensorflow_dot_based_interact-*.whl + ``` + - Or: Install the repository directory locally: + ``` + pip install -e . + ``` + +## Usage + +The architecture to be used is as follows: + +![Dot Based Interact](https://docs.google.com/drawings/d/e/2PACX-1vT-RW1_SsvfENGogMxiqM8_pwDR6m8WXklWzX5kICDOJLK_0XPfO2oLyo_G9apVDXsc9LYE2XP7_e9I/pub?w=368&h=489) + +Where the TF CUDA op implemented by this package takes two inputs: +- **input**: The concatenation (done in TensorFlow) of the Bottom MLP output and the embeddings. +- **bottom_mlp_output**: A copy of the Bottom MLP output tensor. + +The result of the operation will already have the Bottom MLP output tensor concatenated, ready to be given to the next stage of the architecture. + +To use it, follow the installation or building instructions for the package above. Then: + +- Make sure the op is properly installed: + ``` + pip show tensorflow-dot-based-interact + ``` +- Use it like this: + ``` + from tensorflow_dot_based_interact.python.ops import dot_based_interact_ops + + bottom_mlp_output = ... # The bottom MLP output tensor + embeddings = ... # The sparse features embeddings tensor + + input = tf.concat([bottom_mlp_output, embeddings], axis=1) # Bottom Concat + + result = dot_based_interact_ops.dot_based_interact(input, bottom_mlp_output) + ``` + +## Support + +The TF DBI custom op will dynamically switch kernel versions according to: + +- GPU Architecture: + - GPU Major Version >= 8: Ampere Kernels + - GPU Major Version == 7: Volta Kernels + - GPU Major Version <= 6: Not Supported / Error Thrown +- Data Alignment +- Data Type: + - Ampere: + - TF32 (on aligned inputs) + - FP32 (fallback on non-aligned inputs) + - FP16 + - Volta: + - FP32 + - FP16 diff --git a/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/build_pip_pkg.sh b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/build_pip_pkg.sh new file mode 100755 index 00000000..b17ede09 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/build_pip_pkg.sh @@ -0,0 +1,31 @@ +#!/usr/bin/env bash + +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +DEST=$(readlink -f "artifacts") +mkdir -p "${DEST}" +TMPDIR=$(mktemp -d -t tmp.XXXXXXXXXX) + +cp setup.py "${TMPDIR}" +cp MANIFEST.in "${TMPDIR}" +cp LICENSE "${TMPDIR}" +rsync -avm -L --exclude='*_test.py' --exclude='*/cc/*' --exclude='*/__pycache__/*' ${PIP_FILE_PREFIX}tensorflow_dot_based_interact "${TMPDIR}" +pushd ${TMPDIR} +python3 setup.py bdist_wheel > /dev/null +cp dist/*.whl "${DEST}" +popd +rm -rf ${TMPDIR} diff --git a/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/setup.py b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/setup.py new file mode 100644 index 00000000..0e76d509 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/setup.py @@ -0,0 +1,78 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Setup for pip package.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from setuptools import Extension +from setuptools import find_packages +from setuptools import setup +from setuptools.dist import Distribution + + +__version__ = '0.0.1' +REQUIRED_PACKAGES = [ + 'tensorflow >= 2.3.1', +] +project_name = 'tensorflow-dot-based-interact' + + +from setuptools.command.install import install +class InstallPlatlib(install): + def finalize_options(self): + install.finalize_options(self) + self.install_lib = self.install_platlib + + +class BinaryDistribution(Distribution): + """This class is needed in order to create OS specific wheels.""" + + def has_ext_modules(self): + return True + + def is_pure(self): + return False + +setup( + name=project_name, + version=__version__, + description=('tensorflow-dot-based-interact is a CUDA Dot Based Interact custom op for TensorFlow'), + author='NVIDIA Corporation', + author_email='info@nvidia.com', + # Contained modules and scripts. + packages=find_packages(), + install_requires=REQUIRED_PACKAGES, + # Add in any packaged data. + include_package_data=True, + zip_safe=False, + distclass=BinaryDistribution, + cmdclass={'install': InstallPlatlib}, + # PyPI package information. + classifiers=[ + 'Development Status :: 4 - Beta', + 'Intended Audience :: Developers', + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', + 'Programming Language :: Python :: 3.8', + 'Topic :: Scientific/Engineering :: Mathematics', + 'Topic :: Software Development :: Libraries :: Python Modules', + 'Topic :: Software Development :: Libraries', + ], + license='Apache 2.0', + keywords='tensorflow custom op machine learning', +) diff --git a/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/__init__.py b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/__init__.py new file mode 100644 index 00000000..ecb872fe --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import absolute_import + +from tensorflow_dot_based_interact.python.ops.dot_based_interact_ops import dot_based_interact diff --git a/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/ampere/dot_based_interact_ampere.cu.cc b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/ampere/dot_based_interact_ampere.cu.cc new file mode 100644 index 00000000..85750732 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/ampere/dot_based_interact_ampere.cu.cc @@ -0,0 +1,528 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "dot_based_interact_ampere.h" +#include "dot_based_interact_ampere_fp32.cu.inl" +#include "dot_based_interact_ampere_tf32.cu.inl" +#include "dot_based_interact_ampere_half.cu.inl" + +inline void dotBasedInteractAmpereF32Fwd(const void *input, + const void *bottom_mlp_output, + void *output, + uint batch_size, + uint num_rows, + uint num_cols, + cudaStream_t stream) { + const uint kPaddingSize = 1; + const uint kNumThreads = 128; + uint num_blocks = batch_size; + + // Output + uint interaction_output_size = (num_rows * (num_rows - 1)) >> 1; + uint output_size = num_cols + interaction_output_size + kPaddingSize; + + // Input + uint input_size = num_rows * num_cols; + + uint shared_mem_size_elems = input_size; + uint shared_mem_size_bytes = shared_mem_size_elems << 2; // F32 Kernel + + bool float4_predicate = !((num_cols & 3) || (output_size & 3)); + + if (float4_predicate) { + dotBasedInteractF32FwdKernel + <<>>((const float *)input, + (float *)output, + batch_size, + num_rows, + num_cols, + input_size, + output_size, + interaction_output_size); + } else { + dotBasedInteractF32FwdKernelNonAligned + <<>>((const float *)input, + (float *)output, + batch_size, + num_rows, + num_cols, + input_size, + output_size, + interaction_output_size); + } +} + +inline void dotBasedInteractAmpereF32Bwd(const void *input, + const void *upstream_grad, + void *grad, + void *bottom_mlp_grad, + uint batch_size, + uint num_rows, + uint num_cols, + cudaStream_t stream) { + const uint kPaddingSize = 1; + const uint kNumThreads = 128; + + uint num_blocks = batch_size; + + uint input_size = num_rows * num_cols; + + // 1D ugrad size + uint interaction_ugrad_size = num_rows * (num_rows - 1) >> 1; + uint interaction_ugrad_size_with_padding = interaction_ugrad_size + kPaddingSize; + uint ugrad_size = num_cols + interaction_ugrad_size_with_padding; + + // input space + upstream grad space + uint smem_size_elems = input_size + interaction_ugrad_size; + uint smem_size_bytes = smem_size_elems << 2; // F32 Kernel + + bool float4_predicate = !((interaction_ugrad_size_with_padding & 3) || (num_cols & 3)); + if (float4_predicate) { + dotBasedInteractF32BwdKernel + <<>>((const float *)input, + (const float *)upstream_grad, + (float *)grad, + (float *)bottom_mlp_grad, + batch_size, + num_rows, + num_cols, + input_size, + ugrad_size, + interaction_ugrad_size); + } else { + dotBasedInteractF32BwdKernelNonAligned + <<>>((const float *)input, + (const float *)upstream_grad, + (float *)grad, + (float *)bottom_mlp_grad, + batch_size, + num_rows, + num_cols, + input_size, + ugrad_size, + interaction_ugrad_size); + } +} + +void dotBasedInteractAmpereF16Fwd(const void *input, + const void *bottom_mlp_output, + void *output, + uint batch_size, + uint num_rows, + uint num_cols, + cudaStream_t stream) { + const uint kWarpSize = 32; + const uint kWarpSizeLog2 = Log2::value; + const uint kTileDim = 16; + const uint kTileDimLog2 = Log2::value; + const uint warps_per_threadblock = 4; + const uint threadblock_size = warps_per_threadblock * 32; + const uint kPaddingSize = 1; + const uint kRowTilesPerStep = 2; + const uint kColTilesPerStep = 1; + + // num tiles + uint num_row_tiles = (num_rows + kTileDim - 1) >> kTileDimLog2; + uint num_col_tiles = (num_cols + kTileDim - 1) >> kTileDimLog2; + + // number of rows and columns after padding + uint num_rows_after_padding = kTileDim << 1; + uint num_cols_after_padding = num_col_tiles << kTileDimLog2; + + uint num_row_steps = num_row_tiles / kRowTilesPerStep; + uint num_col_steps = num_col_tiles / kColTilesPerStep; + + const uint K_BLOCKS = 8; + const uint M_BLOCKS = 2; + const uint SKEW_HALF = ((K_BLOCKS % 2) == 0) ? 8 : 0; + const uint SMEM_STRIDE = (K_BLOCKS * 16 + SKEW_HALF); + // multiple of 2 to guarantee 256-bit alignment for start of the row, at least 16 to safeload a tile + const uint smem_rows_per_warp = M_BLOCKS << 4; + const uint smem_elems_per_warp_mat = smem_rows_per_warp * SMEM_STRIDE; + const uint SKEW_HALF_ACC = ((M_BLOCKS % 2) == 0) ? 8 : 0; + const uint SMEM_STRIDE_ACC = (M_BLOCKS * 16 + SKEW_HALF_ACC); + const uint smem_elems_per_warp_acc = M_BLOCKS * 16 * SMEM_STRIDE_ACC * 2; // output in FP32 + const uint smem_elems_per_warp = + (smem_elems_per_warp_mat > smem_elems_per_warp_acc) ? smem_elems_per_warp_mat : smem_elems_per_warp_acc; + uint output_size = num_cols + ((num_rows * (num_rows - 1)) >> 1) + kPaddingSize; + + bool float4_predicate = !((num_cols & 7) || (output_size & 7)); + + if (float4_predicate) { + dotBasedInteractFwdKernel + <<<(batch_size + warps_per_threadblock - 1) / warps_per_threadblock, + threadblock_size, + warps_per_threadblock * smem_elems_per_warp * sizeof(__half), stream>>>((const __half *)input, + (half *)output, + batch_size, + num_rows, + num_cols, + num_rows_after_padding, + num_cols_after_padding, + smem_elems_per_warp, + smem_rows_per_warp, + output_size, + num_row_steps, + num_col_steps); + } else { + dotBasedInteractFwdKernelNonAligned + <<<(batch_size + warps_per_threadblock - 1) / warps_per_threadblock, + threadblock_size, + warps_per_threadblock * smem_elems_per_warp * sizeof(__half), stream>>>((const __half *)input, + (half *)output, + batch_size, + num_rows, + num_cols, + num_rows_after_padding, + num_cols_after_padding, + smem_elems_per_warp, + smem_rows_per_warp, + output_size, + num_row_steps, + num_col_steps); + } +} + +void dotBasedInteractAmpereF16Bwd(const void *input, + const void *upstream_grad, + void *grad, + void *bottom_mlp_grad, + uint batch_size, + uint num_rows, + uint num_cols, + cudaStream_t stream) { + const uint kWarpSize = 32; + const uint kWarpSizeLog2 = Log2::value; + const uint kTileDim = 16; + const uint kTileDimLog2 = Log2::value; + const uint mem_skew_size = 8; + const uint kPaddingSize = 1; + const uint kWarpsPerBlock = 4; + const uint kWarpsPerBlockLog2 = Log2::value; + const uint kNumThreads = kWarpsPerBlock * kWarpSize; + const uint kRowTilesPerStep = 2; + const uint kColTilesPerStep = 1; + + uint row_tiles_per_step = num_rows > kTileDim ? kRowTilesPerStep : 1; + + // num tiles + uint num_row_tiles = (num_rows + kTileDim - 1) >> kTileDimLog2; + uint num_col_tiles = (num_cols + kTileDim - 1) >> kTileDimLog2; + + // number of rows and columns after padding + uint num_rows_after_padding = kTileDim << 1; + uint num_cols_after_padding = num_col_tiles << kTileDimLog2; + + // 2D ugrad size and stride + uint interaction_ugrad_2D_stride = num_rows_after_padding + mem_skew_size; + uint interaction_ugrad_2D_size_elems = num_rows_after_padding * interaction_ugrad_2D_stride; + uint interaction_ugrad_2D_size_bytes = interaction_ugrad_2D_size_elems * sizeof(half); + + // 1D ugrad size + uint interaction_ugrad_size = num_rows * (num_rows - 1) >> 1; + uint interaction_ugrad_size_with_padding = interaction_ugrad_size + kPaddingSize; + + // in_out place size and stride + uint input_stride = num_cols_after_padding + mem_skew_size; + uint input_size_elems = num_rows_after_padding * input_stride; + uint input_size_bytes = input_size_elems * sizeof(half); + + // sample size + uint sample_size = num_rows * num_cols; + + // output size + uint output_size_elems = kTileDim * kTileDim * kRowTilesPerStep * kColTilesPerStep; + uint output_size_bytes = output_size_elems * sizeof(float); + + // staging area size + uint staging_area_size_bytes = + output_size_bytes > interaction_ugrad_2D_size_bytes ? output_size_bytes : interaction_ugrad_2D_size_bytes; + + // Shared memory size + uint shared_mem_per_warp_size_byte = input_size_bytes + staging_area_size_bytes; + uint shared_mem_size_bytes = kWarpsPerBlock * shared_mem_per_warp_size_byte; + + uint num_blocks = (batch_size + kWarpsPerBlock - 1) >> kWarpsPerBlockLog2; + uint num_row_steps = num_row_tiles / row_tiles_per_step; + uint num_col_steps = num_col_tiles / kColTilesPerStep; + + bool float4_predicate = !((interaction_ugrad_size_with_padding & 7) || (num_cols & 7)); + if (float4_predicate) { + dotBasedInteractBwdKernel + <<>>((const half *)input, + (const half *)upstream_grad, + (half *)grad, + (half *)bottom_mlp_grad, + batch_size, + num_rows, + num_cols, + num_rows_after_padding, + num_cols_after_padding, + sample_size, + interaction_ugrad_size, + interaction_ugrad_size_with_padding, + interaction_ugrad_2D_size_elems, + interaction_ugrad_2D_stride, + input_size_elems, + input_stride, + num_row_steps, + num_col_steps, + row_tiles_per_step, + shared_mem_per_warp_size_byte); + } else { + dotBasedInteractBwdKernelNonAligned + <<>>((const half *)input, + (const half *)upstream_grad, + (half *)grad, + (half *)bottom_mlp_grad, + batch_size, + num_rows, + num_cols, + num_rows_after_padding, + num_cols_after_padding, + sample_size, + interaction_ugrad_size, + interaction_ugrad_size_with_padding, + interaction_ugrad_2D_size_elems, + interaction_ugrad_2D_stride, + input_size_elems, + input_stride, + num_row_steps, + num_col_steps, + row_tiles_per_step, + shared_mem_per_warp_size_byte); + } +} + +void dotBasedInteractAmpereTF32Fwd(const void *input, + const void *bottom_mlp_output, + void *output, + uint batch_size, + uint num_rows, + uint num_cols, + cudaStream_t stream) { + const uint kWarpSize = 32; + const uint kWarpSizeLog2 = Log2::value; + const uint kTileLength = 16; + const uint kTileLengthLog2 = Log2::value; + const uint kTileWidth = 8; + const uint kTileWidthLog2 = Log2::value; + const uint kWarpsPerBlock = 2; + const uint kThreadBlockSize = kWarpsPerBlock * kWarpSize; + const uint kPaddingSize = 1; + const uint kRowTilesPerStep = 2; + const uint kColTilesPerStep = 1; + const uint kSkewFloat = 4; // Ensures we are 16 byte align as required by nvcuda::wmma::load_matrix_sync + + // num tiles + uint mat_a_num_row_tiles = (num_rows + kTileLength - 1) >> kTileLengthLog2; + uint mat_a_num_col_tiles = (num_cols + kTileWidth - 1) >> kTileWidthLog2; + + // const uint &mat_b_num_row_tiles = mat_a_num_col_tiles; + // const uint &mat_b_num_col_tiles = mat_a_num_row_tiles; + + // number of rows and columns after padding + uint num_rows_after_padding = mat_a_num_row_tiles << kTileLengthLog2; + uint num_cols_after_padding = mat_a_num_col_tiles << kTileWidthLog2; + + uint num_row_steps = mat_a_num_row_tiles / kRowTilesPerStep; + uint num_col_steps = mat_a_num_col_tiles / kColTilesPerStep; + + const uint smem_stride = num_cols_after_padding + kSkewFloat; + const uint smem_elems_per_warp_mat = num_rows_after_padding * smem_stride; + + const uint smem_stride_acc = num_rows_after_padding + kSkewFloat; + const uint smem_elems_per_warp_acc = num_rows_after_padding * smem_stride_acc; + + const uint smem_elems_per_warp = + smem_elems_per_warp_mat > smem_elems_per_warp_acc ? smem_elems_per_warp_mat : smem_elems_per_warp_acc; + + uint output_size = num_cols + ((num_rows * (num_rows - 1)) >> 1) + kPaddingSize; + bool float4_predicate = !((num_cols & 7) || (output_size & 7)); + + if (float4_predicate) { + dotBasedInteractTF32FwdKernel + <<<(batch_size + kWarpsPerBlock - 1) / kWarpsPerBlock, + kThreadBlockSize, + kWarpsPerBlock * smem_elems_per_warp * sizeof(float), stream>>>((const float *)input, + (float *)output, + batch_size, + num_rows, + num_cols, + num_rows_after_padding, + num_cols_after_padding, + smem_elems_per_warp, + output_size, + num_row_steps, + num_col_steps, + smem_stride, + smem_stride_acc); + } else { + // GENERIC VERSION IS UNFINISHED: Use FP32 instead for now + dotBasedInteractAmpereF32Fwd(input, + bottom_mlp_output, + output, + batch_size, + num_rows, + num_cols, + stream); + } +} + +void dotBasedInteractAmpereTF32Bwd(const void *input, + const void *upstream_grad, + void *grad, + void *bottom_mlp_grad, + uint batch_size, + uint num_rows, + uint num_cols, + cudaStream_t stream) { + // Fragment Settings + const uint kFragARows = 2; + const uint kFragBCols = 2; + const uint kTileLength = 16; + const uint kTileLengthLog2 = Log2::value; + const uint kTileWidth = 8; + const uint kTileWidthLog2 = Log2::value; + + const uint kWarpSize = 32; + const uint kWarpSizeLog2 = Log2::value; + const uint kSkewFloat = 4; + const uint kPaddingSize = 1; + const uint kWarpsPerBlock = 1; + const uint kWarpsPerBlockLog2 = Log2::value; + const uint kNumThreads = kWarpsPerBlock * kWarpSize; + + // num tiles + uint mat_a_num_row_tiles = (num_rows + kTileLength - 1) >> kTileLengthLog2; + uint mat_a_num_col_tiles = (num_rows + kTileWidth - 1) >> kTileWidthLog2; + + // const uint &mat_b_num_row_tiles = mat_a_num_col_tiles; + uint mat_b_num_col_tiles = (num_cols + kTileLength - 1) >> kTileLengthLog2; + + // number of rows and columns after padding + uint num_rows_after_padding = mat_a_num_row_tiles << kTileLengthLog2; + uint num_cols_after_padding = mat_b_num_col_tiles << kTileLengthLog2; + + // 2D ugrad size and stride + uint interaction_ugrad_2D_stride = num_rows_after_padding + kSkewFloat; + uint interaction_ugrad_2D_size_elems = num_rows_after_padding * interaction_ugrad_2D_stride; + + // 1D ugrad size + uint interaction_ugrad_size = num_rows * (num_rows - 1) >> 1; + uint interaction_ugrad_size_with_padding = interaction_ugrad_size + kPaddingSize; + + // in_out place size and stride + uint input_stride = num_cols_after_padding + kSkewFloat; + uint input_size_elems = num_rows_after_padding * input_stride; + + // sample size + uint sample_size = num_rows * num_cols; + + // output size + uint output_size_elems = kTileLength * kTileLength * kFragARows * kFragBCols; + + // Shared memory size + uint shared_mem_per_warp_size_elems = interaction_ugrad_2D_size_elems + input_size_elems + output_size_elems; + uint shared_mem_size_elems = kWarpsPerBlock * shared_mem_per_warp_size_elems; + uint shared_mem_size_bytes = shared_mem_size_elems * sizeof(float); + + uint num_blocks = (batch_size + kWarpsPerBlock - 1) >> kWarpsPerBlockLog2; + uint num_k_steps = mat_a_num_col_tiles; + uint num_n_steps = mat_b_num_col_tiles / kFragBCols; + + bool float4_predicate = !((interaction_ugrad_size_with_padding & 7) || (num_cols & 7)); + if (float4_predicate) { + dotBasedInteractTF32BwdKernel + <<>>((const float *)input, + (const float *)upstream_grad, + (float *)grad, + (float *)bottom_mlp_grad, + batch_size, + num_rows, + num_cols, + num_rows_after_padding, + num_cols_after_padding, + sample_size, + interaction_ugrad_size, + interaction_ugrad_size_with_padding, + interaction_ugrad_2D_size_elems, + interaction_ugrad_2D_stride, + input_size_elems, + input_stride, + shared_mem_per_warp_size_elems, + num_k_steps, + num_n_steps); + } else { + // GENERIC VERSION IS UNFINISHED: Use FP32 instead for now + dotBasedInteractAmpereF32Bwd(input, + upstream_grad, + grad, + bottom_mlp_grad, + batch_size, + num_rows, + num_cols, + stream); + } +} diff --git a/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/ampere/dot_based_interact_ampere.h b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/ampere/dot_based_interact_ampere.h new file mode 100644 index 00000000..c159fd58 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/ampere/dot_based_interact_ampere.h @@ -0,0 +1,53 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KERNEL_DOT_BASED_INTERACT_AMPERE_H_ +#define KERNEL_DOT_BASED_INTERACT_AMPERE_H_ + +void dotBasedInteractAmpereF16Fwd(const void *input, + const void *bottom_mlp_output, + void *output, + uint batch_size, + uint num_rows, + uint num_cols, + cudaStream_t stream); + +void dotBasedInteractAmpereF16Bwd(const void *input, + const void *upstream_grad, + void *grad, + void *bottom_mlp_grad, + uint batch_size, + uint num_rows, + uint num_cols, + cudaStream_t stream); + +void dotBasedInteractAmpereTF32Fwd(const void *input, + const void *bottom_mlp_output, + void *output, + uint batch_size, + uint num_rows, + uint num_cols, + cudaStream_t stream); + +void dotBasedInteractAmpereTF32Bwd(const void *input, + const void *upstream_grad, + void *grad, + void *bottom_mlp_grad, + uint batch_size, + uint num_rows, + uint num_cols, + cudaStream_t stream); + +#endif //KERNEL_DOT_BASED_INTERACT_AMPERE_H_ diff --git a/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/ampere/dot_based_interact_ampere_fp32.cu.inl b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/ampere/dot_based_interact_ampere_fp32.cu.inl new file mode 100644 index 00000000..885f6fd6 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/ampere/dot_based_interact_ampere_fp32.cu.inl @@ -0,0 +1,280 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "../dot_based_interact_shared_utils.cu.h" + +template +__launch_bounds__(THREADBLOCK_SIZE) __global__ + void dotBasedInteractF32FwdKernelNonAligned(const float *__restrict input, + float *__restrict output, + uint batch_size, + uint num_rows, + uint num_cols, + uint input_size, + uint output_size, + uint interaction_output_size) { + extern __shared__ float smem_f32_fwd[]; + float *smem_in = &smem_f32_fwd[0]; + + uint input_batch_offset = blockIdx.x * input_size; + const float *gmem_in = &input[input_batch_offset]; + + uint output_batch_offset = blockIdx.x * output_size; + float *gmem_out_bottom_mlp = &output[output_batch_offset]; + float *gmem_out_interaction = &output[output_batch_offset + num_cols]; + + // Load the input - one sample per block + for (uint idx = threadIdx.x; idx < input_size; idx += blockDim.x) { + smem_in[idx] = gmem_in[idx]; + } + __syncthreads(); + + // Copy bottom MLP output to output + for (uint idx = threadIdx.x; idx < num_cols; idx += blockDim.x) { + gmem_out_bottom_mlp[idx] = smem_in[idx]; + } + + for (uint idx = threadIdx.x; idx < (interaction_output_size); idx += blockDim.x) { + uint elems_per_row = 1; + uint index = idx; + while (index >= elems_per_row) { + index -= elems_per_row; + elems_per_row++; + } + uint target_row = elems_per_row; + uint target_col = index; + + float sum = 0; + for (uint i = 0; i < num_cols; i++) { + float tmp1 = smem_in[target_row * num_cols + i]; + float tmp2 = smem_in[target_col * num_cols + i]; + sum = fmaf(tmp1, tmp2, sum); + } + + gmem_out_interaction[idx] = sum; + } + + gmem_out_interaction[interaction_output_size] = 0; +} + +template +__launch_bounds__(THREADBLOCK_SIZE) __global__ void dotBasedInteractF32FwdKernel(const float *__restrict input, + float *__restrict output, + uint batch_size, + uint num_rows, + uint num_cols, + uint input_size, + uint output_size, + uint interaction_output_size) { + extern __shared__ float smem_f32_fwd[]; + float *smem_in = &smem_f32_fwd[0]; + + uint input_batch_offset = blockIdx.x * input_size; + const float *gmem_in = &input[input_batch_offset]; + + uint output_batch_offset = blockIdx.x * output_size; + float *gmem_out_bottom_mlp = &output[output_batch_offset]; + float *gmem_out_interaction = &output[output_batch_offset + num_cols]; + + // Load the input - one sample per block + uint input_size_float4 = input_size >> 2; + for (uint idx = threadIdx.x; idx < input_size_float4; idx += blockDim.x) { + ((float4 *)smem_in)[idx] = ((float4 *)gmem_in)[idx]; + } + __syncthreads(); + + // Copy bottom MLP output to output + uint btm_mlp_out_size_float4 = num_cols >> 2; + for (uint idx = threadIdx.x; idx < btm_mlp_out_size_float4; idx += blockDim.x) { + ((float4 *)gmem_out_bottom_mlp)[idx] = ((float4 *)smem_in)[idx]; + } + + for (uint idx = threadIdx.x; idx < (interaction_output_size); idx += blockDim.x) { + uint elems_per_row = 1; + uint index = idx; + while (index >= elems_per_row) { + index -= elems_per_row; + elems_per_row++; + } + uint target_row = elems_per_row; + uint target_col = index; + + float4 sum; + sum.x = 0; + sum.y = 0; + sum.z = 0; + sum.w = 0; + uint num_cols_float4 = num_cols >> 2; + for (uint i = 0; i < num_cols_float4; i++) { + float4 tmp1 = ((float4 *)smem_in)[target_row * num_cols_float4 + i]; + float4 tmp2 = ((float4 *)smem_in)[target_col * num_cols_float4 + i]; + sum.x = fmaf(tmp1.x, tmp2.x, sum.x); + sum.y = fmaf(tmp1.y, tmp2.y, sum.y); + sum.z = fmaf(tmp1.z, tmp2.z, sum.z); + sum.w = fmaf(tmp1.w, tmp2.w, sum.w); + } + + gmem_out_interaction[idx] = sum.x + sum.y + sum.z + sum.w; + } + + gmem_out_interaction[interaction_output_size] = 0; +} + +template +__launch_bounds__(THREADBLOCK_SIZE) __global__ + void dotBasedInteractF32BwdKernelNonAligned(const float *__restrict input, + const float *__restrict upstream_grad, + float *__restrict grad, + float *__restrict bottom_mlp_grad, + uint batch_size, + uint num_rows, + uint num_cols, + uint input_size, + uint ugrad_size, + uint interaction_ugrad_size) { + extern __shared__ float smem_f32_bwd[]; + float *smem_in = &smem_f32_bwd[0]; + float *smem_interaction_ugrad = &smem_f32_bwd[input_size]; + + // Input + uint input_batch_offset = blockIdx.x * input_size; + const float *gmem_in = &input[input_batch_offset]; + + // Gradient + const uint &grad_batch_offset = input_batch_offset; + float *gmem_mlp_grad = &bottom_mlp_grad[blockIdx.x * num_cols]; + float *gmem_interaction_grad = &grad[grad_batch_offset]; + + // Upstream Gradient + uint upstream_grad_batch_offset = blockIdx.x * ugrad_size; + const float *gmem_mlp_ugrad = &upstream_grad[upstream_grad_batch_offset]; + const float *gmem_interaction_ugrad = &upstream_grad[upstream_grad_batch_offset + num_cols]; + + // input -> shared memory + for (uint idx = threadIdx.x; idx < input_size; idx += blockDim.x) { + smem_in[idx] = gmem_in[idx]; + } + + // Interaction Upstream Grad -> Shared Memory + for (uint idx = threadIdx.x; idx < interaction_ugrad_size; idx += blockDim.x) { + smem_interaction_ugrad[idx] = gmem_interaction_ugrad[idx]; + } + __syncthreads(); + + // Copy the upstream gradient w.r.t to mlp to it's corresponding memory location. + for (uint idx = threadIdx.x; idx < num_cols; idx += blockDim.x) { + gmem_mlp_grad[idx] = gmem_mlp_ugrad[idx]; + } + + for (uint idx = threadIdx.x; idx < num_cols; idx += blockDim.x) { + size_t grad_idx = idx; + for (uint row_idx = 0; row_idx < num_rows; row_idx++) { + float sum = 0; + size_t upstream_grad_offset = (row_idx * (row_idx - 1)) >> 1; + for (int k = 0; k < row_idx; k++) { + sum = fmaf(smem_in[k * num_cols + idx], smem_interaction_ugrad[upstream_grad_offset + k], sum); + } + for (int k = row_idx + 1; k < num_rows; k++) { + upstream_grad_offset = (k * (k - 1)) >> 1; // TODO: this can become a sum + sum = fmaf(smem_in[k * num_cols + idx], smem_interaction_ugrad[upstream_grad_offset + row_idx], sum); + } + gmem_interaction_grad[grad_idx] = sum; + grad_idx += num_cols; + } + } +} + +template +__launch_bounds__(THREADBLOCK_SIZE) __global__ void dotBasedInteractF32BwdKernel(const float *__restrict input, + const float *__restrict upstream_grad, + float *__restrict grad, + float *__restrict bottom_mlp_grad, + uint batch_size, + uint num_rows, + uint num_cols, + uint input_size, + uint ugrad_size, + uint interaction_ugrad_size) { + extern __shared__ float smem_f32_bwd[]; + float *smem_in = &smem_f32_bwd[0]; + float *smem_interaction_ugrad = &smem_f32_bwd[input_size]; + + // Input + uint input_batch_offset = blockIdx.x * input_size; + const float *gmem_in = &input[input_batch_offset]; + + // Gradient + const uint &grad_batch_offset = input_batch_offset; + float *gmem_mlp_grad = &bottom_mlp_grad[blockIdx.x * num_cols]; + float *gmem_interaction_grad = &grad[grad_batch_offset]; + + // Upstream Gradient + uint upstream_grad_batch_offset = blockIdx.x * ugrad_size; + const float *gmem_mlp_ugrad = &upstream_grad[upstream_grad_batch_offset]; + const float *gmem_interaction_ugrad = &upstream_grad[upstream_grad_batch_offset + num_cols]; + + // input -> shared memory + uint input_size_float4 = input_size >> 2; + for (uint idx = threadIdx.x; idx < input_size_float4; idx += blockDim.x) { + ((float4 *)smem_in)[idx] = ((float4 *)gmem_in)[idx]; + } + + // Interaction Upstream Grad -> Shared Memory + uint upstream_grad_size_float4 = interaction_ugrad_size >> 2; + for (uint idx = threadIdx.x; idx < upstream_grad_size_float4; idx += blockDim.x) { + ((float4 *)smem_interaction_ugrad)[idx] = ((float4 *)gmem_interaction_ugrad)[idx]; + } + + uint vectorized_load_offset = (upstream_grad_size_float4 << 2); + for (uint idx = vectorized_load_offset + threadIdx.x; idx < interaction_ugrad_size; idx += blockDim.x) { + smem_interaction_ugrad[idx] = gmem_interaction_ugrad[idx]; + } + __syncthreads(); + + // Copy the upstream gradient w.r.t to mlp to it's corresponding memory location. + for (uint idx = threadIdx.x; idx < (num_cols >> 2); idx += blockDim.x) { + ((float4 *)gmem_mlp_grad)[idx] = ((float4 *)gmem_mlp_ugrad)[idx]; + } + + for (uint idx = threadIdx.x; idx < num_cols; idx += blockDim.x) { + size_t grad_idx = idx; + for (uint row_idx = 0; row_idx < num_rows; row_idx++) { + float sum = 0; + size_t upstream_grad_offset = (row_idx * (row_idx - 1)) >> 1; + for (int k = 0; k < row_idx; k++) { + sum = fmaf(smem_in[k * num_cols + idx], smem_interaction_ugrad[upstream_grad_offset + k], sum); + } + for (int k = row_idx + 1; k < num_rows; k++) { + upstream_grad_offset = (k * (k - 1)) >> 1; // TODO: this can become a sum + sum = fmaf(smem_in[k * num_cols + idx], smem_interaction_ugrad[upstream_grad_offset + row_idx], sum); + } + gmem_interaction_grad[grad_idx] = sum; + grad_idx += num_cols; + } + } +} diff --git a/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/ampere/dot_based_interact_ampere_half.cu.inl b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/ampere/dot_based_interact_ampere_half.cu.inl new file mode 100644 index 00000000..e93d7f52 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/ampere/dot_based_interact_ampere_half.cu.inl @@ -0,0 +1,570 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../dot_based_interact_shared_utils.cu.h" + +struct __align__(8) half4 { + half2 vals[2]; +}; + +using namespace nvcuda; + +template +__launch_bounds__(THREADBLOCK_SIZE) __global__ void dotBasedInteractFwdKernelNonAligned(const __half *__restrict input, + __half *__restrict output, + uint batch_size, + uint num_rows, + uint num_cols, + uint num_rows_after_padding, + uint num_cols_after_padding, + uint smem_elems_per_warp, + uint smem_rows_per_warp, + uint output_size, + uint num_row_steps, + uint num_col_steps) { + uint warp_id = (threadIdx.x >> WARP_SIZE_LOG_2); + int sample_id = blockIdx.x * WARPS_PER_BLOCK + warp_id; + if (sample_id >= batch_size) { + return; + } + int lane_id = threadIdx.x & (WARP_SIZE - 1); + + extern __shared__ half shmem_dynamic[]; + half *shmem = shmem_dynamic + (warp_id * smem_elems_per_warp); + + const half *sample_input = input + num_rows * num_cols * sample_id; + for (uint i = 0; i < num_rows; ++i, sample_input += num_cols) { + for (uint idx = lane_id; idx < num_cols; idx += WARP_SIZE) { + (shmem + i * SMEM_STRIDE)[idx] = sample_input[idx]; + } + } + + uint idx = lane_id + num_cols; + if (idx < num_cols_after_padding) { + for (int i = 0; i < num_rows; ++i) { + (shmem + i * SMEM_STRIDE)[idx] = __float2half(0); + } + } + + half4 zeros; + zeros.vals[0].x = __float2half(0); + zeros.vals[0].y = __float2half(0); + zeros.vals[1].x = __float2half(0); + zeros.vals[1].y = __float2half(0); + if (lane_id < (num_cols_after_padding >> 2)) { + for (int i = num_rows; i < num_rows_after_padding; i++) { + ((half4 *)(shmem + i * SMEM_STRIDE))[lane_id] = zeros; + } + } + __syncwarp(); + half *gmem_output = output + output_size * sample_id; + + for (uint idx = lane_id; idx < num_cols; idx += WARP_SIZE) { + gmem_output[idx] = shmem[idx]; + } + + wmma::fragment acc[M_BLOCKS][M_BLOCKS]; + + for (int i = 0; i < M_BLOCKS; i++) { + for (int j = 0; j < M_BLOCKS; j++) { + wmma::fill_fragment(acc[i][j], 0); + } + } + + for (int k_step = 0; k_step < num_col_steps; k_step++) { + wmma::fragment a[M_BLOCKS]; + wmma::fragment b[M_BLOCKS]; + for (int j = 0; j < M_BLOCKS; j++) { + int base_row = (j < M_BLOCKS - 1) ? j * 16 : smem_rows_per_warp - 16; + const half *tile_ptr = shmem + (base_row * SMEM_STRIDE + k_step * 16); + wmma::load_matrix_sync(a[j], tile_ptr, SMEM_STRIDE); + wmma::load_matrix_sync(b[j], tile_ptr, SMEM_STRIDE); + } + for (int i = 0; i < M_BLOCKS; i++) { + for (int j = 0; j < M_BLOCKS; j++) { + wmma::mma_sync(acc[i][j], a[i], b[j], acc[i][j]); + } + } + } + float *shmem_store = reinterpret_cast(shmem); + for (int i = 0; i < M_BLOCKS; i++) { + for (int j = 0; j < M_BLOCKS; j++) { + float *tile_ptr = shmem_store + (i * 16 * SMEM_STRIDE_ACC + j * 16); + wmma::store_matrix_sync(tile_ptr, acc[i][j], SMEM_STRIDE_ACC, wmma::mem_row_major); + } + } + + half *gmem_interact_output = gmem_output + num_cols; + int lastRowBlockOffset = M_BLOCKS * 16 - smem_rows_per_warp; + int srcLine = 0; + for (int i = 0; i < num_rows; ++i, ++srcLine) { + if (i == ((M_BLOCKS - 1) * 16)) { + srcLine += lastRowBlockOffset; + } + if (lane_id < i) { + uint offset = (i * (i - 1)) >> 1; + gmem_interact_output[offset + lane_id] = __float2half(shmem_store[srcLine * SMEM_STRIDE_ACC + lane_id]); + } + } + // Padding + if (lane_id == 0) { + gmem_output[output_size - 1] = __float2half(0); + } +} + +template +__launch_bounds__(THREADBLOCK_SIZE) __global__ void dotBasedInteractFwdKernel(const __half *__restrict input, + __half *__restrict output, + uint batch_size, + uint num_rows, + uint num_cols, + uint num_rows_after_padding, + uint num_cols_after_padding, + uint smem_elems_per_warp, + uint smem_rows_per_warp, + uint output_size, + uint num_row_steps, + uint num_col_steps) { + uint warp_id = (threadIdx.x >> WARP_SIZE_LOG_2); + int sample_id = blockIdx.x * WARPS_PER_BLOCK + warp_id; + if (sample_id >= batch_size) { + return; + } + int lane_id = threadIdx.x & (WARP_SIZE - 1); + + extern __shared__ half shmem_dynamic[]; + half *shmem = shmem_dynamic + (warp_id * smem_elems_per_warp); + + const half *sample_input = input + num_rows * num_cols * sample_id; + if (lane_id < (num_cols >> 2)) { + for (int i = 0; i < num_rows; ++i, sample_input += num_cols) { + ((float2 *)(shmem + i * SMEM_STRIDE))[lane_id] = ((float2 *)sample_input)[lane_id]; + } + } + + uint idx = lane_id + num_cols; + if (idx < num_cols_after_padding) { + for (int i = 0; i < num_rows; ++i) { + (shmem + i * SMEM_STRIDE)[idx] = __float2half(0); + } + } + + half4 zeros; + zeros.vals[0].x = __float2half(0); + zeros.vals[0].y = __float2half(0); + zeros.vals[1].x = __float2half(0); + zeros.vals[1].y = __float2half(0); + if (lane_id < (num_cols_after_padding >> 2)) { + for (int i = num_rows; i < num_rows_after_padding; i++) { + ((half4 *)(shmem + i * SMEM_STRIDE))[lane_id] = zeros; + } + } + __syncwarp(); + half *gmem_output = output + output_size * sample_id; + if (lane_id < (num_cols >> 2)) { + ((float2 *)gmem_output)[lane_id] = ((float2 *)shmem)[lane_id]; + } + + wmma::fragment acc[M_BLOCKS][M_BLOCKS]; + + for (int i = 0; i < M_BLOCKS; i++) { + for (int j = 0; j < M_BLOCKS; j++) { + wmma::fill_fragment(acc[i][j], 0); + } + } + + for (int k_step = 0; k_step < num_col_steps; k_step++) { + wmma::fragment a[M_BLOCKS]; + wmma::fragment b[M_BLOCKS]; + for (int j = 0; j < M_BLOCKS; j++) { + int base_row = (j < M_BLOCKS - 1) ? j * 16 : smem_rows_per_warp - 16; + const half *tile_ptr = shmem + (base_row * SMEM_STRIDE + k_step * 16); + wmma::load_matrix_sync(a[j], tile_ptr, SMEM_STRIDE); + wmma::load_matrix_sync(b[j], tile_ptr, SMEM_STRIDE); + } + for (int i = 0; i < M_BLOCKS; i++) { + for (int j = 0; j < M_BLOCKS; j++) { + wmma::mma_sync(acc[i][j], a[i], b[j], acc[i][j]); + } + } + } + float *shmem_store = reinterpret_cast(shmem); + for (int i = 0; i < M_BLOCKS; i++) { + for (int j = 0; j < M_BLOCKS; j++) { + float *tile_ptr = shmem_store + (i * 16 * SMEM_STRIDE_ACC + j * 16); + wmma::store_matrix_sync(tile_ptr, acc[i][j], SMEM_STRIDE_ACC, wmma::mem_row_major); + } + } + + half *gmem_interact_output = gmem_output + num_cols; + int lastRowBlockOffset = M_BLOCKS * 16 - smem_rows_per_warp; + int srcLine = 0; + for (int i = 0; i < num_rows; ++i, ++srcLine) { + if (i == ((M_BLOCKS - 1) * 16)) { + srcLine += lastRowBlockOffset; + } + if (lane_id < i) { + uint offset = (i * (i - 1)) >> 1; + gmem_interact_output[offset + lane_id] = __float2half(shmem_store[srcLine * SMEM_STRIDE_ACC + lane_id]); + } + } + // Padding + if (lane_id == 0) { + gmem_output[output_size - 1] = __float2half(0); + } +} + +template +__launch_bounds__(THREADBLOCK_SIZE) __global__ + void dotBasedInteractBwdKernelNonAligned(const __half *__restrict input, + const __half *__restrict upstream_grad, + half __restrict *grad, + half __restrict *bottom_mlp_grad, + uint batch_size, + uint num_rows, + uint num_cols, + uint num_rows_after_padding, + uint num_cols_after_padding, + uint sample_size, + uint interaction_ugrad_size, + uint interaction_ugrad_size_with_padding, + uint interaction_ugrad_2D_size_elems, + uint interaction_ugrad_2D_stride, + uint input_size_elems, + uint input_stride, + uint num_row_steps, + uint num_col_steps, + uint row_tiles_per_step, + uint shared_mem_per_warp_size_byte) { + extern __shared__ half shared_mem[]; + uint warp_id = (threadIdx.x >> WARP_SIZE_LOG_2); + uint sample_id = blockIdx.x * WARPS_PER_BLOCK + warp_id; + if (sample_id >= batch_size) { + return; + } + uint lane_id = threadIdx.x & (WARP_SIZE - 1); + // ">> 1" to convert to half pointer + uint smem_warp_offset = warp_id * (shared_mem_per_warp_size_byte >> 1); + + half *smem_in = &shared_mem[smem_warp_offset]; + half *smem_temp = &shared_mem[smem_warp_offset + input_size_elems]; + float *smem_out = reinterpret_cast(smem_temp); + + // Global memory pointers for the current sample + // Input + uint gmem_input_sample_offset = sample_id * sample_size; + const half *gmem_input = &input[gmem_input_sample_offset]; + + // Interaction Gradient + const uint &gmem_grad_sample_offset = gmem_input_sample_offset; + half *gmem_grad = &grad[gmem_grad_sample_offset]; + + // Bottom MLP gradient + half *gmem_mlp_grad = &bottom_mlp_grad[sample_id * num_cols]; + + // Upstream gradient vector + uint gmem_ugrad_sample_offset = sample_id * (num_cols + interaction_ugrad_size_with_padding); + const half *gmem_ugrad = &upstream_grad[gmem_ugrad_sample_offset]; + + // Upstream gradient vector for interactions + const half *gmem_ugrad_interactions = &gmem_ugrad[num_cols]; + + // upstream grad -> shared memory (place in input section temporarily) +#pragma unroll + for (uint idx = lane_id; idx < interaction_ugrad_size; idx += WARP_SIZE) { + smem_in[idx] = gmem_ugrad_interactions[idx]; + } + __syncwarp(); + // Form the 2D ugrad matrix. + if (lane_id < num_rows_after_padding) { + uint ugrad_flat_index = ((lane_id * (lane_id - 1)) >> 1); + uint ugrad_offset_1 = lane_id * interaction_ugrad_2D_stride; + for (uint row = 0; row < num_rows; row++) { + half ugrad_val = __float2half(0.0f); + if (row < lane_id && lane_id < num_rows) { + ugrad_val = smem_in[ugrad_flat_index + row]; + smem_temp[ugrad_offset_1 + row] = ugrad_val; + } + if (row <= lane_id && lane_id < num_rows_after_padding) { + smem_temp[row * interaction_ugrad_2D_stride + lane_id] = ugrad_val; + } + } + for (uint row = num_rows; row < num_rows_after_padding; row++) { + smem_temp[row * interaction_ugrad_2D_stride + lane_id] = __float2half(0.0f); + } + } + __syncwarp(); + + // Input -> Shared Memory + + for (uint row = 0; row < num_rows; row++) { + half *smem_row_ptr = &smem_in[row * input_stride]; + const half *gmem_row_ptr = &gmem_input[row * num_cols]; + for (uint idx = lane_id; idx < num_cols; idx += WARP_SIZE) { + smem_row_ptr[idx] = gmem_row_ptr[idx]; + } + uint idx = lane_id + num_cols; + if (idx < num_cols_after_padding) { + smem_row_ptr[idx] = __float2half(0); + } + } + +#pragma unroll 2 + for (uint row = num_rows; row < num_rows_after_padding; row++) { + half *smem_row_ptr = &smem_in[row * input_stride]; + for (uint idx = lane_id; idx < num_cols_after_padding; idx += WARP_SIZE) { + smem_row_ptr[idx] = __float2half(0); + } + } + __syncwarp(); + + wmma::fragment a[ROW_TILES_PER_STEP] + [ROW_TILES_PER_STEP]; + for (uint i = 0; i < ROW_TILES_PER_STEP; i++) { + for (uint j = 0; j < ROW_TILES_PER_STEP; j++) { + const half *tile_ptr = smem_temp + ((i * interaction_ugrad_2D_stride + j) << TILE_DIM_LOG_2); + wmma::load_matrix_sync(a[i][j], tile_ptr, interaction_ugrad_2D_stride); + } + } + + wmma::fragment acc[ROW_TILES_PER_STEP]; + wmma::fragment b[ROW_TILES_PER_STEP]; + for (int col_step = 0; col_step < num_col_steps; col_step++) { + for (uint i = 0; i < ROW_TILES_PER_STEP; i++) { + const half *tile_ptr = smem_in + ((i * input_stride + col_step) << TILE_DIM_LOG_2); + wmma::fill_fragment(acc[i], 0); + wmma::load_matrix_sync(b[i], tile_ptr, input_stride); + } + for (uint i = 0; i < ROW_TILES_PER_STEP; i++) { + for (uint j = 0; j < ROW_TILES_PER_STEP; j++) { + wmma::mma_sync(acc[i], a[i][j], b[j], acc[i]); + } + } + for (uint i = 0; i < ROW_TILES_PER_STEP; i++) { + float *tile_ptr = smem_out + i * TILE_DIM * TILE_DIM; + wmma::store_matrix_sync(tile_ptr, acc[i], TILE_DIM, wmma::mem_row_major); + } + __syncwarp(); + uint gmem_grad_col = (col_step << TILE_DIM_LOG_2) + lane_id; + if (gmem_grad_col < num_cols) { + for (uint i = 0; i < num_rows; i++) { + gmem_grad[i * num_cols + gmem_grad_col] = __float2half(smem_out[(i << TILE_DIM_LOG_2) + lane_id]); + } + } + } + + for (uint idx = lane_id; idx < num_cols; idx += WARP_SIZE) { + gmem_mlp_grad[idx] = gmem_ugrad[idx]; + } +} + +template +__launch_bounds__(THREADBLOCK_SIZE) __global__ void dotBasedInteractBwdKernel(const __half *__restrict input, + const __half *__restrict upstream_grad, + half __restrict *grad, + half __restrict *bottom_mlp_grad, + uint batch_size, + uint num_rows, + uint num_cols, + uint num_rows_after_padding, + uint num_cols_after_padding, + uint sample_size, + uint interaction_ugrad_size, + uint interaction_ugrad_size_with_padding, + uint interaction_ugrad_2D_size_elems, + uint interaction_ugrad_2D_stride, + uint input_size_elems, + uint input_stride, + uint num_row_steps, + uint num_col_steps, + uint row_tiles_per_step, + uint shared_mem_per_warp_size_byte) { + extern __shared__ half shared_mem[]; + uint warp_id = (threadIdx.x >> WARP_SIZE_LOG_2); + uint sample_id = blockIdx.x * WARPS_PER_BLOCK + warp_id; + if (sample_id >= batch_size) { + return; + } + uint lane_id = threadIdx.x & (WARP_SIZE - 1); + // ">> 1" to convert to half pointer + uint smem_warp_offset = warp_id * (shared_mem_per_warp_size_byte >> 1); + + half *smem_in = &shared_mem[smem_warp_offset]; + half *smem_temp = &shared_mem[smem_warp_offset + input_size_elems]; + float *smem_out = reinterpret_cast(smem_temp); + + // Global memory pointers for the current sample + // Input + uint gmem_input_sample_offset = sample_id * sample_size; + const half *gmem_input = &input[gmem_input_sample_offset]; + + // Interaction Gradient + const uint &gmem_grad_sample_offset = gmem_input_sample_offset; + half *gmem_grad = &grad[gmem_grad_sample_offset]; + + // Bottom MLP gradient + half *gmem_mlp_grad = &bottom_mlp_grad[sample_id * num_cols]; + + // Upstream gradient vector + uint gmem_ugrad_sample_offset = sample_id * (num_cols + interaction_ugrad_size_with_padding); + const half *gmem_ugrad = &upstream_grad[gmem_ugrad_sample_offset]; + + // Upstream gradient vector for interactions + const half *gmem_ugrad_interactions = &gmem_ugrad[num_cols]; + + // upstream grad -> shared memory (place in input section temporarily) +#pragma unroll + for (uint idx = lane_id; idx < (interaction_ugrad_size >> 3); idx += WARP_SIZE) { + ((float4 *)smem_in)[idx] = ((float4 *)gmem_ugrad_interactions)[idx]; + } + uint offset = (interaction_ugrad_size >> 3) << 3; + for (uint idx = lane_id + offset; idx < interaction_ugrad_size; idx += WARP_SIZE) { + smem_in[idx] = gmem_ugrad_interactions[idx]; + } + __syncwarp(); + // Form the 2D ugrad matrix. + if (lane_id < num_rows_after_padding) { + uint ugrad_flat_index = ((lane_id * (lane_id - 1)) >> 1); + uint ugrad_offset_1 = lane_id * interaction_ugrad_2D_stride; + for (uint row = 0; row < num_rows; row++) { + half ugrad_val = __float2half(0.0f); + if (row < lane_id && lane_id < num_rows) { + ugrad_val = smem_in[ugrad_flat_index + row]; + smem_temp[ugrad_offset_1 + row] = ugrad_val; + } + if (row <= lane_id && lane_id < num_rows_after_padding) { + smem_temp[row * interaction_ugrad_2D_stride + lane_id] = ugrad_val; + } + } + for (uint row = num_rows; row < num_rows_after_padding; row++) { + smem_temp[row * interaction_ugrad_2D_stride + lane_id] = __float2half(0.0f); + } + } + __syncwarp(); + + // Input -> Shared Memory + + if (lane_id < (num_cols >> 2)) { + for (uint row = 0; row < num_rows; row++) { + half *smem_row_ptr = &smem_in[row * input_stride]; + const half *gmem_row_ptr = &gmem_input[row * num_cols]; + ((float2 *)smem_row_ptr)[lane_id] = ((float2 *)gmem_row_ptr)[lane_id]; + } + } + + uint idx = lane_id + num_cols; + if (idx < num_cols_after_padding) { + for (uint row = 0; row < num_rows; row++) { + half *smem_row_ptr = &smem_in[row * input_stride]; + smem_row_ptr[idx] = __float2half(0); + } + } + + half4 zeros; + zeros.vals[0].x = __float2half(0); + zeros.vals[0].y = __float2half(0); + zeros.vals[1].x = __float2half(0); + zeros.vals[1].y = __float2half(0); + if (lane_id < (num_cols_after_padding >> 2)) { +#pragma unroll 2 + for (uint row = num_rows; row < num_rows_after_padding; row++) { + half *smem_row_ptr = &smem_in[row * input_stride]; + ((half4 *)smem_row_ptr)[lane_id] = zeros; + } + } + __syncwarp(); + + wmma::fragment a[ROW_TILES_PER_STEP] + [ROW_TILES_PER_STEP]; + for (uint i = 0; i < ROW_TILES_PER_STEP; i++) { + for (uint j = 0; j < ROW_TILES_PER_STEP; j++) { + const half *tile_ptr = smem_temp + ((i * interaction_ugrad_2D_stride + j) << TILE_DIM_LOG_2); + wmma::load_matrix_sync(a[i][j], tile_ptr, interaction_ugrad_2D_stride); + } + } + + wmma::fragment acc[ROW_TILES_PER_STEP]; + wmma::fragment b[ROW_TILES_PER_STEP]; + for (int col_step = 0; col_step < num_col_steps; col_step++) { + for (uint i = 0; i < ROW_TILES_PER_STEP; i++) { + const half *tile_ptr = smem_in + ((i * input_stride + col_step) << TILE_DIM_LOG_2); + wmma::fill_fragment(acc[i], 0); + wmma::load_matrix_sync(b[i], tile_ptr, input_stride); + } + for (uint i = 0; i < ROW_TILES_PER_STEP; i++) { + for (uint j = 0; j < ROW_TILES_PER_STEP; j++) { + wmma::mma_sync(acc[i], a[i][j], b[j], acc[i]); + } + } + for (uint i = 0; i < ROW_TILES_PER_STEP; i++) { + float *tile_ptr = smem_out + i * TILE_DIM * TILE_DIM; + wmma::store_matrix_sync(tile_ptr, acc[i], TILE_DIM, wmma::mem_row_major); + } + __syncwarp(); + uint gmem_grad_col = (col_step << TILE_DIM_LOG_2) + lane_id; + if (gmem_grad_col < num_cols) { + for (uint i = 0; i < num_rows; i++) { + gmem_grad[i * num_cols + gmem_grad_col] = __float2half(smem_out[(i << TILE_DIM_LOG_2) + lane_id]); + } + } + } + if (lane_id < (num_cols >> 2)) { + ((float2 *)gmem_mlp_grad)[lane_id] = ((float2 *)gmem_ugrad)[lane_id]; + } +} diff --git a/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/ampere/dot_based_interact_ampere_tf32.cu.inl b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/ampere/dot_based_interact_ampere_tf32.cu.inl new file mode 100644 index 00000000..8c80fe51 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/ampere/dot_based_interact_ampere_tf32.cu.inl @@ -0,0 +1,346 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "../dot_based_interact_shared_utils.cu.h" + +using namespace nvcuda; + +using namespace nvcuda; + +template +__launch_bounds__(THREADBLOCK_SIZE) __global__ void dotBasedInteractTF32FwdKernel(const float *__restrict input, + float *__restrict output, + uint batch_size, + uint num_rows, + uint num_cols, + uint num_rows_after_padding, + uint num_cols_after_padding, + uint smem_elems_per_warp, + uint output_size, + uint num_row_steps, + uint num_col_steps, + uint smem_stride, + uint smem_stride_acc) { + // The only support sizes for TF32. + const uint kWmmaM = 16; + const uint kWmmaN = 16; + const uint kWmmaK = 8; + + uint warp_id = threadIdx.x >> WARP_SIZE_LOG_2; + uint sample_id = blockIdx.x * WARPS_PER_BLOCK + warp_id; + if (sample_id >= batch_size) { + return; + } + int lane_id = threadIdx.x & (WARP_SIZE - 1); + + extern __shared__ float shmem_dynamic_float[]; + float *shmem = shmem_dynamic_float + (warp_id * smem_elems_per_warp); + + const float *gmem_input = input + num_rows * num_cols * sample_id; + if (lane_id < (num_cols >> 2)) { + for (int i = 0; i < num_rows; ++i, gmem_input += num_cols) { + float4 tmp = ((float4 *)gmem_input)[lane_id]; + tmp.x = wmma::__float_to_tf32(tmp.x); + tmp.y = wmma::__float_to_tf32(tmp.y); + tmp.z = wmma::__float_to_tf32(tmp.z); + tmp.w = wmma::__float_to_tf32(tmp.w); + ((float4 *)(shmem + i * smem_stride))[lane_id] = tmp; + } + } + + float zero = wmma::__float_to_tf32(0.0f); + float4 zero4; + zero4.x = zero; + zero4.y = zero; + zero4.z = zero; + zero4.w = zero; + + uint idx = lane_id + num_cols; + if (idx < num_cols_after_padding) { + for (uint i = 0; i < num_rows; ++i) { + (shmem + i * smem_stride)[idx] = zero; + } + } + + if (lane_id < (num_cols_after_padding >> 2)) { + for (int i = num_rows; i < num_rows_after_padding; i++) { + ((float4 *)(shmem + i * smem_stride))[lane_id] = zero4; + } + } + __syncwarp(); + // TODO: MTMD - Copy directly without using shared memory + float *gmem_output = output + output_size * sample_id; + if (lane_id < (num_cols >> 2)) { + ((float4 *)gmem_output)[lane_id] = ((float4 *)shmem)[lane_id]; + } + + wmma::fragment acc[ROW_TILES_PER_STEP][ROW_TILES_PER_STEP]; + + for (int i = 0; i < ROW_TILES_PER_STEP; i++) { + for (int j = 0; j < ROW_TILES_PER_STEP; j++) { + wmma::fill_fragment(acc[i][j], zero); + } + } + + // TODO: MTMD - Loop promotion + for (int k_step = 0; k_step < num_col_steps; k_step++) { + wmma::fragment + a[ROW_TILES_PER_STEP]; + wmma::fragment + b[ROW_TILES_PER_STEP]; + for (int j = 0; j < ROW_TILES_PER_STEP; j++) { + int base_row = (j < ROW_TILES_PER_STEP - 1) ? j * 16 : num_rows_after_padding - 16; + const float *tile_ptr = shmem + (base_row * smem_stride + k_step * kWmmaK); + wmma::load_matrix_sync(a[j], tile_ptr, smem_stride); + wmma::load_matrix_sync(b[j], tile_ptr, smem_stride); + } + for (int i = 0; i < ROW_TILES_PER_STEP; i++) { + for (int j = 0; j < ROW_TILES_PER_STEP; j++) { + wmma::mma_sync(acc[i][j], a[i], b[j], acc[i][j]); + } + } + } + + for (int i = 0; i < ROW_TILES_PER_STEP; i++) { + for (int j = 0; j < ROW_TILES_PER_STEP; j++) { + float *tile_ptr = shmem + (i * kWmmaM * smem_stride_acc + j * kWmmaN); + wmma::store_matrix_sync(tile_ptr, acc[i][j], smem_stride_acc, wmma::mem_row_major); + } + } + + float *gmem_interact_output = gmem_output + num_cols; + int lastRowBlockOffset = ROW_TILES_PER_STEP * 16 - num_rows_after_padding; + int src_line = 0; + for (int i = 0; i < num_rows; ++i, ++src_line) { + if (i == ((ROW_TILES_PER_STEP - 1) * 16)) { + src_line += lastRowBlockOffset; + } + if (lane_id < i) { + uint offset = (i * (i - 1)) >> 1; + gmem_interact_output[offset + lane_id] = shmem[src_line * smem_stride_acc + lane_id]; + } + } + // Padding + if (lane_id == 0) { + gmem_output[output_size - 1] = 0; + } +} + +template +__launch_bounds__(THREADBLOCK_SIZE) __global__ + void dotBasedInteractTF32BwdKernel(const float *__restrict input, + const float *__restrict upstream_grad, + float *__restrict grad, + float *__restrict bottom_mlp_grad, + uint batch_size, + uint num_rows, + uint num_cols, + uint num_rows_after_padding, + uint num_cols_after_padding, + uint sample_size, + uint interaction_ugrad_size, + uint interaction_ugrad_size_with_padding, + uint interaction_ugrad_2D_size_elems, + uint interaction_ugrad_2D_stride, + uint input_size_elems, + uint input_stride, + uint shared_mem_per_warp_size_elems, + uint num_k_steps, + uint num_n_steps) { + // The only support sizes for TF32. + const uint kWmmaM = 16; + const uint kWmmaN = 16; + const uint kWmmaK = 8; + + extern __shared__ float shared_mem_float[]; + uint warp_id = threadIdx.x >> WARP_SIZE_LOG_2; + uint sample_id = blockIdx.x * WARPS_PER_BLOCK + warp_id; + if (sample_id >= batch_size) { + return; + } + uint lane_id = threadIdx.x & (WARP_SIZE - 1); + uint smem_warp_offset = warp_id * shared_mem_per_warp_size_elems; + + float *smem_in = &shared_mem_float[smem_warp_offset]; + float *smem_ugrad = &shared_mem_float[smem_warp_offset + input_size_elems]; + float *smem_out = &shared_mem_float[smem_warp_offset + input_size_elems + interaction_ugrad_2D_size_elems]; + + // Global memory pointers for the current sample + // Input + uint gmem_input_sample_offset = sample_id * sample_size; + const float *gmem_input = &input[gmem_input_sample_offset]; + + // Interaction Gradient + const uint &gmem_grad_sample_offset = gmem_input_sample_offset; + float *gmem_grad = &grad[gmem_grad_sample_offset]; + + // Bottom MLP gradient + float *gmem_mlp_grad = &bottom_mlp_grad[sample_id * num_cols]; + + // Upstream gradient vector + uint gmem_ugrad_sample_offset = sample_id * (num_cols + interaction_ugrad_size_with_padding); + const float *gmem_ugrad = &upstream_grad[gmem_ugrad_sample_offset]; + + // Upstream gradient vector for interactions + const float *gmem_ugrad_interactions = &gmem_ugrad[num_cols]; + + // upstream grad -> shared memory (place in input section temporarily) +#pragma unroll + for (uint idx = lane_id; idx < (interaction_ugrad_size >> 2); idx += WARP_SIZE) { + float4 tmp = ((float4 *)gmem_ugrad_interactions)[idx]; + tmp.x = wmma::__float_to_tf32(tmp.x); + tmp.y = wmma::__float_to_tf32(tmp.y); + tmp.z = wmma::__float_to_tf32(tmp.z); + tmp.w = wmma::__float_to_tf32(tmp.w); + ((float4 *)smem_in)[idx] = tmp; + } + uint offset = (interaction_ugrad_size >> 2) << 2; + for (uint idx = lane_id + offset; idx < interaction_ugrad_size; idx += WARP_SIZE) { + smem_in[idx] = wmma::__float_to_tf32(gmem_ugrad_interactions[idx]); + } + __syncwarp(); + + float zero = wmma::__float_to_tf32(0.0f); + float4 zero4; + zero4.x = zero; + zero4.y = zero; + zero4.z = zero; + zero4.w = zero; + // Form the 2D ugrad matrix. + if (lane_id < num_rows_after_padding) { + uint ugrad_flat_index = ((lane_id * (lane_id - 1)) >> 1); + uint ugrad_offset_1 = lane_id * interaction_ugrad_2D_stride; + for (uint row = 0; row < num_rows; row++) { + float ugrad_val = zero; + if (row < lane_id && lane_id < num_rows) { + ugrad_val = smem_in[ugrad_flat_index + row]; + smem_ugrad[ugrad_offset_1 + row] = ugrad_val; + } + if (row <= lane_id && lane_id < num_rows_after_padding) { + smem_ugrad[row * interaction_ugrad_2D_stride + lane_id] = ugrad_val; + } + } + for (uint row = num_rows; row < num_rows_after_padding; row++) { + smem_ugrad[row * interaction_ugrad_2D_stride + lane_id] = zero; + } + } + __syncwarp(); + + // Input -> Shared Memory + + if (lane_id < (num_cols >> 2)) { + for (uint row = 0; row < num_rows; row++) { + float *smem_row_ptr = &smem_in[row * input_stride]; + const float *gmem_row_ptr = &gmem_input[row * num_cols]; + float4 tmp = ((float4 *)gmem_row_ptr)[lane_id]; + tmp.x = wmma::__float_to_tf32(tmp.x); + tmp.y = wmma::__float_to_tf32(tmp.y); + tmp.z = wmma::__float_to_tf32(tmp.z); + tmp.w = wmma::__float_to_tf32(tmp.w); + ((float4 *)smem_row_ptr)[lane_id] = tmp; + } + } + + uint idx = lane_id + num_cols; + if (idx < num_cols_after_padding) { + for (uint row = 0; row < num_rows; row++) { + float *smem_row_ptr = &smem_in[row * input_stride]; + smem_row_ptr[idx] = zero; + } + } + + if (lane_id < (num_cols_after_padding >> 2)) { +#pragma unroll 2 + for (uint row = num_rows; row < num_rows_after_padding; row++) { + float *smem_row_ptr = &smem_in[row * input_stride]; + ((float4 *)smem_row_ptr)[lane_id] = zero4; + } + } + __syncwarp(); + + wmma::fragment a[FRAG_A_ROWS]; + wmma::fragment b[FRAG_B_COLS]; + wmma::fragment acc[FRAG_A_ROWS][FRAG_B_COLS]; + for (uint n = 0; n < num_n_steps; n++) { + for (uint i = 0; i < FRAG_A_ROWS; i++) { + for (uint j = 0; j < FRAG_B_COLS; j++) { + wmma::fill_fragment(acc[i][j], zero); + } + } + for (uint k = 0; k < num_k_steps; k++) { + for (uint i = 0; i < FRAG_A_ROWS; i++) { + const float *mat_a_tile_ptr = + smem_ugrad + (i << TILE_LENGTH_LOG_2) * interaction_ugrad_2D_stride + (k << TILE_WIDTH_LOG_2); + wmma::load_matrix_sync(a[i], mat_a_tile_ptr, interaction_ugrad_2D_stride); + } + for (uint j = 0; j < FRAG_B_COLS; j++) { + const float *mat_b_tile_ptr = + smem_in + (k << TILE_WIDTH_LOG_2) * input_stride + ((2 * n + j) << TILE_LENGTH_LOG_2); + wmma::load_matrix_sync(b[j], mat_b_tile_ptr, input_stride); + } + for (uint i = 0; i < FRAG_A_ROWS; i++) { + for (uint j = 0; j < FRAG_B_COLS; j++) { + wmma::mma_sync(acc[i][j], a[i], b[j], acc[i][j]); + } + } + } + // __syncwarp(); ? + uint out_stride = FRAG_B_COLS << TILE_LENGTH_LOG_2; + for (uint i = 0; i < FRAG_A_ROWS; i++) { + for (uint j = 0; j < FRAG_B_COLS; j++) { + float *out_tile_ptr = smem_out + (i << TILE_LENGTH_LOG_2) * out_stride + (j << TILE_LENGTH_LOG_2); + wmma::store_matrix_sync(out_tile_ptr, acc[i][j], out_stride, wmma::mem_row_major); + } + } + uint gmem_grad_col = n * (FRAG_B_COLS << TILE_LENGTH_LOG_2) + lane_id; + for (uint i = 0; i < num_rows; i++) { + gmem_grad[i * num_cols + gmem_grad_col] = smem_out[i * out_stride + lane_id]; + } + } + + if (lane_id < (num_cols >> 2)) { + ((float4 *)gmem_mlp_grad)[lane_id] = ((float4 *)gmem_ugrad)[lane_id]; + } +} diff --git a/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/dot_based_interact_grad_kernels.cc b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/dot_based_interact_grad_kernels.cc new file mode 100644 index 00000000..b598a62e --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/dot_based_interact_grad_kernels.cc @@ -0,0 +1,155 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +#include "tensorflow/core/framework/op_kernel.h" + +#include "volta/dot_based_interact_volta.h" +#include "ampere/dot_based_interact_ampere.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; +typedef Eigen::half half; + +namespace functor { + +template +struct DotBasedInteractGradFunctor { + void operator()(const Device& d, const T* input, const T* upstream_grad, T* grad, + T* bottom_mlp_grad, int64 batch_size, int64 num_rows, int64 num_cols); +}; + +template <> +struct DotBasedInteractGradFunctor { + void operator()(const GPUDevice& d, const float* input, const float* upstream_grad, float* grad, + float* bottom_mlp_grad, int64 batch_size, int64 num_rows, int64 num_cols) { + int major = d.majorDeviceVersion(); + if (major >= 8) { + dotBasedInteractAmpereTF32Bwd(input, + upstream_grad, + grad, + bottom_mlp_grad, + batch_size, + num_rows, + num_cols, + d.stream()); + } else if (major == 7) { + dotBasedInteractVoltaF32Bwd(input, + upstream_grad, + grad, + bottom_mlp_grad, + batch_size, + num_rows, + num_cols, + d.stream()); + } + } +}; + +template <> +struct DotBasedInteractGradFunctor { + void operator()(const GPUDevice& d, const half* input, const half* upstream_grad, half* grad, + half* bottom_mlp_grad, int64 batch_size, int64 num_rows, int64 num_cols) { + int major = d.majorDeviceVersion(); + if (major >= 8) { + dotBasedInteractAmpereF16Bwd(input, + upstream_grad, + grad, + bottom_mlp_grad, + batch_size, + num_rows, + num_cols, + d.stream()); + } else if (major == 7) { + dotBasedInteractVoltaF16Bwd(input, + upstream_grad, + grad, + bottom_mlp_grad, + batch_size, + num_rows, + num_cols, + d.stream()); + } + } +}; + + +// OpKernel definition. +// template parameter is the datatype of the tensors. +template +class DotBasedInteractGradOp : public OpKernel { + public: + explicit DotBasedInteractGradOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& input_tensor = context->input(0); + + // Grab the bottom_mlp_output tensor + const Tensor& upstream_grad_tensor = context->input(1); + + // Calculate the output tensor shape + TensorShape input_shape = input_tensor.shape(); + int64 batch_size = input_shape.dim_size(0); + int64 num_rows = input_shape.dim_size(1); + int64 num_cols = input_shape.dim_size(2); + TensorShape bottom_mlp_grad_shape({batch_size, num_cols}); + + // Create the grad output tensor + Tensor* grad_tensor = NULL; + OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, + &grad_tensor)); + + // Create the bottom mlp grad output tensor + Tensor* bottom_mlp_grad_tensor = NULL; + OP_REQUIRES_OK(context, context->allocate_output(1, bottom_mlp_grad_shape, + &bottom_mlp_grad_tensor)); + + // Do the computation. + OP_REQUIRES(context, input_tensor.NumElements() <= tensorflow::kint32max, + errors::InvalidArgument("Too many elements in tensor")); + + // GPU Architecture + GPUDevice device = ((GPUDevice) context->eigen_device()); + OP_REQUIRES(context, device.majorDeviceVersion() >= 7, + errors::InvalidArgument("GPU not supported (need Volta or higher)")); + + DotBasedInteractGradFunctor()( + device, + input_tensor.flat().data(), + upstream_grad_tensor.flat().data(), + grad_tensor->flat().data(), + bottom_mlp_grad_tensor->flat().data(), + batch_size, + num_rows, + num_cols); + } +}; + +// Register the GPU kernels. +#ifdef GOOGLE_CUDA +#define REGISTER_GPU(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("DotBasedInteractGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + DotBasedInteractGradOp); +REGISTER_GPU(float); +REGISTER_GPU(half); +#endif // GOOGLE_CUDA +} +} // namespace tensorflow diff --git a/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/dot_based_interact_kernels.cc b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/dot_based_interact_kernels.cc new file mode 100644 index 00000000..416088cf --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/dot_based_interact_kernels.cc @@ -0,0 +1,147 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +#include "tensorflow/core/framework/op_kernel.h" + +#include "volta/dot_based_interact_volta.h" +#include "ampere/dot_based_interact_ampere.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; +typedef Eigen::half half; + +namespace functor { + +template +struct DotBasedInteractFunctor { + void operator()(const Device& d, const T* input, const T* bottom_mlp_output, + T* output, int64 batch_size, int64 num_rows, int64 num_cols); +}; + +template <> +struct DotBasedInteractFunctor { + void operator()(const GPUDevice& d, const float* input, const float* bottom_mlp_output, + float* output, int64 batch_size, int64 num_rows, int64 num_cols) { + int major = d.majorDeviceVersion(); + if (major >= 8) { + dotBasedInteractAmpereTF32Fwd(input, + bottom_mlp_output, + output, + batch_size, + num_rows, + num_cols, + d.stream()); + } else if (major == 7) { + dotBasedInteractVoltaF32Fwd(input, + bottom_mlp_output, + output, + batch_size, + num_rows, + num_cols, + d.stream()); + } + } +}; + +template <> +struct DotBasedInteractFunctor { + void operator()(const GPUDevice& d, const half* input, const half* bottom_mlp_output, + half* output, int64 batch_size, int64 num_rows, int64 num_cols) { + int major = d.majorDeviceVersion(); + if (major >= 8) { + dotBasedInteractAmpereF16Fwd(input, + bottom_mlp_output, + output, + batch_size, + num_rows, + num_cols, + d.stream()); + } else if (major == 7) { + dotBasedInteractVoltaF16Fwd(input, + bottom_mlp_output, + output, + batch_size, + num_rows, + num_cols, + d.stream()); + } + } +}; + + +// OpKernel definition. +// template parameter is the datatype of the tensors. +template +class DotBasedInteractOp : public OpKernel { + public: + explicit DotBasedInteractOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& input_tensor = context->input(0); + + // Grab the bottom_mlp_output tensor + const Tensor& bottom_mlp_output_tensor = context->input(1); + + // Calculate the output tensor shape + TensorShape input_shape = input_tensor.shape(); + const int64 pad = 1; + int64 batch_size = input_shape.dim_size(0); + int64 num_rows = input_shape.dim_size(1); + int64 num_cols = input_shape.dim_size(2); + int64 output_size = ((num_rows * (num_rows - 1)) >> 1) + num_cols + pad; + TensorShape output_shape({batch_size, output_size}); + + // Create an output tensor + Tensor* output_tensor = NULL; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, + &output_tensor)); + + // Do the computation. + OP_REQUIRES(context, input_tensor.NumElements() <= tensorflow::kint32max, + errors::InvalidArgument("Too many elements in tensor")); + + // GPU Architecture + GPUDevice device = ((GPUDevice) context->eigen_device()); + OP_REQUIRES(context, device.majorDeviceVersion() >= 7, + errors::InvalidArgument("GPU not supported (need Volta or higher)")); + + DotBasedInteractFunctor()( + device, + input_tensor.flat().data(), + bottom_mlp_output_tensor.flat().data(), + output_tensor->flat().data(), + batch_size, + num_rows, + num_cols); + } +}; + +// Register the GPU kernels. +#ifdef GOOGLE_CUDA +#define REGISTER_GPU(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("DotBasedInteract").Device(DEVICE_GPU).TypeConstraint("T"), \ + DotBasedInteractOp); +REGISTER_GPU(float); +REGISTER_GPU(half); +#endif // GOOGLE_CUDA +} +} // namespace tensorflow diff --git a/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/dot_based_interact_shared_utils.cu.h b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/dot_based_interact_shared_utils.cu.h new file mode 100644 index 00000000..6b32836e --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/dot_based_interact_shared_utils.cu.h @@ -0,0 +1,41 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KERNEL_DOT_BASED_INTERACT_SHARED_UTILS_H_ +#define KERNEL_DOT_BASED_INTERACT_SHARED_UTILS_H_ + +#include + +#define CHK_CUDA(expression) \ +{ \ + cudaError_t status = (expression); \ + if (status != cudaSuccess) { \ + std::cerr << "Error in file: " << __FILE__ << ", on line: " << __LINE__ << ": " << cudaGetErrorString(status) \ + << std::endl; \ + std::exit(EXIT_FAILURE); \ + } \ +} + +template +struct Log2 { + static constexpr uint value = 1 + Log2::value; +}; + +template <> +struct Log2<1> { + static constexpr uint value = 0; +}; + +#endif //KERNEL_DOT_BASED_INTERACT_SHARED_UTILS_H_ diff --git a/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/volta/dot_based_interact_volta.cu.cc b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/volta/dot_based_interact_volta.cu.cc new file mode 100644 index 00000000..04fd2403 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/volta/dot_based_interact_volta.cu.cc @@ -0,0 +1,337 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "dot_based_interact_volta.h" +#include "dot_based_interact_volta.cu.inl" + +void dotBasedInteractVoltaF16Fwd(const void *input, + const void *bottom_mlp_output, + void *output, + uint batch_size, + uint num_rows, + uint num_cols, + cudaStream_t stream) { + const uint kWarpSize = 32; + const uint kWarpSizeLog2 = Log2::value; + const uint kTileDim = 16; + const uint kTileDimLog2 = Log2::value; + const uint warps_per_threadblock = 4; + const uint threadblock_size = warps_per_threadblock * 32; + const uint kPaddingSize = 1; + const uint kRowTilesPerStep = 2; + const uint kColTilesPerStep = 1; + + // num tiles + uint num_row_tiles = (num_rows + kTileDim - 1) >> kTileDimLog2; + uint num_col_tiles = (num_cols + kTileDim - 1) >> kTileDimLog2; + + // number of rows and columns after padding + uint num_rows_after_padding = kTileDim << 1; + uint num_cols_after_padding = num_col_tiles << kTileDimLog2; + + uint num_row_steps = num_row_tiles / kRowTilesPerStep; + uint num_col_steps = num_col_tiles / kColTilesPerStep; + + const uint K_BLOCKS = 8; + const uint M_BLOCKS = 2; + const uint SKEW_HALF = ((K_BLOCKS % 2) == 0) ? 8 : 0; + const uint SMEM_STRIDE = (K_BLOCKS * 16 + SKEW_HALF); + // multiple of 2 to guarantee 256-bit alignment for start of the row, at least 16 to safeload a tile + const uint smem_rows_per_warp = M_BLOCKS << 4; + const uint smem_elems_per_warp_mat = smem_rows_per_warp * SMEM_STRIDE; + const uint SKEW_HALF_ACC = ((M_BLOCKS % 2) == 0) ? 8 : 0; + const uint SMEM_STRIDE_ACC = (M_BLOCKS * 16 + SKEW_HALF_ACC); + const uint smem_elems_per_warp_acc = M_BLOCKS * 16 * SMEM_STRIDE_ACC * 2; // output in FP32 + const uint smem_elems_per_warp = + (smem_elems_per_warp_mat > smem_elems_per_warp_acc) ? smem_elems_per_warp_mat : smem_elems_per_warp_acc; + uint output_size = num_cols + ((num_rows * (num_rows - 1)) >> 1) + kPaddingSize; + + bool float4_predicate = !((num_cols & 7) || (output_size & 7)); + + if (float4_predicate) { + dotBasedInteractFwdKernel + <<<(batch_size + warps_per_threadblock - 1) / warps_per_threadblock, + threadblock_size, + warps_per_threadblock * smem_elems_per_warp * sizeof(__half), stream>>>((const __half *)input, + (half *)output, + batch_size, + num_rows, + num_cols, + num_rows_after_padding, + num_cols_after_padding, + smem_elems_per_warp, + smem_rows_per_warp, + output_size, + num_row_steps, + num_col_steps); + } else { + dotBasedInteractFwdKernelNonAligned + <<<(batch_size + warps_per_threadblock - 1) / warps_per_threadblock, + threadblock_size, + warps_per_threadblock * smem_elems_per_warp * sizeof(__half), stream>>>((const __half *)input, + (half *)output, + batch_size, + num_rows, + num_cols, + num_rows_after_padding, + num_cols_after_padding, + smem_elems_per_warp, + smem_rows_per_warp, + output_size, + num_row_steps, + num_col_steps); + } +} + +void dotBasedInteractVoltaF16Bwd(const void *input, + const void *upstream_grad, + void *grad, + void *bottom_mlp_grad, + uint batch_size, + uint num_rows, + uint num_cols, + cudaStream_t stream) { + const uint kWarpSize = 32; + const uint kWarpSizeLog2 = Log2::value; + const uint kTileDim = 16; + const uint kTileDimLog2 = Log2::value; + const uint mem_skew_size = 8; + const uint kPaddingSize = 1; + const uint kWarpsPerBlock = 4; + const uint kWarpsPerBlockLog2 = Log2::value; + const uint kNumThreads = kWarpsPerBlock * kWarpSize; + const uint kRowTilesPerStep = 2; + const uint kColTilesPerStep = 1; + + uint row_tiles_per_step = num_rows > kTileDim ? kRowTilesPerStep : 1; + + // num tiles + uint num_row_tiles = (num_rows + kTileDim - 1) >> kTileDimLog2; + uint num_col_tiles = (num_cols + kTileDim - 1) >> kTileDimLog2; + + // number of rows and columns after padding + uint num_rows_after_padding = kTileDim << 1; + uint num_cols_after_padding = num_col_tiles << kTileDimLog2; + + // 2D ugrad size and stride + uint interaction_ugrad_2D_stride = num_rows_after_padding + mem_skew_size; + uint interaction_ugrad_2D_size_elems = num_rows_after_padding * interaction_ugrad_2D_stride; + uint interaction_ugrad_2D_size_bytes = interaction_ugrad_2D_size_elems * sizeof(half); + + // 1D ugrad size + uint interaction_ugrad_size = num_rows * (num_rows - 1) >> 1; + uint interaction_ugrad_size_with_padding = interaction_ugrad_size + kPaddingSize; + + // in_out place size and stride + uint input_stride = num_cols_after_padding + mem_skew_size; + uint input_size_elems = num_rows_after_padding * input_stride; + uint input_size_bytes = input_size_elems * sizeof(half); + + // sample size + uint sample_size = num_rows * num_cols; + + // output size + uint output_size_elems = kTileDim * kTileDim * kRowTilesPerStep * kColTilesPerStep; + uint output_size_bytes = output_size_elems * sizeof(float); + + // staging area size + uint staging_area_size_bytes = + output_size_bytes > interaction_ugrad_2D_size_bytes ? output_size_bytes : interaction_ugrad_2D_size_bytes; + + // Shared memory size + uint shared_mem_per_warp_size_byte = input_size_bytes + staging_area_size_bytes; + uint shared_mem_size_bytes = kWarpsPerBlock * shared_mem_per_warp_size_byte; + + uint num_blocks = (batch_size + kWarpsPerBlock - 1) >> kWarpsPerBlockLog2; + uint num_row_steps = num_row_tiles / row_tiles_per_step; + uint num_col_steps = num_col_tiles / kColTilesPerStep; + + bool float4_predicate = !((interaction_ugrad_size_with_padding & 7) || (num_cols & 7)); + if (float4_predicate) { + dotBasedInteractBwdKernel + <<>>((const half *)input, + (const half *)upstream_grad, + (half *)grad, + (half *)bottom_mlp_grad, + batch_size, + num_rows, + num_cols, + num_rows_after_padding, + num_cols_after_padding, + sample_size, + interaction_ugrad_size, + interaction_ugrad_size_with_padding, + interaction_ugrad_2D_size_elems, + interaction_ugrad_2D_stride, + input_size_elems, + input_stride, + num_row_steps, + num_col_steps, + row_tiles_per_step, + shared_mem_per_warp_size_byte); + } else { + dotBasedInteractBwdKernelNonAligned + <<>>((const half *)input, + (const half *)upstream_grad, + (half *)grad, + (half *)bottom_mlp_grad, + batch_size, + num_rows, + num_cols, + num_rows_after_padding, + num_cols_after_padding, + sample_size, + interaction_ugrad_size, + interaction_ugrad_size_with_padding, + interaction_ugrad_2D_size_elems, + interaction_ugrad_2D_stride, + input_size_elems, + input_stride, + num_row_steps, + num_col_steps, + row_tiles_per_step, + shared_mem_per_warp_size_byte); + } +} + +void dotBasedInteractVoltaF32Fwd(const void *input, + const void *bottom_mlp_output, + void *output, + uint batch_size, + uint num_rows, + uint num_cols, + cudaStream_t stream) { + const uint kPaddingSize = 1; + const uint kNumThreads = 128; + uint num_blocks = batch_size; + + // Output + uint interaction_output_size = (num_rows * (num_rows - 1)) >> 1; + uint output_size = num_cols + interaction_output_size + kPaddingSize; + + // Input + uint input_size = num_rows * num_cols; + + uint shared_mem_size_elems = input_size; + uint shared_mem_size_bytes = shared_mem_size_elems << 2; // F32 Kernel + + bool float4_predicate = !((num_cols & 3) || (output_size & 3)); + + if (float4_predicate) { + dotBasedInteractF32FwdKernel + <<>>((const float *)input, + (float *)output, + batch_size, + num_rows, + num_cols, + input_size, + output_size, + interaction_output_size); + } else { + dotBasedInteractF32FwdKernelNonAligned + <<>>((const float *)input, + (float *)output, + batch_size, + num_rows, + num_cols, + input_size, + output_size, + interaction_output_size); + } +} + +void dotBasedInteractVoltaF32Bwd(const void *input, + const void *upstream_grad, + void *grad, + void *bottom_mlp_grad, + uint batch_size, + uint num_rows, + uint num_cols, + cudaStream_t stream) { + const uint kPaddingSize = 1; + const uint kNumThreads = 128; + + uint num_blocks = batch_size; + + uint input_size = num_rows * num_cols; + + // 1D ugrad size + uint interaction_ugrad_size = num_rows * (num_rows - 1) >> 1; + uint interaction_ugrad_size_with_padding = interaction_ugrad_size + kPaddingSize; + uint ugrad_size = num_cols + interaction_ugrad_size_with_padding; + + // input space + upstream grad space + uint smem_size_elems = input_size + interaction_ugrad_size; + uint smem_size_bytes = smem_size_elems << 2; // F32 Kernel + + bool float4_predicate = !((interaction_ugrad_size_with_padding & 3) || (num_cols & 3)); + if (float4_predicate) { + dotBasedInteractF32BwdKernel + <<>>((const float *)input, + (const float *)upstream_grad, + (float *)grad, + (float *)bottom_mlp_grad, + batch_size, + num_rows, + num_cols, + input_size, + ugrad_size, + interaction_ugrad_size); + } else { + dotBasedInteractF32BwdKernelNonAligned + <<>>((const float *)input, + (const float *)upstream_grad, + (float *)grad, + (float *)bottom_mlp_grad, + batch_size, + num_rows, + num_cols, + input_size, + ugrad_size, + interaction_ugrad_size); + } +} diff --git a/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/volta/dot_based_interact_volta.cu.inl b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/volta/dot_based_interact_volta.cu.inl new file mode 100644 index 00000000..871ed645 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/volta/dot_based_interact_volta.cu.inl @@ -0,0 +1,822 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +using namespace nvcuda; + +#include "../dot_based_interact_shared_utils.cu.h" + +struct __align__(8) half4 { + half2 vals[2]; +}; + +template +__launch_bounds__(THREADBLOCK_SIZE) __global__ void dotBasedInteractFwdKernelNonAligned(const __half *__restrict input, + __half *__restrict output, + uint batch_size, + uint num_rows, + uint num_cols, + uint num_rows_after_padding, + uint num_cols_after_padding, + uint smem_elems_per_warp, + uint smem_rows_per_warp, + uint output_size, + uint num_row_steps, + uint num_col_steps) { + uint warp_id = (threadIdx.x >> WARP_SIZE_LOG_2); + int sample_id = blockIdx.x * WARPS_PER_BLOCK + warp_id; + if (sample_id >= batch_size) { + return; + } + int lane_id = threadIdx.x & (WARP_SIZE - 1); + + extern __shared__ half shmem_dynamic[]; + half *shmem = shmem_dynamic + (warp_id * smem_elems_per_warp); + + const half *sample_input = input + num_rows * num_cols * sample_id; + for (uint i = 0; i < num_rows; ++i, sample_input += num_cols) { + for (uint idx = lane_id; idx < num_cols; idx += WARP_SIZE) { + (shmem + i * SMEM_STRIDE)[idx] = sample_input[idx]; + } + } + + uint idx = lane_id + num_cols; + if (idx < num_cols_after_padding) { + for (int i = 0; i < num_rows; ++i) { + (shmem + i * SMEM_STRIDE)[idx] = __float2half(0); + } + } + + half4 zeros; + zeros.vals[0].x = __float2half(0); + zeros.vals[0].y = __float2half(0); + zeros.vals[1].x = __float2half(0); + zeros.vals[1].y = __float2half(0); + if (lane_id < (num_cols_after_padding >> 2)) { + for (int i = num_rows; i < num_rows_after_padding; i++) { + ((half4 *)(shmem + i * SMEM_STRIDE))[lane_id] = zeros; + } + } + __syncwarp(); + half *gmem_output = output + output_size * sample_id; + + for (uint idx = lane_id; idx < num_cols; idx += WARP_SIZE) { + gmem_output[idx] = shmem[idx]; + } + + wmma::fragment acc[M_BLOCKS][M_BLOCKS]; + + for (int i = 0; i < M_BLOCKS; i++) { + for (int j = 0; j < M_BLOCKS; j++) { + wmma::fill_fragment(acc[i][j], 0); + } + } + + for (int k_step = 0; k_step < num_col_steps; k_step++) { + wmma::fragment a[M_BLOCKS]; + wmma::fragment b[M_BLOCKS]; + for (int j = 0; j < M_BLOCKS; j++) { + int base_row = (j < M_BLOCKS - 1) ? j * 16 : smem_rows_per_warp - 16; + const half *tile_ptr = shmem + (base_row * SMEM_STRIDE + k_step * 16); + wmma::load_matrix_sync(a[j], tile_ptr, SMEM_STRIDE); + wmma::load_matrix_sync(b[j], tile_ptr, SMEM_STRIDE); + } + for (int i = 0; i < M_BLOCKS; i++) { + for (int j = 0; j < M_BLOCKS; j++) { + wmma::mma_sync(acc[i][j], a[i], b[j], acc[i][j]); + } + } + } + float *shmem_store = reinterpret_cast(shmem); + for (int i = 0; i < M_BLOCKS; i++) { + for (int j = 0; j < M_BLOCKS; j++) { + float *tile_ptr = shmem_store + (i * 16 * SMEM_STRIDE_ACC + j * 16); + wmma::store_matrix_sync(tile_ptr, acc[i][j], SMEM_STRIDE_ACC, wmma::mem_row_major); + } + } + + half *gmem_interact_output = gmem_output + num_cols; + int lastRowBlockOffset = M_BLOCKS * 16 - smem_rows_per_warp; + int srcLine = 0; + for (int i = 0; i < num_rows; ++i, ++srcLine) { + if (i == ((M_BLOCKS - 1) * 16)) { + srcLine += lastRowBlockOffset; + } + if (lane_id < i) { + uint offset = (i * (i - 1)) >> 1; + gmem_interact_output[offset + lane_id] = __float2half(shmem_store[srcLine * SMEM_STRIDE_ACC + lane_id]); + } + } + // Padding + if (lane_id == 0) { + gmem_output[output_size - 1] = __float2half(0); + } +} + +template +__launch_bounds__(THREADBLOCK_SIZE) __global__ void dotBasedInteractFwdKernel(const __half *__restrict input, + __half *__restrict output, + uint batch_size, + uint num_rows, + uint num_cols, + uint num_rows_after_padding, + uint num_cols_after_padding, + uint smem_elems_per_warp, + uint smem_rows_per_warp, + uint output_size, + uint num_row_steps, + uint num_col_steps) { + uint warp_id = (threadIdx.x >> WARP_SIZE_LOG_2); + int sample_id = blockIdx.x * WARPS_PER_BLOCK + warp_id; + if (sample_id >= batch_size) { + return; + } + int lane_id = threadIdx.x & (WARP_SIZE - 1); + + extern __shared__ half shmem_dynamic[]; + half *shmem = shmem_dynamic + (warp_id * smem_elems_per_warp); + + const half *sample_input = input + num_rows * num_cols * sample_id; + if (lane_id < (num_cols >> 2)) { + for (int i = 0; i < num_rows; ++i, sample_input += num_cols) { + ((float2 *)(shmem + i * SMEM_STRIDE))[lane_id] = ((float2 *)sample_input)[lane_id]; + } + } + + uint idx = lane_id + num_cols; + if (idx < num_cols_after_padding) { + for (int i = 0; i < num_rows; ++i) { + (shmem + i * SMEM_STRIDE)[idx] = __float2half(0); + } + } + + half4 zeros; + zeros.vals[0].x = __float2half(0); + zeros.vals[0].y = __float2half(0); + zeros.vals[1].x = __float2half(0); + zeros.vals[1].y = __float2half(0); + if (lane_id < (num_cols_after_padding >> 2)) { + for (int i = num_rows; i < num_rows_after_padding; i++) { + ((half4 *)(shmem + i * SMEM_STRIDE))[lane_id] = zeros; + } + } + __syncwarp(); + half *gmem_output = output + output_size * sample_id; + if (lane_id < (num_cols >> 2)) { + ((float2 *)gmem_output)[lane_id] = ((float2 *)shmem)[lane_id]; + } + + wmma::fragment acc[M_BLOCKS][M_BLOCKS]; + + for (int i = 0; i < M_BLOCKS; i++) { + for (int j = 0; j < M_BLOCKS; j++) { + wmma::fill_fragment(acc[i][j], 0); + } + } + + for (int k_step = 0; k_step < num_col_steps; k_step++) { + wmma::fragment a[M_BLOCKS]; + wmma::fragment b[M_BLOCKS]; + for (int j = 0; j < M_BLOCKS; j++) { + int base_row = (j < M_BLOCKS - 1) ? j * 16 : smem_rows_per_warp - 16; + const half *tile_ptr = shmem + (base_row * SMEM_STRIDE + k_step * 16); + wmma::load_matrix_sync(a[j], tile_ptr, SMEM_STRIDE); + wmma::load_matrix_sync(b[j], tile_ptr, SMEM_STRIDE); + } + for (int i = 0; i < M_BLOCKS; i++) { + for (int j = 0; j < M_BLOCKS; j++) { + wmma::mma_sync(acc[i][j], a[i], b[j], acc[i][j]); + } + } + } + float *shmem_store = reinterpret_cast(shmem); + for (int i = 0; i < M_BLOCKS; i++) { + for (int j = 0; j < M_BLOCKS; j++) { + float *tile_ptr = shmem_store + (i * 16 * SMEM_STRIDE_ACC + j * 16); + wmma::store_matrix_sync(tile_ptr, acc[i][j], SMEM_STRIDE_ACC, wmma::mem_row_major); + } + } + + half *gmem_interact_output = gmem_output + num_cols; + int lastRowBlockOffset = M_BLOCKS * 16 - smem_rows_per_warp; + int srcLine = 0; + for (int i = 0; i < num_rows; ++i, ++srcLine) { + if (i == ((M_BLOCKS - 1) * 16)) { + srcLine += lastRowBlockOffset; + } + if (lane_id < i) { + uint offset = (i * (i - 1)) >> 1; + gmem_interact_output[offset + lane_id] = __float2half(shmem_store[srcLine * SMEM_STRIDE_ACC + lane_id]); + } + } + // Padding + if (lane_id == 0) { + gmem_output[output_size - 1] = __float2half(0); + } +} + +template +__launch_bounds__(THREADBLOCK_SIZE) __global__ + void dotBasedInteractBwdKernelNonAligned(const __half *__restrict input, + const __half *__restrict upstream_grad, + half __restrict *grad, + half __restrict *bottom_mlp_grad, + uint batch_size, + uint num_rows, + uint num_cols, + uint num_rows_after_padding, + uint num_cols_after_padding, + uint sample_size, + uint interaction_ugrad_size, + uint interaction_ugrad_size_with_padding, + uint interaction_ugrad_2D_size_elems, + uint interaction_ugrad_2D_stride, + uint input_size_elems, + uint input_stride, + uint num_row_steps, + uint num_col_steps, + uint row_tiles_per_step, + uint shared_mem_per_warp_size_byte) { + extern __shared__ half shared_mem[]; + uint warp_id = (threadIdx.x >> WARP_SIZE_LOG_2); + uint sample_id = blockIdx.x * WARPS_PER_BLOCK + warp_id; + if (sample_id >= batch_size) { + return; + } + uint lane_id = threadIdx.x & (WARP_SIZE - 1); + // ">> 1" to convert to half pointer + uint smem_warp_offset = warp_id * (shared_mem_per_warp_size_byte >> 1); + + half *smem_in = &shared_mem[smem_warp_offset]; + half *smem_temp = &shared_mem[smem_warp_offset + input_size_elems]; + float *smem_out = reinterpret_cast(smem_temp); + + // Global memory pointers for the current sample + // Input + uint gmem_input_sample_offset = sample_id * sample_size; + const half *gmem_input = &input[gmem_input_sample_offset]; + + // Interaction Gradient + const uint &gmem_grad_sample_offset = gmem_input_sample_offset; + half *gmem_grad = &grad[gmem_grad_sample_offset]; + + // Bottom MLP gradient + half *gmem_mlp_grad = &bottom_mlp_grad[sample_id * num_cols]; + + // Upstream gradient vector + uint gmem_ugrad_sample_offset = sample_id * (num_cols + interaction_ugrad_size_with_padding); + const half *gmem_ugrad = &upstream_grad[gmem_ugrad_sample_offset]; + + // Upstream gradient vector for interactions + const half *gmem_ugrad_interactions = &gmem_ugrad[num_cols]; + + // upstream grad -> shared memory (place in input section temporarily) +#pragma unroll + for (uint idx = lane_id; idx < interaction_ugrad_size; idx += WARP_SIZE) { + smem_in[idx] = gmem_ugrad_interactions[idx]; + } + __syncwarp(); + // Form the 2D ugrad matrix. + if (lane_id < num_rows_after_padding) { + uint ugrad_flat_index = ((lane_id * (lane_id - 1)) >> 1); + uint ugrad_offset_1 = lane_id * interaction_ugrad_2D_stride; + for (uint row = 0; row < num_rows; row++) { + half ugrad_val = __float2half(0.0f); + if (row < lane_id && lane_id < num_rows) { + ugrad_val = smem_in[ugrad_flat_index + row]; + smem_temp[ugrad_offset_1 + row] = ugrad_val; + } + if (row <= lane_id && lane_id < num_rows_after_padding) { + smem_temp[row * interaction_ugrad_2D_stride + lane_id] = ugrad_val; + } + } + for (uint row = num_rows; row < num_rows_after_padding; row++) { + smem_temp[row * interaction_ugrad_2D_stride + lane_id] = __float2half(0.0f); + } + } + __syncwarp(); + + // Input -> Shared Memory + + for (uint row = 0; row < num_rows; row++) { + half *smem_row_ptr = &smem_in[row * input_stride]; + const half *gmem_row_ptr = &gmem_input[row * num_cols]; + for (uint idx = lane_id; idx < num_cols; idx += WARP_SIZE) { + smem_row_ptr[idx] = gmem_row_ptr[idx]; + } + uint idx = lane_id + num_cols; + if (idx < num_cols_after_padding) { + smem_row_ptr[idx] = __float2half(0); + } + } + +#pragma unroll 2 + for (uint row = num_rows; row < num_rows_after_padding; row++) { + half *smem_row_ptr = &smem_in[row * input_stride]; + for (uint idx = lane_id; idx < num_cols_after_padding; idx += WARP_SIZE) { + smem_row_ptr[idx] = __float2half(0); + } + } + __syncwarp(); + + wmma::fragment a[ROW_TILES_PER_STEP] + [ROW_TILES_PER_STEP]; + for (uint i = 0; i < ROW_TILES_PER_STEP; i++) { + for (uint j = 0; j < ROW_TILES_PER_STEP; j++) { + const half *tile_ptr = smem_temp + ((i * interaction_ugrad_2D_stride + j) << TILE_DIM_LOG_2); + wmma::load_matrix_sync(a[i][j], tile_ptr, interaction_ugrad_2D_stride); + } + } + + wmma::fragment acc[ROW_TILES_PER_STEP]; + wmma::fragment b[ROW_TILES_PER_STEP]; + for (int col_step = 0; col_step < num_col_steps; col_step++) { + for (uint i = 0; i < ROW_TILES_PER_STEP; i++) { + const half *tile_ptr = smem_in + ((i * input_stride + col_step) << TILE_DIM_LOG_2); + wmma::fill_fragment(acc[i], 0); + wmma::load_matrix_sync(b[i], tile_ptr, input_stride); + } + for (uint i = 0; i < ROW_TILES_PER_STEP; i++) { + for (uint j = 0; j < ROW_TILES_PER_STEP; j++) { + wmma::mma_sync(acc[i], a[i][j], b[j], acc[i]); + } + } + for (uint i = 0; i < ROW_TILES_PER_STEP; i++) { + float *tile_ptr = smem_out + i * TILE_DIM * TILE_DIM; + wmma::store_matrix_sync(tile_ptr, acc[i], TILE_DIM, wmma::mem_row_major); + } + __syncwarp(); + uint gmem_grad_col = (col_step << TILE_DIM_LOG_2) + lane_id; + if (gmem_grad_col < num_cols) { + for (uint i = 0; i < num_rows; i++) { + gmem_grad[i * num_cols + gmem_grad_col] = __float2half(smem_out[(i << TILE_DIM_LOG_2) + lane_id]); + } + } + } + + for (uint idx = lane_id; idx < num_cols; idx += WARP_SIZE) { + gmem_mlp_grad[idx] = gmem_ugrad[idx]; + } +} + +template +__launch_bounds__(THREADBLOCK_SIZE) __global__ void dotBasedInteractBwdKernel(const __half *__restrict input, + const __half *__restrict upstream_grad, + half __restrict *grad, + half __restrict *bottom_mlp_grad, + uint batch_size, + uint num_rows, + uint num_cols, + uint num_rows_after_padding, + uint num_cols_after_padding, + uint sample_size, + uint interaction_ugrad_size, + uint interaction_ugrad_size_with_padding, + uint interaction_ugrad_2D_size_elems, + uint interaction_ugrad_2D_stride, + uint input_size_elems, + uint input_stride, + uint num_row_steps, + uint num_col_steps, + uint row_tiles_per_step, + uint shared_mem_per_warp_size_byte) { + extern __shared__ half shared_mem[]; + uint warp_id = (threadIdx.x >> WARP_SIZE_LOG_2); + uint sample_id = blockIdx.x * WARPS_PER_BLOCK + warp_id; + if (sample_id >= batch_size) { + return; + } + uint lane_id = threadIdx.x & (WARP_SIZE - 1); + // ">> 1" to convert to half pointer + uint smem_warp_offset = warp_id * (shared_mem_per_warp_size_byte >> 1); + + half *smem_in = &shared_mem[smem_warp_offset]; + half *smem_temp = &shared_mem[smem_warp_offset + input_size_elems]; + float *smem_out = reinterpret_cast(smem_temp); + + // Global memory pointers for the current sample + // Input + uint gmem_input_sample_offset = sample_id * sample_size; + const half *gmem_input = &input[gmem_input_sample_offset]; + + // Interaction Gradient + const uint &gmem_grad_sample_offset = gmem_input_sample_offset; + half *gmem_grad = &grad[gmem_grad_sample_offset]; + + // Bottom MLP gradient + half *gmem_mlp_grad = &bottom_mlp_grad[sample_id * num_cols]; + + // Upstream gradient vector + uint gmem_ugrad_sample_offset = sample_id * (num_cols + interaction_ugrad_size_with_padding); + const half *gmem_ugrad = &upstream_grad[gmem_ugrad_sample_offset]; + + // Upstream gradient vector for interactions + const half *gmem_ugrad_interactions = &gmem_ugrad[num_cols]; + + // upstream grad -> shared memory (place in input section temporarily) +#pragma unroll + for (uint idx = lane_id; idx < (interaction_ugrad_size >> 3); idx += WARP_SIZE) { + ((float4 *)smem_in)[idx] = ((float4 *)gmem_ugrad_interactions)[idx]; + } + uint offset = (interaction_ugrad_size >> 3) << 3; + for (uint idx = lane_id + offset; idx < interaction_ugrad_size; idx += WARP_SIZE) { + smem_in[idx] = gmem_ugrad_interactions[idx]; + } + __syncwarp(); + // Form the 2D ugrad matrix. + if (lane_id < num_rows_after_padding) { + uint ugrad_flat_index = ((lane_id * (lane_id - 1)) >> 1); + uint ugrad_offset_1 = lane_id * interaction_ugrad_2D_stride; + for (uint row = 0; row < num_rows; row++) { + half ugrad_val = __float2half(0.0f); + if (row < lane_id && lane_id < num_rows) { + ugrad_val = smem_in[ugrad_flat_index + row]; + smem_temp[ugrad_offset_1 + row] = ugrad_val; + } + if (row <= lane_id && lane_id < num_rows_after_padding) { + smem_temp[row * interaction_ugrad_2D_stride + lane_id] = ugrad_val; + } + } + for (uint row = num_rows; row < num_rows_after_padding; row++) { + smem_temp[row * interaction_ugrad_2D_stride + lane_id] = __float2half(0.0f); + } + } + __syncwarp(); + + // Input -> Shared Memory + + if (lane_id < (num_cols >> 2)) { + for (uint row = 0; row < num_rows; row++) { + half *smem_row_ptr = &smem_in[row * input_stride]; + const half *gmem_row_ptr = &gmem_input[row * num_cols]; + ((float2 *)smem_row_ptr)[lane_id] = ((float2 *)gmem_row_ptr)[lane_id]; + } + } + + uint idx = lane_id + num_cols; + if (idx < num_cols_after_padding) { + for (uint row = 0; row < num_rows; row++) { + half *smem_row_ptr = &smem_in[row * input_stride]; + smem_row_ptr[idx] = __float2half(0); + } + } + + half4 zeros; + zeros.vals[0].x = __float2half(0); + zeros.vals[0].y = __float2half(0); + zeros.vals[1].x = __float2half(0); + zeros.vals[1].y = __float2half(0); + if (lane_id < (num_cols_after_padding >> 2)) { +#pragma unroll 2 + for (uint row = num_rows; row < num_rows_after_padding; row++) { + half *smem_row_ptr = &smem_in[row * input_stride]; + ((half4 *)smem_row_ptr)[lane_id] = zeros; + } + } + __syncwarp(); + + wmma::fragment a[ROW_TILES_PER_STEP] + [ROW_TILES_PER_STEP]; + for (uint i = 0; i < ROW_TILES_PER_STEP; i++) { + for (uint j = 0; j < ROW_TILES_PER_STEP; j++) { + const half *tile_ptr = smem_temp + ((i * interaction_ugrad_2D_stride + j) << TILE_DIM_LOG_2); + wmma::load_matrix_sync(a[i][j], tile_ptr, interaction_ugrad_2D_stride); + } + } + + wmma::fragment acc[ROW_TILES_PER_STEP]; + wmma::fragment b[ROW_TILES_PER_STEP]; + for (int col_step = 0; col_step < num_col_steps; col_step++) { + for (uint i = 0; i < ROW_TILES_PER_STEP; i++) { + const half *tile_ptr = smem_in + ((i * input_stride + col_step) << TILE_DIM_LOG_2); + wmma::fill_fragment(acc[i], 0); + wmma::load_matrix_sync(b[i], tile_ptr, input_stride); + } + for (uint i = 0; i < ROW_TILES_PER_STEP; i++) { + for (uint j = 0; j < ROW_TILES_PER_STEP; j++) { + wmma::mma_sync(acc[i], a[i][j], b[j], acc[i]); + } + } + for (uint i = 0; i < ROW_TILES_PER_STEP; i++) { + float *tile_ptr = smem_out + i * TILE_DIM * TILE_DIM; + wmma::store_matrix_sync(tile_ptr, acc[i], TILE_DIM, wmma::mem_row_major); + } + __syncwarp(); + uint gmem_grad_col = (col_step << TILE_DIM_LOG_2) + lane_id; + if (gmem_grad_col < num_cols) { + for (uint i = 0; i < num_rows; i++) { + gmem_grad[i * num_cols + gmem_grad_col] = __float2half(smem_out[(i << TILE_DIM_LOG_2) + lane_id]); + } + } + } + if (lane_id < (num_cols >> 2)) { + ((float2 *)gmem_mlp_grad)[lane_id] = ((float2 *)gmem_ugrad)[lane_id]; + } +} + +template +__launch_bounds__(THREADBLOCK_SIZE) __global__ + void dotBasedInteractF32FwdKernelNonAligned(const float *__restrict input, + float *__restrict output, + uint batch_size, + uint num_rows, + uint num_cols, + uint input_size, + uint output_size, + uint interaction_output_size) { + extern __shared__ float smem_f32_fwd[]; + float *smem_in = &smem_f32_fwd[0]; + + uint input_batch_offset = blockIdx.x * input_size; + const float *gmem_in = &input[input_batch_offset]; + + uint output_batch_offset = blockIdx.x * output_size; + float *gmem_out_bottom_mlp = &output[output_batch_offset]; + float *gmem_out_interaction = &output[output_batch_offset + num_cols]; + + // Load the input - one sample per block + for (uint idx = threadIdx.x; idx < input_size; idx += blockDim.x) { + smem_in[idx] = gmem_in[idx]; + } + __syncthreads(); + + // Copy bottom MLP output to output + for (uint idx = threadIdx.x; idx < num_cols; idx += blockDim.x) { + gmem_out_bottom_mlp[idx] = smem_in[idx]; + } + + for (uint idx = threadIdx.x; idx < (interaction_output_size); idx += blockDim.x) { + uint elems_per_row = 1; + uint index = idx; + while (index >= elems_per_row) { + index -= elems_per_row; + elems_per_row++; + } + uint target_row = elems_per_row; + uint target_col = index; + + float sum = 0; + for (uint i = 0; i < num_cols; i++) { + float tmp1 = smem_in[target_row * num_cols + i]; + float tmp2 = smem_in[target_col * num_cols + i]; + sum = fmaf(tmp1, tmp2, sum); + } + + gmem_out_interaction[idx] = sum; + } + + gmem_out_interaction[interaction_output_size] = 0; +} + +template +__launch_bounds__(THREADBLOCK_SIZE) __global__ void dotBasedInteractF32FwdKernel(const float *__restrict input, + float *__restrict output, + uint batch_size, + uint num_rows, + uint num_cols, + uint input_size, + uint output_size, + uint interaction_output_size) { + extern __shared__ float smem_f32_fwd[]; + float *smem_in = &smem_f32_fwd[0]; + + uint input_batch_offset = blockIdx.x * input_size; + const float *gmem_in = &input[input_batch_offset]; + + uint output_batch_offset = blockIdx.x * output_size; + float *gmem_out_bottom_mlp = &output[output_batch_offset]; + float *gmem_out_interaction = &output[output_batch_offset + num_cols]; + + // Load the input - one sample per block + uint input_size_float4 = input_size >> 2; + for (uint idx = threadIdx.x; idx < input_size_float4; idx += blockDim.x) { + ((float4 *)smem_in)[idx] = ((float4 *)gmem_in)[idx]; + } + __syncthreads(); + + // Copy bottom MLP output to output + uint btm_mlp_out_size_float4 = num_cols >> 2; + for (uint idx = threadIdx.x; idx < btm_mlp_out_size_float4; idx += blockDim.x) { + ((float4 *)gmem_out_bottom_mlp)[idx] = ((float4 *)smem_in)[idx]; + } + + for (uint idx = threadIdx.x; idx < (interaction_output_size); idx += blockDim.x) { + uint elems_per_row = 1; + uint index = idx; + while (index >= elems_per_row) { + index -= elems_per_row; + elems_per_row++; + } + uint target_row = elems_per_row; + uint target_col = index; + + float4 sum; + sum.x = 0; + sum.y = 0; + sum.z = 0; + sum.w = 0; + uint num_cols_float4 = num_cols >> 2; + for (uint i = 0; i < num_cols_float4; i++) { + float4 tmp1 = ((float4 *)smem_in)[target_row * num_cols_float4 + i]; + float4 tmp2 = ((float4 *)smem_in)[target_col * num_cols_float4 + i]; + sum.x = fmaf(tmp1.x, tmp2.x, sum.x); + sum.y = fmaf(tmp1.y, tmp2.y, sum.y); + sum.z = fmaf(tmp1.z, tmp2.z, sum.z); + sum.w = fmaf(tmp1.w, tmp2.w, sum.w); + } + + gmem_out_interaction[idx] = sum.x + sum.y + sum.z + sum.w; + } + + gmem_out_interaction[interaction_output_size] = 0; +} + +template +__launch_bounds__(THREADBLOCK_SIZE) __global__ + void dotBasedInteractF32BwdKernelNonAligned(const float *__restrict input, + const float *__restrict upstream_grad, + float *__restrict grad, + float *__restrict bottom_mlp_grad, + uint batch_size, + uint num_rows, + uint num_cols, + uint input_size, + uint ugrad_size, + uint interaction_ugrad_size) { + extern __shared__ float smem_f32_bwd[]; + float *smem_in = &smem_f32_bwd[0]; + float *smem_interaction_ugrad = &smem_f32_bwd[input_size]; + + // Input + uint input_batch_offset = blockIdx.x * input_size; + const float *gmem_in = &input[input_batch_offset]; + + // Gradient + const uint &grad_batch_offset = input_batch_offset; + float *gmem_mlp_grad = &bottom_mlp_grad[blockIdx.x * num_cols]; + float *gmem_interaction_grad = &grad[grad_batch_offset]; + + // Upstream Gradient + uint upstream_grad_batch_offset = blockIdx.x * ugrad_size; + const float *gmem_mlp_ugrad = &upstream_grad[upstream_grad_batch_offset]; + const float *gmem_interaction_ugrad = &upstream_grad[upstream_grad_batch_offset + num_cols]; + + // input -> shared memory + for (uint idx = threadIdx.x; idx < input_size; idx += blockDim.x) { + smem_in[idx] = gmem_in[idx]; + } + + // Interaction Upstream Grad -> Shared Memory + for (uint idx = threadIdx.x; idx < interaction_ugrad_size; idx += blockDim.x) { + smem_interaction_ugrad[idx] = gmem_interaction_ugrad[idx]; + } + __syncthreads(); + + // Copy the upstream gradient w.r.t to mlp to it's corresponding memory location. + for (uint idx = threadIdx.x; idx < num_cols; idx += blockDim.x) { + gmem_mlp_grad[idx] = gmem_mlp_ugrad[idx]; + } + + for (uint idx = threadIdx.x; idx < num_cols; idx += blockDim.x) { + size_t grad_idx = idx; + for (uint row_idx = 0; row_idx < num_rows; row_idx++) { + float sum = 0; + size_t upstream_grad_offset = (row_idx * (row_idx - 1)) >> 1; + for (int k = 0; k < row_idx; k++) { + sum = fmaf(smem_in[k * num_cols + idx], smem_interaction_ugrad[upstream_grad_offset + k], sum); + } + for (int k = row_idx + 1; k < num_rows; k++) { + upstream_grad_offset = (k * (k - 1)) >> 1; // TODO: this can become a sum + sum = fmaf(smem_in[k * num_cols + idx], smem_interaction_ugrad[upstream_grad_offset + row_idx], sum); + } + gmem_interaction_grad[grad_idx] = sum; + grad_idx += num_cols; + } + } +} + +template +__launch_bounds__(THREADBLOCK_SIZE) __global__ void dotBasedInteractF32BwdKernel(const float *__restrict input, + const float *__restrict upstream_grad, + float *__restrict grad, + float *__restrict bottom_mlp_grad, + uint batch_size, + uint num_rows, + uint num_cols, + uint input_size, + uint ugrad_size, + uint interaction_ugrad_size) { + extern __shared__ float smem_f32_bwd[]; + float *smem_in = &smem_f32_bwd[0]; + float *smem_interaction_ugrad = &smem_f32_bwd[input_size]; + + // Input + uint input_batch_offset = blockIdx.x * input_size; + const float *gmem_in = &input[input_batch_offset]; + + // Gradient + const uint &grad_batch_offset = input_batch_offset; + float *gmem_mlp_grad = &bottom_mlp_grad[blockIdx.x * num_cols]; + float *gmem_interaction_grad = &grad[grad_batch_offset]; + + // Upstream Gradient + uint upstream_grad_batch_offset = blockIdx.x * ugrad_size; + const float *gmem_mlp_ugrad = &upstream_grad[upstream_grad_batch_offset]; + const float *gmem_interaction_ugrad = &upstream_grad[upstream_grad_batch_offset + num_cols]; + + // input -> shared memory + uint input_size_float4 = input_size >> 2; + for (uint idx = threadIdx.x; idx < input_size_float4; idx += blockDim.x) { + ((float4 *)smem_in)[idx] = ((float4 *)gmem_in)[idx]; + } + + // Interaction Upstream Grad -> Shared Memory + uint upstream_grad_size_float4 = interaction_ugrad_size >> 2; + for (uint idx = threadIdx.x; idx < upstream_grad_size_float4; idx += blockDim.x) { + ((float4 *)smem_interaction_ugrad)[idx] = ((float4 *)gmem_interaction_ugrad)[idx]; + } + + uint vectorized_load_offset = (upstream_grad_size_float4 << 2); + for (uint idx = vectorized_load_offset + threadIdx.x; idx < interaction_ugrad_size; idx += blockDim.x) { + smem_interaction_ugrad[idx] = gmem_interaction_ugrad[idx]; + } + __syncthreads(); + + // Copy the upstream gradient w.r.t to mlp to it's corresponding memory location. + for (uint idx = threadIdx.x; idx < (num_cols >> 2); idx += blockDim.x) { + ((float4 *)gmem_mlp_grad)[idx] = ((float4 *)gmem_mlp_ugrad)[idx]; + } + + for (uint idx = threadIdx.x; idx < num_cols; idx += blockDim.x) { + size_t grad_idx = idx; + for (uint row_idx = 0; row_idx < num_rows; row_idx++) { + float sum = 0; + size_t upstream_grad_offset = (row_idx * (row_idx - 1)) >> 1; + for (int k = 0; k < row_idx; k++) { + sum = fmaf(smem_in[k * num_cols + idx], smem_interaction_ugrad[upstream_grad_offset + k], sum); + } + for (int k = row_idx + 1; k < num_rows; k++) { + upstream_grad_offset = (k * (k - 1)) >> 1; // TODO: this can become a sum + sum = fmaf(smem_in[k * num_cols + idx], smem_interaction_ugrad[upstream_grad_offset + row_idx], sum); + } + gmem_interaction_grad[grad_idx] = sum; + grad_idx += num_cols; + } + } +} diff --git a/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/volta/dot_based_interact_volta.h b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/volta/dot_based_interact_volta.h new file mode 100644 index 00000000..51a05804 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/kernels/volta/dot_based_interact_volta.h @@ -0,0 +1,53 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KERNEL_DOT_BASED_INTERACT_VOLTA_H_ +#define KERNEL_DOT_BASED_INTERACT_VOLTA_H_ + +void dotBasedInteractVoltaF16Fwd(const void *input, + const void *bottom_mlp_output, + void *output, + uint batch_size, + uint num_rows, + uint num_cols, + cudaStream_t stream); + +void dotBasedInteractVoltaF16Bwd(const void *input, + const void *upstream_grad, + void *grad, + void *bottom_mlp_grad, + uint batch_size, + uint num_rows, + uint num_cols, + cudaStream_t stream); + +void dotBasedInteractVoltaF32Fwd(const void *input, + const void *bottom_mlp_output, + void *output, + uint batch_size, + uint num_rows, + uint num_cols, + cudaStream_t stream); + +void dotBasedInteractVoltaF32Bwd(const void *input, + const void *upstream_grad, + void *grad, + void *bottom_mlp_grad, + uint batch_size, + uint num_rows, + uint num_cols, + cudaStream_t stream); + +#endif //KERNEL_DOT_BASED_INTERACT_VOLTA_H_ diff --git a/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/ops/dot_based_interact_ops.cc b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/ops/dot_based_interact_ops.cc new file mode 100644 index 00000000..bfc019b9 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/cc/ops/dot_based_interact_ops.cc @@ -0,0 +1,50 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +using namespace tensorflow; + +REGISTER_OP("DotBasedInteract") + .Attr("T: {float, half}") + .Input("input: T") + .Input("bottom_mlp_output: T") + .Output("output: T") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + const int64 pad = 1; + auto input = c->input(0); + auto batch_size_dim = c->Dim(input, 0); + int64 num_rows = c->Value(c->Dim(input, 1)); + int64 num_cols = c->Value(c->Dim(input, 2)); + auto output_size_dim = c->MakeDim(((num_rows * (num_rows - 1)) >> 1) + num_cols + pad); + c->set_output(0, c->MakeShape({batch_size_dim, output_size_dim})); + return Status::OK(); + }); + +REGISTER_OP("DotBasedInteractGrad") + .Attr("T: {float, half}") + .Input("input: T") + .Input("upstream_grad: T") + .Output("grad: T") + .Output("bottom_mlp_grad: T") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + auto input = c->input(0); + auto batch_size_dim = c->Dim(input, 0); + auto num_cols_dim = c->Dim(input, 2); + c->set_output(0, input); + c->set_output(1, c->MakeShape({batch_size_dim, num_cols_dim})); + return Status::OK(); + }); diff --git a/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/python/__init__.py b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/python/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/python/ops/__init__.py b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/python/ops/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/python/ops/dot_based_interact_ops.py b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/python/ops/dot_based_interact_ops.py new file mode 100644 index 00000000..4bea3b32 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/python/ops/dot_based_interact_ops.py @@ -0,0 +1,31 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.framework import load_library +from tensorflow.python.platform import resource_loader + +dot_based_interact_ops = load_library.load_op_library( + resource_loader.get_path_to_datafile('_dot_based_interact_ops.so')) +dot_based_interact = dot_based_interact_ops.dot_based_interact + +@ops.RegisterGradient("DotBasedInteract") +def dot_based_interact_grad(op, grad): + input = op.inputs[0] + return dot_based_interact_ops.dot_based_interact_grad(input, grad) diff --git a/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/python/ops/dot_based_interact_ops_test.py b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/python/ops/dot_based_interact_ops_test.py new file mode 100644 index 00000000..427a2d9e --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/tensorflow-dot-based-interact/tensorflow_dot_based_interact/python/ops/dot_based_interact_ops_test.py @@ -0,0 +1,115 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +import numpy as np + +from tensorflow.python.framework import ops +from tensorflow.python.platform import test +from tensorflow.python.framework import test_util +try: + from tensorflow_dot_based_interact.python.ops import dot_based_interact_ops +except ImportError: + import dot_based_interact_ops + +def dot_based_interact_native(input, bottom_mlp_output): + # Dot Based Interact of the "input" tensor + concat_features = tf.cast(input, tf.float32) + interactions = tf.matmul(concat_features, concat_features, transpose_b=True) + ones = tf.ones_like(interactions, dtype=concat_features.dtype) + upper_tri_mask = tf.linalg.band_part(ones, 0, -1) + feature_dim = tf.shape(interactions)[-1] + lower_tri_mask = ones - upper_tri_mask + activations = tf.boolean_mask(interactions, lower_tri_mask) + out_dim = feature_dim * (feature_dim - 1) // 2 + activations = tf.reshape(activations, shape=[-1, out_dim]) + + # Top Concatenation of the bottom_mlp_output with the interactions + bottom_mlp_output = tf.cast(tf.squeeze(bottom_mlp_output, axis=1), tf.float32) + top_concat = tf.concat([bottom_mlp_output, activations], axis=1) + + # Zero Padding for performance in upstream ops + padding = tf.zeros([concat_features.shape[0], 1]) + zero_padded = tf.concat([top_concat, padding], axis=1) + + return zero_padded + + +class DotBasedInteractTest(test.TestCase): + + def input(self, batch_size, num_rows, num_cols, dtype): + # Creates two random tensors to use as sample inputs to test with: + # - input: With shape [batch_size, num_rows, num_cols] + # - bottom_mlp_output: With shape [batch_size, 1, num_cols] + # Where the first row of input is a copy of bottom_mlp_output + mlp_rows = 1 + emb_rows = num_rows - mlp_rows + bottom_mlp_output = tf.random.uniform(shape=[batch_size, mlp_rows, num_cols], dtype=dtype) + embeddings = tf.random.uniform(shape=[batch_size, emb_rows, num_cols], dtype=dtype) + input = tf.concat([bottom_mlp_output, embeddings], axis=1) + return tf.Variable(input), tf.Variable(bottom_mlp_output) + + def forward(self, batch_size, num_rows, num_cols, dtype): + with self.test_session() as sess: + with ops.device("/gpu:0"): + input, bottom_mlp_output = self.input(batch_size, num_rows, num_cols, dtype) + expected = dot_based_interact_native(input, bottom_mlp_output) + result = dot_based_interact_ops.dot_based_interact(input, bottom_mlp_output) + return result, expected + + def backward(self, batch_size, num_rows, num_cols, dtype): + with self.test_session() as sess: + with ops.device("/gpu:0"): + input, bottom_mlp_output = self.input(batch_size, num_rows, num_cols, dtype) + with tf.GradientTape() as tape: + output = dot_based_interact_native(input, bottom_mlp_output) + expected = tape.gradient(output, [input, bottom_mlp_output]) + with tf.GradientTape() as tape: + output = dot_based_interact_ops.dot_based_interact(input, bottom_mlp_output) + result = tape.gradient(output, [input, bottom_mlp_output]) + return result[0], expected[0] + + def test_fp32(self): + # Higher than normal tolerance on FP32 due to TF32 on Ampere + self.assertAllClose(*self.forward(16, 32, 32, tf.float32), rtol=1e-03) + + def test_fp32_not_aligned(self): + self.assertAllClose(*self.forward(17, 31, 37, tf.float32), rtol=1e-03) + + def test_grad_fp32(self): + self.assertAllClose(*self.backward(16, 32, 32, tf.float32), rtol=1e-03) + + def test_grad_fp32_not_aligned(self): + self.assertAllClose(*self.backward(17, 31, 37, tf.float32), rtol=1e-03) + + def test_fp16(self): + self.assertAllCloseAccordingToType(*self.forward(16, 32, 32, tf.float16)) + + def test_fp16_not_aligned(self): + self.assertAllCloseAccordingToType(*self.forward(15, 31, 37, tf.float16)) + + def test_grad_fp16(self): + self.assertAllCloseAccordingToType(*self.backward(16, 32, 32, tf.float16)) + + def test_grad_fp16_not_aligned(self): + self.assertAllCloseAccordingToType(*self.backward(17, 31, 37, tf.float16)) + +if __name__ == '__main__': + test.main() diff --git a/Tensorflow2/Recommendation/DLRM/utils.py b/Tensorflow2/Recommendation/DLRM/utils.py new file mode 100644 index 00000000..f91805c3 --- /dev/null +++ b/Tensorflow2/Recommendation/DLRM/utils.py @@ -0,0 +1,127 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# author: Tomasz Grel (tgrel@nvidia.com) + + +import time +import dllogger +import horovod.tensorflow as hvd +import json + + +def print_model_summary(model): + variables_placement = { + v.name: (v.device, v.shape.as_list(), + str(v.is_initialized()), str(v.dtype), str(v.trainable), str(v.synchronization)) for v in model.trainable_variables + } + print('============ VARIABLES PLACEMENT =====================') + print(json.dumps(variables_placement, indent=4)) + print('============ VARIABLES PLACEMENT END =================') + + +def dist_print(*args, force=False, **kwargs): + if hvd.rank() == 0 or force: + print(*args, **kwargs) + + +def init_logging(log_path, FLAGS): + json_backend = dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE, + filename=log_path) + stdout_backend = dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE) + + stdout_backend._metadata['auc'].update({'format': '0:.5f'}) + stdout_backend._metadata['throughput'].update({'format': ':.2e'}) + stdout_backend._metadata['mean_step_time_ms'].update({'format': '0:.3f'}) + stdout_backend._metadata['mean_inference_throughput'].update({'format': ':.2e'}) + stdout_backend._metadata['mean_inference_latency'].update({'format': '0:.5f'}) + for percentile in [90, 95, 99]: + stdout_backend._metadata[f'p{percentile}_inference_latency'].update({'format': '0:.5f'}) + + dllogger.init(backends=[json_backend, stdout_backend]) + + if hvd.rank() == 0: + dllogger.log(data=FLAGS.flag_values_dict(), step='PARAMETER') + print("Command line flags:") + print(json.dumps(FLAGS.flag_values_dict(), indent=4)) + + +class IterTimer: + def __init__(self, train_batch_size, test_batch_size, optimizer, print_freq=50, + enabled=True, benchmark_warmup_steps=100): + self.previous_tick = None + self.train_idx = 0 + self.test_idx = 0 + self.train_batch_size = train_batch_size + self.test_batch_size = test_batch_size + self.print_freq = print_freq + self.optimizer = optimizer + self.enabled = enabled + self.training_steps_time = 0 + self.benchmark_warmup_steps = benchmark_warmup_steps + self.steps_measured = 0 + + def step_train(self, loss=None): + if not self.enabled: + return + + if self.train_idx < self.benchmark_warmup_steps: + self.train_idx += 1 + return + + if self.train_idx % self.print_freq == 0 and self.train_idx > 0: + if self.previous_tick is None: + self.previous_tick = time.time() + self.train_idx += 1 + return + + current_time = time.time() + elapsed = current_time - self.previous_tick + throughput = (self.train_batch_size * self.print_freq) / elapsed + throughput_in_millions = throughput / 1e6 + step_time_ms = elapsed / self.print_freq * 1000 + lr = f'{self.optimizer.lr.numpy().item():.4f}' + + print(f'step={self.train_idx}, throughput={throughput_in_millions:.3f}M, step_time={step_time_ms:.3f} ms, learning_rate={lr}, loss={loss:.8f},') + + self.previous_tick = current_time + self.training_steps_time += elapsed + self.steps_measured += self.print_freq + + self.train_idx += 1 + + def mean_train_time(self): + return self.training_steps_time / self.steps_measured + + def step_test(self): + if not self.enabled: + return + + if self.previous_tick is None: + self.previous_tick = time.time() + self.test_idx += 1 + return + + if self.test_idx % self.print_freq == self.print_freq - 1: + current_time = time.time() + elapsed = current_time - self.previous_tick + throughput = (self.test_batch_size * self.print_freq) / elapsed + throughput_in_millions = throughput / 1e6 + step_time_ms = elapsed / self.print_freq * 1000 + + print(f'validation_step={self.test_idx}, validation_throughput={throughput_in_millions:.3f}M, step_time={step_time_ms:.3f} ms') + + self.previous_tick = current_time + self.test_idx += 1 +