Skip to content

Commit

Permalink
fix use_cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
truthless11 committed Jun 20, 2019
1 parent 11146f7 commit 505e728
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion convlab/modules/nlg/multiwoz/sc_lstm/loader/dataset_woz.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch
from torch.autograd import Variable
from convlab.modules.nlg.multiwoz.sc_lstm.nlg_sc_lstm import USE_CUDA
from convlab.modules.nlg.multiwoz.utils import USE_CUDA


class DatasetWoz(object):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from convlab.modules.nlg.multiwoz.sc_lstm.nlg_sc_lstm import USE_CUDA
from convlab.modules.nlg.multiwoz.utils import USE_CUDA


class DecoderDeep(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion convlab/modules/nlg/multiwoz/sc_lstm/model/lm_deep.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from convlab.modules.nlg.multiwoz.sc_lstm.model.layers.decoder_deep import DecoderDeep
from convlab.modules.nlg.multiwoz.sc_lstm.model.masked_cross_entropy import masked_cross_entropy
from convlab.modules.nlg.multiwoz.sc_lstm.nlg_sc_lstm import USE_CUDA
from convlab.modules.nlg.multiwoz.utils import USE_CUDA


class LMDeep(nn.Module):
Expand Down
3 changes: 2 additions & 1 deletion convlab/modules/nlg/multiwoz/sc_lstm/nlg_sc_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
from convlab.modules.nlg.multiwoz.sc_lstm.loader.dataset_woz import SimpleDatasetWoz
from convlab.modules.nlg.multiwoz.sc_lstm.model.lm_deep import LMDeep
from convlab.modules.nlg.nlg import NLG
from convlab.modules.nlg.multiwoz.utils import USE_CUDA

DEFAULT_DIRECTORY = "models"
DEFAULT_ARCHIVE_FILE = os.path.join(DEFAULT_DIRECTORY, "nlg-sclstm-multiwoz.zip")
USE_CUDA = -1

def parse(is_user):
if is_user:
Expand Down Expand Up @@ -63,6 +63,7 @@ def __init__(self,

self.args, self.config = parse(is_user)
self.dataset = SimpleDatasetWoz(self.config)
global USE_CUDA
USE_CUDA = use_cuda

# get model hyper-parameters
Expand Down
1 change: 1 addition & 0 deletions convlab/modules/nlg/multiwoz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np

USE_CUDA = -1

def initWeights(n,d):
""" Initialization Strategy """
Expand Down

0 comments on commit 505e728

Please sign in to comment.