영넌 개발로그
[밑시딥 3] 자연스러운 코드로 2 - 복잡한 계산 그래프 본문
올바른 역전파의 경우 D,B,C,A 나 D,C,B,A 순서로 역전파가 일어나야 함
현재 구현한 코드는 위의 그래프의 역전파를 올바르게 해내지 못함
처리할 함수의 후보를 funcs 리스트의 끝에 추가하고 다음에 처리할 함수를 그 리스트에서 '마지막' 원소를 꺼내고 있음
처리 순서가 D,C,A,B,A 로 C다음 A가 바로 이어지며 A의 역전파가 두 번 일어나게 된다.
문제 해결을 위해 '우선순위'가 필요
1. 계산 그래프를 분석하여 위상 정렬 알고리즘 사용 (Topological Sort)
노드의 연결 방법을 기초로 노드 정렬 가능
2. 순전파 때 어떤 함수가 어떤 변수를 만들어내는지 알 수 있음
세대 기록 가능
class Variable:
def __init__(self,data):
if data is not None:
if not isinstance(data, np.ndarray):
raise TypeError("{}은 지원하지 않습니다.".format(type(data)))
self.data = data
self.grad = None
self.creator = None
self.generation = 0 #세대 수를 기록하는 변수
def set_creator(self, func):
self.creator = func
self.generation = func.generation + 1 #세대 기록 (부모 +1)
...
class Function:
def __call__(self, *inputs):
xs = [x.data for x in inputs]
ys = self.forward(*xs)
if not isinstance(ys, tuple):
ys = (ys, )
outputs = [Variable[as_array(y)) for y in ys]
#입력 변수가 둘 이상이라면 가장 큰 generation 의 수 선택
self.generation = max([x.generation for x in inputs])
for output in outputs:
output.set_creator(self)
self.inputs = inputs
self.outputs = outputs
return outputs if len(outputs) > 1 else outputs[0]
...
class Variable:
def __init__(self,data):
if data is not None:
if not isinstance(data, np.ndarray):
raise TypeError("{}은 지원하지 않습니다.".format(type(data)))
self.data = data
self.grad = None
self.creator = None
self.generation = 0 #세대 수를 기록하는 변수
def set_creator(self, func):
self.creator = func
self.generation = func.generation + 1 #세대 기록 (부모 +1)
def backward(self):
if self.grad is None:
self.grad = np.ones_like(self.data)
####변경 부분
funcs = []
#같은 함수 중복 추가가 안되도록 set으로 설정
seen_set = set()
#함수 리스트를 세대 순으로 정렬하는 역할
def add_func(f):
if f not in seen_Set:
funcs.append(f)
seen_set.add(f)
funcs.sort(key=lambda x: x.generation)
add_func(self.creator)
####
while funcs:
f = funcs.pop()
gys = [output.grad for output in f.outputs]
gxs = f.backward(*gys)
if not isinstace(gxs, tuple):
gxs = (gxs, )
for x, gx in zip(f.inputs, gxs):
if x.grad is None:
x.grad = gx
else:
x.grad = x.grad + gx
if x.creator is not None:
add_func(x.creator) #funcs.append(x.creator)
def cleargrad(self):
self.grad =None
'코딩 > ML , Deep' 카테고리의 다른 글
[밑시딥3] 자연스러운 코드로 4 - 변수 사용성 개선, 연산자 오버로드 (0) | 2023.09.06 |
---|---|
[밑시딥3] 자연스러운 코드로 3 - 메모리 관리, 절약 모드 (0) | 2023.09.06 |
[밑시딥3] 자연스러운 코드로 1 - 가변 길이 인수, 같은 변수 반복 사용 (0) | 2023.09.06 |
[밑시딥3] 미분 자동 계산 4 - 테스트 (0) | 2023.09.05 |
Comments