Commit eb38b6c8 authored by Kira Selby's avatar Kira Selby

set batch norm to always be in eval mode regardless of if .eval is called

parent 439d2222
......@@ -233,7 +233,7 @@ class BatchNormFlow(nn.Module):
def forward(self, inputs, mode='direct'):
if mode == 'direct':
if self.training:
if True:#self.training:
self.batch_mean = inputs.mean(0)
self.batch_var = (
inputs - self.batch_mean).pow(2).mean(0) + self.eps
......@@ -257,7 +257,7 @@ class BatchNormFlow(nn.Module):
return y, (self.log_gamma - 0.5 * torch.log(var)).sum(
-1, keepdim=True)
else:
if self.training:
if True:#self.training:
mean = self.batch_mean
var = self.batch_var
else:
......
......@@ -189,7 +189,7 @@ def train_epoch(model, optim, train_loader, epoch, device, log_interval):
def validate(model, loader, device, prefix='Validation'):
#model.eval()
model.eval()
val_loss = 0
for data in loader:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment