[internal/CI] Add SE3T A100 tests
This commit is contained in:
parent
8c98f155e0
commit
67c7afa053
|
@ -328,7 +328,7 @@ The complete list of the available parameters for the `training.py` script conta
|
|||
- `--gradient_clip`: Clipping of the gradient norms (default: `None`)
|
||||
- `--accumulate_grad_batches`: Gradient accumulation (default: `1`)
|
||||
- `--ckpt_interval`: Save a checkpoint every N epochs (default: `-1`)
|
||||
- `--eval_interval`: Do an evaluation round every N epochs (default: `1`)
|
||||
- `--eval_interval`: Do an evaluation round every N epochs (default: `20`)
|
||||
- `--silent`: Minimize stdout output (default: `false`)
|
||||
|
||||
**Paths**
|
||||
|
@ -485,6 +485,7 @@ Our results were obtained by running the `scripts/train.sh` training script in t
|
|||
| 8 | 240 | 0.03380 | 0.03495 | 29min | 20min | 1.45x |
|
||||
|
||||
|
||||
|
||||
#### Training performance results
|
||||
|
||||
##### Training performance: NVIDIA DGX A100 (8x A100 80GB)
|
||||
|
@ -495,8 +496,8 @@ Our results were obtained by running the `scripts/benchmark_train.sh` and `scrip
|
|||
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
|
||||
| 1 | 240 | 2.21 | 2.92 | 1.32x | | |
|
||||
| 1 | 120 | 1.81 | 2.04 | 1.13x | | |
|
||||
| 8 | 240 | 17.15 | 22.95 | 1.34x | 7.76 | 7.86 |
|
||||
| 8 | 120 | 13.89 | 15.62 | 1.12x | 7.67 | 7.66 |
|
||||
| 8 | 240 | 15.88 | 21.02 | 1.32x | 7.18 | 7.20 |
|
||||
| 8 | 120 | 12.68 | 13.99 | 1.10x | 7.00 | 6.86 |
|
||||
|
||||
|
||||
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
||||
|
@ -510,8 +511,8 @@ Our results were obtained by running the `scripts/benchmark_train.sh` and `scrip
|
|||
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
|
||||
| 1 | 240 | 1.25 | 1.88 | 1.50x | | |
|
||||
| 1 | 120 | 1.03 | 1.41 | 1.37x | | |
|
||||
| 8 | 240 | 9.33 | 14.02 | 1.50x | 7.46 | 7.46 |
|
||||
| 8 | 120 | 7.39 | 9.41 | 1.27x | 7.17 | 6.67 |
|
||||
| 8 | 240 | 8.68 | 12.75 | 1.47x | 6.94 | 6.78 |
|
||||
| 8 | 120 | 6.64 | 8.58 | 1.29x | 6.44 | 6.08 |
|
||||
|
||||
|
||||
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
||||
|
@ -572,6 +573,15 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic
|
|||
|
||||
### Changelog
|
||||
|
||||
October 2021:
|
||||
- Updated README performance tables
|
||||
- Fixed shape mismatch when using partially fused TFNs per output degree
|
||||
- Fixed shape mismatch when using partially fused TFNs per input degree with edge degrees > 0
|
||||
|
||||
September 2021:
|
||||
- Moved to new location (from `PyTorch/DrugDiscovery` to `DGLPyTorch/DrugDiscovery`)
|
||||
- Fixed multi-GPUs training script
|
||||
|
||||
August 2021
|
||||
- Initial release
|
||||
|
||||
|
|
|
@ -54,9 +54,8 @@ def get_all_clebsch_gordon(max_degree: int, device) -> List[List[Tensor]]:
|
|||
|
||||
def get_spherical_harmonics(relative_pos: Tensor, max_degree: int) -> List[Tensor]:
|
||||
all_degrees = list(range(2 * max_degree + 1))
|
||||
with nvtx_range('spherical harmonics'):
|
||||
sh = o3.spherical_harmonics(all_degrees, relative_pos, normalize=True)
|
||||
return torch.split(sh, [degree_to_dim(d) for d in all_degrees], dim=1)
|
||||
sh = o3.spherical_harmonics(all_degrees, relative_pos, normalize=True)
|
||||
return torch.split(sh, [degree_to_dim(d) for d in all_degrees], dim=1)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
|
|
|
@ -178,5 +178,3 @@ class AttentionBlockSE3(nn.Module):
|
|||
value[degree] = feat
|
||||
|
||||
return key, value
|
||||
|
||||
|
||||
|
|
|
@ -152,12 +152,9 @@ class VersatileConvSE3(nn.Module):
|
|||
|
||||
if basis is not None:
|
||||
# This block performs the einsum n i l, n o i f, n l f k -> n o k
|
||||
out_dim = basis.shape[-1]
|
||||
if self.fuse_level != ConvSE3FuseLevel.FULL:
|
||||
out_dim += out_dim % 2 - 1 # Account for padded basis
|
||||
basis_view = basis.view(num_edges, in_dim, -1)
|
||||
tmp = (features @ basis_view).view(num_edges, -1, basis.shape[-1])
|
||||
return (radial_weights @ tmp)[:, :, :out_dim]
|
||||
return radial_weights @ tmp
|
||||
else:
|
||||
# k = l = 0 non-fused case
|
||||
return radial_weights @ features
|
||||
|
@ -248,7 +245,8 @@ class ConvSE3(nn.Module):
|
|||
self.conv_in = nn.ModuleDict()
|
||||
for d_in, c_in in fiber_in:
|
||||
sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees])
|
||||
self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, c_in, list(channels_out_set)[0],
|
||||
channels_in_new = c_in + fiber_edge[d_in] * (d_in > 0)
|
||||
self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, channels_in_new, list(channels_out_set)[0],
|
||||
fuse_level=self.used_fuse_level, **common_args)
|
||||
else:
|
||||
# Use pairwise TFN convolutions
|
||||
|
@ -268,6 +266,15 @@ class ConvSE3(nn.Module):
|
|||
self.to_kernel_self[str(degree_out)] = nn.Parameter(
|
||||
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
|
||||
|
||||
def _try_unpad(self, feature, basis):
|
||||
# Account for padded basis
|
||||
if basis is not None:
|
||||
out_dim = basis.shape[-1]
|
||||
out_dim += out_dim % 2 - 1
|
||||
return feature[..., :out_dim]
|
||||
else:
|
||||
return feature
|
||||
|
||||
def forward(
|
||||
self,
|
||||
node_feats: Dict[str, Tensor],
|
||||
|
@ -299,7 +306,10 @@ class ConvSE3(nn.Module):
|
|||
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_out'):
|
||||
in_features_fused = torch.cat(in_features, dim=-1)
|
||||
for degree_out in self.fiber_out.degrees:
|
||||
out[str(degree_out)] = self.conv_out[str(degree_out)](in_features_fused, invariant_edge_feats, basis[f'out{degree_out}_fused'])
|
||||
basis_used = basis[f'out{degree_out}_fused']
|
||||
out[str(degree_out)] = self._try_unpad(
|
||||
self.conv_out[str(degree_out)](in_features_fused, invariant_edge_feats, basis_used),
|
||||
basis_used)
|
||||
|
||||
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_in'):
|
||||
out = 0
|
||||
|
@ -313,7 +323,10 @@ class ConvSE3(nn.Module):
|
|||
out_feature = 0
|
||||
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
|
||||
dict_key = f'{degree_in},{degree_out}'
|
||||
out_feature = out_feature + self.conv[dict_key](feature, invariant_edge_feats, basis.get(dict_key, None))
|
||||
basis_used = basis.get(dict_key, None)
|
||||
out_feature = out_feature + self._try_unpad(
|
||||
self.conv[dict_key](feature, invariant_edge_feats, basis_used),
|
||||
basis_used)
|
||||
out[str(degree_out)] = out_feature
|
||||
|
||||
for degree_out in self.fiber_out.degrees:
|
||||
|
@ -330,5 +343,3 @@ class ConvSE3(nn.Module):
|
|||
else:
|
||||
out = dgl.ops.copy_e_sum(graph, out)
|
||||
return out
|
||||
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@ PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False
|
|||
PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms')
|
||||
PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation')
|
||||
PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs')
|
||||
PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=1,
|
||||
PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=20,
|
||||
help='Do an evaluation round every N epochs')
|
||||
PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False,
|
||||
help='Minimize stdout output')
|
||||
|
|
Loading…
Reference in a new issue