[internal/CI] Add SE3T A100 tests

This commit is contained in:
Alexandre Milesi 2021-10-21 00:11:30 -07:00 committed by Andrei Shumak
parent 8c98f155e0
commit 67c7afa053
15 changed files with 1581 additions and 1563 deletions

View file

@ -328,7 +328,7 @@ The complete list of the available parameters for the `training.py` script conta
- `--gradient_clip`: Clipping of the gradient norms (default: `None`)
- `--accumulate_grad_batches`: Gradient accumulation (default: `1`)
- `--ckpt_interval`: Save a checkpoint every N epochs (default: `-1`)
- `--eval_interval`: Do an evaluation round every N epochs (default: `1`)
- `--eval_interval`: Do an evaluation round every N epochs (default: `20`)
- `--silent`: Minimize stdout output (default: `false`)
**Paths**
@ -485,6 +485,7 @@ Our results were obtained by running the `scripts/train.sh` training script in t
| 8 | 240 | 0.03380 | 0.03495 | 29min | 20min | 1.45x |
#### Training performance results
##### Training performance: NVIDIA DGX A100 (8x A100 80GB)
@ -495,8 +496,8 @@ Our results were obtained by running the `scripts/benchmark_train.sh` and `scrip
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
| 1 | 240 | 2.21 | 2.92 | 1.32x | | |
| 1 | 120 | 1.81 | 2.04 | 1.13x | | |
| 8 | 240 | 17.15 | 22.95 | 1.34x | 7.76 | 7.86 |
| 8 | 120 | 13.89 | 15.62 | 1.12x | 7.67 | 7.66 |
| 8 | 240 | 15.88 | 21.02 | 1.32x | 7.18 | 7.20 |
| 8 | 120 | 12.68 | 13.99 | 1.10x | 7.00 | 6.86 |
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
@ -510,8 +511,8 @@ Our results were obtained by running the `scripts/benchmark_train.sh` and `scrip
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
| 1 | 240 | 1.25 | 1.88 | 1.50x | | |
| 1 | 120 | 1.03 | 1.41 | 1.37x | | |
| 8 | 240 | 9.33 | 14.02 | 1.50x | 7.46 | 7.46 |
| 8 | 120 | 7.39 | 9.41 | 1.27x | 7.17 | 6.67 |
| 8 | 240 | 8.68 | 12.75 | 1.47x | 6.94 | 6.78 |
| 8 | 120 | 6.64 | 8.58 | 1.29x | 6.44 | 6.08 |
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
@ -572,6 +573,15 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic
### Changelog
October 2021:
- Updated README performance tables
- Fixed shape mismatch when using partially fused TFNs per output degree
- Fixed shape mismatch when using partially fused TFNs per input degree with edge degrees > 0
September 2021:
- Moved to new location (from `PyTorch/DrugDiscovery` to `DGLPyTorch/DrugDiscovery`)
- Fixed multi-GPUs training script
August 2021
- Initial release

View file

@ -54,7 +54,6 @@ def get_all_clebsch_gordon(max_degree: int, device) -> List[List[Tensor]]:
def get_spherical_harmonics(relative_pos: Tensor, max_degree: int) -> List[Tensor]:
all_degrees = list(range(2 * max_degree + 1))
with nvtx_range('spherical harmonics'):
sh = o3.spherical_harmonics(all_degrees, relative_pos, normalize=True)
return torch.split(sh, [degree_to_dim(d) for d in all_degrees], dim=1)

View file

@ -178,5 +178,3 @@ class AttentionBlockSE3(nn.Module):
value[degree] = feat
return key, value

View file

@ -152,12 +152,9 @@ class VersatileConvSE3(nn.Module):
if basis is not None:
# This block performs the einsum n i l, n o i f, n l f k -> n o k
out_dim = basis.shape[-1]
if self.fuse_level != ConvSE3FuseLevel.FULL:
out_dim += out_dim % 2 - 1 # Account for padded basis
basis_view = basis.view(num_edges, in_dim, -1)
tmp = (features @ basis_view).view(num_edges, -1, basis.shape[-1])
return (radial_weights @ tmp)[:, :, :out_dim]
return radial_weights @ tmp
else:
# k = l = 0 non-fused case
return radial_weights @ features
@ -248,7 +245,8 @@ class ConvSE3(nn.Module):
self.conv_in = nn.ModuleDict()
for d_in, c_in in fiber_in:
sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees])
self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, c_in, list(channels_out_set)[0],
channels_in_new = c_in + fiber_edge[d_in] * (d_in > 0)
self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, channels_in_new, list(channels_out_set)[0],
fuse_level=self.used_fuse_level, **common_args)
else:
# Use pairwise TFN convolutions
@ -268,6 +266,15 @@ class ConvSE3(nn.Module):
self.to_kernel_self[str(degree_out)] = nn.Parameter(
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
def _try_unpad(self, feature, basis):
# Account for padded basis
if basis is not None:
out_dim = basis.shape[-1]
out_dim += out_dim % 2 - 1
return feature[..., :out_dim]
else:
return feature
def forward(
self,
node_feats: Dict[str, Tensor],
@ -299,7 +306,10 @@ class ConvSE3(nn.Module):
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_out'):
in_features_fused = torch.cat(in_features, dim=-1)
for degree_out in self.fiber_out.degrees:
out[str(degree_out)] = self.conv_out[str(degree_out)](in_features_fused, invariant_edge_feats, basis[f'out{degree_out}_fused'])
basis_used = basis[f'out{degree_out}_fused']
out[str(degree_out)] = self._try_unpad(
self.conv_out[str(degree_out)](in_features_fused, invariant_edge_feats, basis_used),
basis_used)
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_in'):
out = 0
@ -313,7 +323,10 @@ class ConvSE3(nn.Module):
out_feature = 0
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
dict_key = f'{degree_in},{degree_out}'
out_feature = out_feature + self.conv[dict_key](feature, invariant_edge_feats, basis.get(dict_key, None))
basis_used = basis.get(dict_key, None)
out_feature = out_feature + self._try_unpad(
self.conv[dict_key](feature, invariant_edge_feats, basis_used),
basis_used)
out[str(degree_out)] = out_feature
for degree_out in self.fiber_out.degrees:
@ -330,5 +343,3 @@ class ConvSE3(nn.Module):
else:
out = dgl.ops.copy_e_sum(graph, out)
return out

View file

@ -58,7 +58,7 @@ PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False
PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms')
PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation')
PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs')
PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=1,
PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=20,
help='Do an evaluation round every N epochs')
PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False,
help='Minimize stdout output')