영넌 개발로그

[밑시딥 3] 자연스러운 코드로 2 - 복잡한 계산 그래프 본문

코딩/ML , Deep

[밑시딥 3] 자연스러운 코드로 2 - 복잡한 계산 그래프

영넌 2023. 9. 6. 14:49

올바른 역전파의 경우 D,B,C,A 나 D,C,B,A 순서로 역전파가 일어나야 함

 

현재 구현한 코드는 위의 그래프의 역전파를 올바르게 해내지 못함

처리할 함수의 후보를 funcs 리스트의 끝에 추가하고 다음에 처리할 함수를 그 리스트에서 '마지막' 원소를 꺼내고 있음

처리 순서가 D,C,A,B,A 로 C다음 A가 바로 이어지며 A의 역전파가 두 번 일어나게 된다.

 

문제 해결을 위해 '우선순위'가 필요

1. 계산 그래프를 분석하여 위상 정렬 알고리즘 사용 (Topological Sort) 

   노드의 연결 방법을 기초로 노드 정렬 가능

2. 순전파 때 어떤 함수가 어떤 변수를 만들어내는지 알 수 있음 

   세대 기록 가능

class Variable:
    def __init__(self,data):
        if data is not None:   
            if not isinstance(data, np.ndarray):
                raise TypeError("{}은 지원하지 않습니다.".format(type(data)))
        
        self.data = data
        self.grad = None
        self.creator = None
        self.generation = 0 #세대 수를 기록하는 변수

    def set_creator(self, func):
        self.creator = func
        self.generation = func.generation + 1 #세대 기록 (부모 +1)
    
    ...

class Function:
    def __call__(self, *inputs):
        xs = [x.data for x in inputs]
        ys = self.forward(*xs) 	
        if not isinstance(ys, tuple):
            ys = (ys, )
        outputs = [Variable[as_array(y)) for y in ys]
        
        #입력 변수가 둘 이상이라면 가장 큰 generation 의 수 선택
        self.generation = max([x.generation for x in inputs])
        
        for output in outputs:
            output.set_creator(self)
        self.inputs = inputs
        self.outputs = outputs
        

        return outputs if len(outputs) > 1 else outputs[0]
   ...

class Variable:
    def __init__(self,data):
        if data is not None:   
            if not isinstance(data, np.ndarray):
                raise TypeError("{}은 지원하지 않습니다.".format(type(data)))
        
        self.data = data
        self.grad = None
        self.creator = None
        self.generation = 0 #세대 수를 기록하는 변수

    def set_creator(self, func):
        self.creator = func
        self.generation = func.generation + 1 #세대 기록 (부모 +1)

    def backward(self):
        if self.grad is None:   
            self.grad = np.ones_like(self.data)

        ####변경 부분
            
        funcs = []
        #같은 함수 중복 추가가 안되도록 set으로 설정
        seen_set = set()

        #함수 리스트를 세대 순으로 정렬하는 역할
        def add_func(f):
            if f not in seen_Set:
                funcs.append(f)
                seen_set.add(f)
                funcs.sort(key=lambda x: x.generation)
        add_func(self.creator)
        ####
        
        while funcs:
            f = funcs.pop()
            
            gys = [output.grad for output in f.outputs]
            gxs = f.backward(*gys)
            if not isinstace(gxs, tuple):
                gxs = (gxs, )

            for x, gx in zip(f.inputs, gxs):
                
                if x.grad is None:
                    x.grad = gx
                else:
                    x.grad = x.grad + gx
                    
                if x.creator is not None:
                    add_func(x.creator) #funcs.append(x.creator)

    def cleargrad(self):
        self.grad =None

 

Comments